diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2023-11-12 21:25:31 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2023-11-12 21:39:54 +0300 |
commit | d28c55ab25cc8cedab8a5f4736c0d66e88b3da95 (patch) | |
tree | 73d373709b74fa2baaa4fe02a40a77c0a5baf6b7 | |
parent | 35b17f4f3b6e0ed855e7e47d3f1eb57470388a2c (diff) | |
download | ydb-d28c55ab25cc8cedab8a5f4736c0d66e88b3da95.tar.gz |
Intermediate changes
416 files changed, 51925 insertions, 4 deletions
diff --git a/.mapping.json b/.mapping.json index 5c586a131e..b1d5ec9df7 100644 --- a/.mapping.json +++ b/.mapping.json @@ -412,6 +412,7 @@ "contrib/libs/libunwind/CMakeLists.linux-aarch64.txt":"", "contrib/libs/libunwind/CMakeLists.linux-x86_64.txt":"", "contrib/libs/libunwind/CMakeLists.txt":"", + "contrib/libs/libunwind/CMakeLists.windows-x86_64.txt":"", "contrib/libs/liburing/CMakeLists.linux-aarch64.txt":"", "contrib/libs/liburing/CMakeLists.linux-x86_64.txt":"", "contrib/libs/liburing/CMakeLists.txt":"", @@ -2625,6 +2626,16 @@ "library/cpp/pop_count/CMakeLists.linux-x86_64.txt":"", "library/cpp/pop_count/CMakeLists.txt":"", "library/cpp/pop_count/CMakeLists.windows-x86_64.txt":"", + "library/cpp/porto/CMakeLists.darwin-x86_64.txt":"", + "library/cpp/porto/CMakeLists.linux-aarch64.txt":"", + "library/cpp/porto/CMakeLists.linux-x86_64.txt":"", + "library/cpp/porto/CMakeLists.txt":"", + "library/cpp/porto/CMakeLists.windows-x86_64.txt":"", + "library/cpp/porto/proto/CMakeLists.darwin-x86_64.txt":"", + "library/cpp/porto/proto/CMakeLists.linux-aarch64.txt":"", + "library/cpp/porto/proto/CMakeLists.linux-x86_64.txt":"", + "library/cpp/porto/proto/CMakeLists.txt":"", + "library/cpp/porto/proto/CMakeLists.windows-x86_64.txt":"", "library/cpp/presort/CMakeLists.darwin-x86_64.txt":"", "library/cpp/presort/CMakeLists.linux-aarch64.txt":"", "library/cpp/presort/CMakeLists.linux-x86_64.txt":"", @@ -3076,10 +3087,21 @@ "library/cpp/yt/backtrace/cursors/CMakeLists.windows-x86_64.txt":"", "library/cpp/yt/backtrace/cursors/dummy/CMakeLists.txt":"", "library/cpp/yt/backtrace/cursors/dummy/CMakeLists.windows-x86_64.txt":"", + "library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.darwin-x86_64.txt":"", + "library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.linux-aarch64.txt":"", + "library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.linux-x86_64.txt":"", + "library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.txt":"", + "library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.windows-x86_64.txt":"", + "library/cpp/yt/backtrace/cursors/interop/CMakeLists.darwin-x86_64.txt":"", + "library/cpp/yt/backtrace/cursors/interop/CMakeLists.linux-aarch64.txt":"", + "library/cpp/yt/backtrace/cursors/interop/CMakeLists.linux-x86_64.txt":"", + "library/cpp/yt/backtrace/cursors/interop/CMakeLists.txt":"", + "library/cpp/yt/backtrace/cursors/interop/CMakeLists.windows-x86_64.txt":"", "library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.darwin-x86_64.txt":"", "library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.linux-aarch64.txt":"", "library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.linux-x86_64.txt":"", "library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.txt":"", + "library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.windows-x86_64.txt":"", "library/cpp/yt/coding/CMakeLists.darwin-x86_64.txt":"", "library/cpp/yt/coding/CMakeLists.linux-aarch64.txt":"", "library/cpp/yt/coding/CMakeLists.linux-x86_64.txt":"", @@ -3125,11 +3147,21 @@ "library/cpp/yt/misc/CMakeLists.linux-x86_64.txt":"", "library/cpp/yt/misc/CMakeLists.txt":"", "library/cpp/yt/misc/CMakeLists.windows-x86_64.txt":"", + "library/cpp/yt/mlock/CMakeLists.darwin-x86_64.txt":"", + "library/cpp/yt/mlock/CMakeLists.linux-aarch64.txt":"", + "library/cpp/yt/mlock/CMakeLists.linux-x86_64.txt":"", + "library/cpp/yt/mlock/CMakeLists.txt":"", + "library/cpp/yt/mlock/CMakeLists.windows-x86_64.txt":"", "library/cpp/yt/small_containers/CMakeLists.darwin-x86_64.txt":"", "library/cpp/yt/small_containers/CMakeLists.linux-aarch64.txt":"", "library/cpp/yt/small_containers/CMakeLists.linux-x86_64.txt":"", "library/cpp/yt/small_containers/CMakeLists.txt":"", "library/cpp/yt/small_containers/CMakeLists.windows-x86_64.txt":"", + "library/cpp/yt/stockpile/CMakeLists.darwin-x86_64.txt":"", + "library/cpp/yt/stockpile/CMakeLists.linux-aarch64.txt":"", + "library/cpp/yt/stockpile/CMakeLists.linux-x86_64.txt":"", + "library/cpp/yt/stockpile/CMakeLists.txt":"", + "library/cpp/yt/stockpile/CMakeLists.windows-x86_64.txt":"", "library/cpp/yt/string/CMakeLists.darwin-x86_64.txt":"", "library/cpp/yt/string/CMakeLists.linux-aarch64.txt":"", "library/cpp/yt/string/CMakeLists.linux-x86_64.txt":"", @@ -9953,6 +9985,18 @@ "yt/yt/core/misc/isa_crc64/CMakeLists.linux-x86_64.txt":"", "yt/yt/core/misc/isa_crc64/CMakeLists.txt":"", "yt/yt/core/misc/isa_crc64/CMakeLists.windows-x86_64.txt":"", + "yt/yt/core/rpc/CMakeLists.txt":"", + "yt/yt/core/rpc/grpc/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/core/rpc/grpc/CMakeLists.linux-aarch64.txt":"", + "yt/yt/core/rpc/grpc/CMakeLists.linux-x86_64.txt":"", + "yt/yt/core/rpc/grpc/CMakeLists.txt":"", + "yt/yt/core/rpc/grpc/CMakeLists.windows-x86_64.txt":"", + "yt/yt/core/service_discovery/CMakeLists.txt":"", + "yt/yt/core/service_discovery/yp/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/core/service_discovery/yp/CMakeLists.linux-aarch64.txt":"", + "yt/yt/core/service_discovery/yp/CMakeLists.linux-x86_64.txt":"", + "yt/yt/core/service_discovery/yp/CMakeLists.txt":"", + "yt/yt/core/service_discovery/yp/CMakeLists.windows-x86_64.txt":"", "yt/yt/library/CMakeLists.darwin-x86_64.txt":"", "yt/yt/library/CMakeLists.linux-aarch64.txt":"", "yt/yt/library/CMakeLists.linux-x86_64.txt":"", @@ -9962,6 +10006,17 @@ "yt/yt/library/auth/CMakeLists.linux-aarch64.txt":"", "yt/yt/library/auth/CMakeLists.linux-x86_64.txt":"", "yt/yt/library/auth/CMakeLists.txt":"", + "yt/yt/library/backtrace_introspector/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/backtrace_introspector/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/backtrace_introspector/CMakeLists.txt":"", + "yt/yt/library/backtrace_introspector/http/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/backtrace_introspector/http/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/backtrace_introspector/http/CMakeLists.txt":"", + "yt/yt/library/containers/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/library/containers/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/containers/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/containers/CMakeLists.txt":"", + "yt/yt/library/containers/CMakeLists.windows-x86_64.txt":"", "yt/yt/library/decimal/CMakeLists.darwin-x86_64.txt":"", "yt/yt/library/decimal/CMakeLists.linux-aarch64.txt":"", "yt/yt/library/decimal/CMakeLists.linux-x86_64.txt":"", @@ -9970,20 +10025,50 @@ "yt/yt/library/erasure/CMakeLists.linux-aarch64.txt":"", "yt/yt/library/erasure/CMakeLists.linux-x86_64.txt":"", "yt/yt/library/erasure/CMakeLists.txt":"", + "yt/yt/library/monitoring/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/library/monitoring/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/monitoring/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/monitoring/CMakeLists.txt":"", + "yt/yt/library/monitoring/CMakeLists.windows-x86_64.txt":"", "yt/yt/library/numeric/CMakeLists.darwin-x86_64.txt":"", "yt/yt/library/numeric/CMakeLists.linux-aarch64.txt":"", "yt/yt/library/numeric/CMakeLists.linux-x86_64.txt":"", "yt/yt/library/numeric/CMakeLists.txt":"", + "yt/yt/library/process/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/library/process/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/process/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/process/CMakeLists.txt":"", + "yt/yt/library/process/CMakeLists.windows-x86_64.txt":"", "yt/yt/library/profiling/CMakeLists.darwin-x86_64.txt":"", "yt/yt/library/profiling/CMakeLists.linux-aarch64.txt":"", "yt/yt/library/profiling/CMakeLists.linux-x86_64.txt":"", "yt/yt/library/profiling/CMakeLists.txt":"", "yt/yt/library/profiling/CMakeLists.windows-x86_64.txt":"", + "yt/yt/library/profiling/perf/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/library/profiling/perf/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/profiling/perf/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/profiling/perf/CMakeLists.txt":"", + "yt/yt/library/profiling/perf/CMakeLists.windows-x86_64.txt":"", "yt/yt/library/profiling/resource_tracker/CMakeLists.darwin-x86_64.txt":"", "yt/yt/library/profiling/resource_tracker/CMakeLists.linux-aarch64.txt":"", "yt/yt/library/profiling/resource_tracker/CMakeLists.linux-x86_64.txt":"", "yt/yt/library/profiling/resource_tracker/CMakeLists.txt":"", "yt/yt/library/profiling/resource_tracker/CMakeLists.windows-x86_64.txt":"", + "yt/yt/library/profiling/solomon/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/library/profiling/solomon/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/profiling/solomon/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/profiling/solomon/CMakeLists.txt":"", + "yt/yt/library/profiling/solomon/CMakeLists.windows-x86_64.txt":"", + "yt/yt/library/profiling/tcmalloc/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/library/profiling/tcmalloc/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/profiling/tcmalloc/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/profiling/tcmalloc/CMakeLists.txt":"", + "yt/yt/library/profiling/tcmalloc/CMakeLists.windows-x86_64.txt":"", + "yt/yt/library/program/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/library/program/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/program/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/program/CMakeLists.txt":"", + "yt/yt/library/program/CMakeLists.windows-x86_64.txt":"", "yt/yt/library/quantile_digest/CMakeLists.darwin-x86_64.txt":"", "yt/yt/library/quantile_digest/CMakeLists.linux-aarch64.txt":"", "yt/yt/library/quantile_digest/CMakeLists.linux-x86_64.txt":"", @@ -10002,6 +10087,11 @@ "yt/yt/library/tracing/CMakeLists.linux-x86_64.txt":"", "yt/yt/library/tracing/CMakeLists.txt":"", "yt/yt/library/tracing/CMakeLists.windows-x86_64.txt":"", + "yt/yt/library/tracing/jaeger/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/library/tracing/jaeger/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/tracing/jaeger/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/tracing/jaeger/CMakeLists.txt":"", + "yt/yt/library/tracing/jaeger/CMakeLists.windows-x86_64.txt":"", "yt/yt/library/tvm/CMakeLists.darwin-x86_64.txt":"", "yt/yt/library/tvm/CMakeLists.linux-aarch64.txt":"", "yt/yt/library/tvm/CMakeLists.linux-x86_64.txt":"", @@ -10012,12 +10102,24 @@ "yt/yt/library/undumpable/CMakeLists.linux-x86_64.txt":"", "yt/yt/library/undumpable/CMakeLists.txt":"", "yt/yt/library/undumpable/CMakeLists.windows-x86_64.txt":"", + "yt/yt/library/ytprof/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/library/ytprof/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/ytprof/CMakeLists.linux-x86_64.txt":"", "yt/yt/library/ytprof/CMakeLists.txt":"", + "yt/yt/library/ytprof/CMakeLists.windows-x86_64.txt":"", "yt/yt/library/ytprof/api/CMakeLists.darwin-x86_64.txt":"", "yt/yt/library/ytprof/api/CMakeLists.linux-aarch64.txt":"", "yt/yt/library/ytprof/api/CMakeLists.linux-x86_64.txt":"", "yt/yt/library/ytprof/api/CMakeLists.txt":"", "yt/yt/library/ytprof/api/CMakeLists.windows-x86_64.txt":"", + "yt/yt/library/ytprof/http/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/ytprof/http/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/ytprof/http/CMakeLists.txt":"", + "yt/yt/library/ytprof/proto/CMakeLists.darwin-x86_64.txt":"", + "yt/yt/library/ytprof/proto/CMakeLists.linux-aarch64.txt":"", + "yt/yt/library/ytprof/proto/CMakeLists.linux-x86_64.txt":"", + "yt/yt/library/ytprof/proto/CMakeLists.txt":"", + "yt/yt/library/ytprof/proto/CMakeLists.windows-x86_64.txt":"", "yt/yt_proto/CMakeLists.txt":"", "yt/yt_proto/yt/CMakeLists.darwin-x86_64.txt":"", "yt/yt_proto/yt/CMakeLists.linux-aarch64.txt":"", diff --git a/contrib/libs/CMakeLists.windows-x86_64.txt b/contrib/libs/CMakeLists.windows-x86_64.txt index d92b43e34d..7bacb8e92b 100644 --- a/contrib/libs/CMakeLists.windows-x86_64.txt +++ b/contrib/libs/CMakeLists.windows-x86_64.txt @@ -33,6 +33,7 @@ add_subdirectory(libbz2) add_subdirectory(libc_compat) add_subdirectory(libevent) add_subdirectory(libfyaml) +add_subdirectory(libunwind) add_subdirectory(libxml) add_subdirectory(linuxvdso) add_subdirectory(llvm12) diff --git a/contrib/libs/libunwind/CMakeLists.txt b/contrib/libs/libunwind/CMakeLists.txt index 606ff46b4b..f8b31df0c1 100644 --- a/contrib/libs/libunwind/CMakeLists.txt +++ b/contrib/libs/libunwind/CMakeLists.txt @@ -10,6 +10,8 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarc include(CMakeLists.linux-aarch64.txt) elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) include(CMakeLists.linux-x86_64.txt) endif() diff --git a/contrib/libs/libunwind/CMakeLists.windows-x86_64.txt b/contrib/libs/libunwind/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..4b54ad1509 --- /dev/null +++ b/contrib/libs/libunwind/CMakeLists.windows-x86_64.txt @@ -0,0 +1,35 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(contrib-libs-libunwind) +target_compile_options(contrib-libs-libunwind PUBLIC + -D_libunwind_ +) +target_compile_options(contrib-libs-libunwind PRIVATE + -D_LIBUNWIND_IS_NATIVE_ONLY + -fno-exceptions + -fno-rtti + -funwind-tables +) +target_include_directories(contrib-libs-libunwind PRIVATE + ${CMAKE_SOURCE_DIR}/contrib/libs/libunwind/include +) +target_link_libraries(contrib-libs-libunwind PUBLIC + cpp-sanitizer-include +) +target_sources(contrib-libs-libunwind PRIVATE + ${CMAKE_SOURCE_DIR}/contrib/libs/libunwind/src/Unwind-EHABI.cpp + ${CMAKE_SOURCE_DIR}/contrib/libs/libunwind/src/Unwind-seh.cpp + ${CMAKE_SOURCE_DIR}/contrib/libs/libunwind/src/Unwind-sjlj.c + ${CMAKE_SOURCE_DIR}/contrib/libs/libunwind/src/UnwindLevel1-gcc-ext.c + ${CMAKE_SOURCE_DIR}/contrib/libs/libunwind/src/UnwindLevel1.c + ${CMAKE_SOURCE_DIR}/contrib/libs/libunwind/src/UnwindRegistersRestore.S + ${CMAKE_SOURCE_DIR}/contrib/libs/libunwind/src/UnwindRegistersSave.S + ${CMAKE_SOURCE_DIR}/contrib/libs/libunwind/src/libunwind.cpp +) diff --git a/contrib/python/anyio/.dist-info/METADATA b/contrib/python/anyio/.dist-info/METADATA new file mode 100644 index 0000000000..5e46476e02 --- /dev/null +++ b/contrib/python/anyio/.dist-info/METADATA @@ -0,0 +1,105 @@ +Metadata-Version: 2.1 +Name: anyio +Version: 3.7.1 +Summary: High level compatibility layer for multiple asynchronous event loop implementations +Author-email: Alex Grönholm <alex.gronholm@nextday.fi> +License: MIT +Project-URL: Documentation, https://anyio.readthedocs.io/en/latest/ +Project-URL: Changelog, https://anyio.readthedocs.io/en/stable/versionhistory.html +Project-URL: Source code, https://github.com/agronholm/anyio +Project-URL: Issue tracker, https://github.com/agronholm/anyio/issues +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Framework :: AnyIO +Classifier: Typing :: Typed +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: idna (>=2.8) +Requires-Dist: sniffio (>=1.1) +Requires-Dist: exceptiongroup ; python_version < "3.11" +Requires-Dist: typing-extensions ; python_version < "3.8" +Provides-Extra: doc +Requires-Dist: packaging ; extra == 'doc' +Requires-Dist: Sphinx ; extra == 'doc' +Requires-Dist: sphinx-rtd-theme (>=1.2.2) ; extra == 'doc' +Requires-Dist: sphinxcontrib-jquery ; extra == 'doc' +Requires-Dist: sphinx-autodoc-typehints (>=1.2.0) ; extra == 'doc' +Provides-Extra: test +Requires-Dist: anyio[trio] ; extra == 'test' +Requires-Dist: coverage[toml] (>=4.5) ; extra == 'test' +Requires-Dist: hypothesis (>=4.0) ; extra == 'test' +Requires-Dist: psutil (>=5.9) ; extra == 'test' +Requires-Dist: pytest (>=7.0) ; extra == 'test' +Requires-Dist: pytest-mock (>=3.6.1) ; extra == 'test' +Requires-Dist: trustme ; extra == 'test' +Requires-Dist: uvloop (>=0.17) ; (python_version < "3.12" and platform_python_implementation == "CPython" and platform_system != "Windows") and extra == 'test' +Requires-Dist: mock (>=4) ; (python_version < "3.8") and extra == 'test' +Provides-Extra: trio +Requires-Dist: trio (<0.22) ; extra == 'trio' + +.. image:: https://github.com/agronholm/anyio/actions/workflows/test.yml/badge.svg + :target: https://github.com/agronholm/anyio/actions/workflows/test.yml + :alt: Build Status +.. image:: https://coveralls.io/repos/github/agronholm/anyio/badge.svg?branch=master + :target: https://coveralls.io/github/agronholm/anyio?branch=master + :alt: Code Coverage +.. image:: https://readthedocs.org/projects/anyio/badge/?version=latest + :target: https://anyio.readthedocs.io/en/latest/?badge=latest + :alt: Documentation +.. image:: https://badges.gitter.im/gitterHQ/gitter.svg + :target: https://gitter.im/python-trio/AnyIO + :alt: Gitter chat + +AnyIO is an asynchronous networking and concurrency library that works on top of either asyncio_ or +trio_. It implements trio-like `structured concurrency`_ (SC) on top of asyncio and works in harmony +with the native SC of trio itself. + +Applications and libraries written against AnyIO's API will run unmodified on either asyncio_ or +trio_. AnyIO can also be adopted into a library or application incrementally – bit by bit, no full +refactoring necessary. It will blend in with the native libraries of your chosen backend. + +Documentation +------------- + +View full documentation at: https://anyio.readthedocs.io/ + +Features +-------- + +AnyIO offers the following functionality: + +* Task groups (nurseries_ in trio terminology) +* High-level networking (TCP, UDP and UNIX sockets) + + * `Happy eyeballs`_ algorithm for TCP connections (more robust than that of asyncio on Python + 3.8) + * async/await style UDP sockets (unlike asyncio where you still have to use Transports and + Protocols) + +* A versatile API for byte streams and object streams +* Inter-task synchronization and communication (locks, conditions, events, semaphores, object + streams) +* Worker threads +* Subprocesses +* Asynchronous file I/O (using worker threads) +* Signal handling + +AnyIO also comes with its own pytest_ plugin which also supports asynchronous fixtures. +It even works with the popular Hypothesis_ library. + +.. _asyncio: https://docs.python.org/3/library/asyncio.html +.. _trio: https://github.com/python-trio/trio +.. _structured concurrency: https://en.wikipedia.org/wiki/Structured_concurrency +.. _nurseries: https://trio.readthedocs.io/en/stable/reference-core.html#nurseries-and-spawning +.. _Happy eyeballs: https://en.wikipedia.org/wiki/Happy_Eyeballs +.. _pytest: https://docs.pytest.org/en/latest/ +.. _Hypothesis: https://hypothesis.works/ diff --git a/contrib/python/anyio/.dist-info/entry_points.txt b/contrib/python/anyio/.dist-info/entry_points.txt new file mode 100644 index 0000000000..44dd9bdc30 --- /dev/null +++ b/contrib/python/anyio/.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[pytest11] +anyio = anyio.pytest_plugin diff --git a/contrib/python/anyio/.dist-info/top_level.txt b/contrib/python/anyio/.dist-info/top_level.txt new file mode 100644 index 0000000000..c77c069ecc --- /dev/null +++ b/contrib/python/anyio/.dist-info/top_level.txt @@ -0,0 +1 @@ +anyio diff --git a/contrib/python/anyio/LICENSE b/contrib/python/anyio/LICENSE new file mode 100644 index 0000000000..104eebf5a3 --- /dev/null +++ b/contrib/python/anyio/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2018 Alex Grönholm + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/contrib/python/anyio/README.rst b/contrib/python/anyio/README.rst new file mode 100644 index 0000000000..35afc7e312 --- /dev/null +++ b/contrib/python/anyio/README.rst @@ -0,0 +1,57 @@ +.. image:: https://github.com/agronholm/anyio/actions/workflows/test.yml/badge.svg + :target: https://github.com/agronholm/anyio/actions/workflows/test.yml + :alt: Build Status +.. image:: https://coveralls.io/repos/github/agronholm/anyio/badge.svg?branch=master + :target: https://coveralls.io/github/agronholm/anyio?branch=master + :alt: Code Coverage +.. image:: https://readthedocs.org/projects/anyio/badge/?version=latest + :target: https://anyio.readthedocs.io/en/latest/?badge=latest + :alt: Documentation +.. image:: https://badges.gitter.im/gitterHQ/gitter.svg + :target: https://gitter.im/python-trio/AnyIO + :alt: Gitter chat + +AnyIO is an asynchronous networking and concurrency library that works on top of either asyncio_ or +trio_. It implements trio-like `structured concurrency`_ (SC) on top of asyncio and works in harmony +with the native SC of trio itself. + +Applications and libraries written against AnyIO's API will run unmodified on either asyncio_ or +trio_. AnyIO can also be adopted into a library or application incrementally – bit by bit, no full +refactoring necessary. It will blend in with the native libraries of your chosen backend. + +Documentation +------------- + +View full documentation at: https://anyio.readthedocs.io/ + +Features +-------- + +AnyIO offers the following functionality: + +* Task groups (nurseries_ in trio terminology) +* High-level networking (TCP, UDP and UNIX sockets) + + * `Happy eyeballs`_ algorithm for TCP connections (more robust than that of asyncio on Python + 3.8) + * async/await style UDP sockets (unlike asyncio where you still have to use Transports and + Protocols) + +* A versatile API for byte streams and object streams +* Inter-task synchronization and communication (locks, conditions, events, semaphores, object + streams) +* Worker threads +* Subprocesses +* Asynchronous file I/O (using worker threads) +* Signal handling + +AnyIO also comes with its own pytest_ plugin which also supports asynchronous fixtures. +It even works with the popular Hypothesis_ library. + +.. _asyncio: https://docs.python.org/3/library/asyncio.html +.. _trio: https://github.com/python-trio/trio +.. _structured concurrency: https://en.wikipedia.org/wiki/Structured_concurrency +.. _nurseries: https://trio.readthedocs.io/en/stable/reference-core.html#nurseries-and-spawning +.. _Happy eyeballs: https://en.wikipedia.org/wiki/Happy_Eyeballs +.. _pytest: https://docs.pytest.org/en/latest/ +.. _Hypothesis: https://hypothesis.works/ diff --git a/contrib/python/anyio/anyio/__init__.py b/contrib/python/anyio/anyio/__init__.py new file mode 100644 index 0000000000..29fb3561e4 --- /dev/null +++ b/contrib/python/anyio/anyio/__init__.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +__all__ = ( + "maybe_async", + "maybe_async_cm", + "run", + "sleep", + "sleep_forever", + "sleep_until", + "current_time", + "get_all_backends", + "get_cancelled_exc_class", + "BrokenResourceError", + "BrokenWorkerProcess", + "BusyResourceError", + "ClosedResourceError", + "DelimiterNotFound", + "EndOfStream", + "ExceptionGroup", + "IncompleteRead", + "TypedAttributeLookupError", + "WouldBlock", + "AsyncFile", + "Path", + "open_file", + "wrap_file", + "aclose_forcefully", + "open_signal_receiver", + "connect_tcp", + "connect_unix", + "create_tcp_listener", + "create_unix_listener", + "create_udp_socket", + "create_connected_udp_socket", + "getaddrinfo", + "getnameinfo", + "wait_socket_readable", + "wait_socket_writable", + "create_memory_object_stream", + "run_process", + "open_process", + "create_lock", + "CapacityLimiter", + "CapacityLimiterStatistics", + "Condition", + "ConditionStatistics", + "Event", + "EventStatistics", + "Lock", + "LockStatistics", + "Semaphore", + "SemaphoreStatistics", + "create_condition", + "create_event", + "create_semaphore", + "create_capacity_limiter", + "open_cancel_scope", + "fail_after", + "move_on_after", + "current_effective_deadline", + "TASK_STATUS_IGNORED", + "CancelScope", + "create_task_group", + "TaskInfo", + "get_current_task", + "get_running_tasks", + "wait_all_tasks_blocked", + "run_sync_in_worker_thread", + "run_async_from_thread", + "run_sync_from_thread", + "current_default_worker_thread_limiter", + "create_blocking_portal", + "start_blocking_portal", + "typed_attribute", + "TypedAttributeSet", + "TypedAttributeProvider", +) + +from typing import Any + +from ._core._compat import maybe_async, maybe_async_cm +from ._core._eventloop import ( + current_time, + get_all_backends, + get_cancelled_exc_class, + run, + sleep, + sleep_forever, + sleep_until, +) +from ._core._exceptions import ( + BrokenResourceError, + BrokenWorkerProcess, + BusyResourceError, + ClosedResourceError, + DelimiterNotFound, + EndOfStream, + ExceptionGroup, + IncompleteRead, + TypedAttributeLookupError, + WouldBlock, +) +from ._core._fileio import AsyncFile, Path, open_file, wrap_file +from ._core._resources import aclose_forcefully +from ._core._signals import open_signal_receiver +from ._core._sockets import ( + connect_tcp, + connect_unix, + create_connected_udp_socket, + create_tcp_listener, + create_udp_socket, + create_unix_listener, + getaddrinfo, + getnameinfo, + wait_socket_readable, + wait_socket_writable, +) +from ._core._streams import create_memory_object_stream +from ._core._subprocesses import open_process, run_process +from ._core._synchronization import ( + CapacityLimiter, + CapacityLimiterStatistics, + Condition, + ConditionStatistics, + Event, + EventStatistics, + Lock, + LockStatistics, + Semaphore, + SemaphoreStatistics, + create_capacity_limiter, + create_condition, + create_event, + create_lock, + create_semaphore, +) +from ._core._tasks import ( + TASK_STATUS_IGNORED, + CancelScope, + create_task_group, + current_effective_deadline, + fail_after, + move_on_after, + open_cancel_scope, +) +from ._core._testing import ( + TaskInfo, + get_current_task, + get_running_tasks, + wait_all_tasks_blocked, +) +from ._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute + +# Re-exported here, for backwards compatibility +# isort: off +from .to_thread import current_default_worker_thread_limiter, run_sync_in_worker_thread +from .from_thread import ( + create_blocking_portal, + run_async_from_thread, + run_sync_from_thread, + start_blocking_portal, +) + +# Re-export imports so they look like they live directly in this package +key: str +value: Any +for key, value in list(locals().items()): + if getattr(value, "__module__", "").startswith("anyio."): + value.__module__ = __name__ diff --git a/contrib/python/anyio/anyio/_backends/__init__.py b/contrib/python/anyio/anyio/_backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/anyio/anyio/_backends/__init__.py diff --git a/contrib/python/anyio/anyio/_backends/_asyncio.py b/contrib/python/anyio/anyio/_backends/_asyncio.py new file mode 100644 index 0000000000..bfdb4ea7e1 --- /dev/null +++ b/contrib/python/anyio/anyio/_backends/_asyncio.py @@ -0,0 +1,2117 @@ +from __future__ import annotations + +import array +import asyncio +import concurrent.futures +import math +import socket +import sys +from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined] +from collections import OrderedDict, deque +from concurrent.futures import Future +from contextvars import Context, copy_context +from dataclasses import dataclass +from functools import partial, wraps +from inspect import ( + CORO_RUNNING, + CORO_SUSPENDED, + GEN_RUNNING, + GEN_SUSPENDED, + getcoroutinestate, + getgeneratorstate, +) +from io import IOBase +from os import PathLike +from queue import Queue +from socket import AddressFamily, SocketKind +from threading import Thread +from types import TracebackType +from typing import ( + IO, + Any, + AsyncGenerator, + Awaitable, + Callable, + Collection, + Coroutine, + Generator, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) +from weakref import WeakKeyDictionary + +import sniffio + +from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc +from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable +from .._core._eventloop import claim_worker_thread, threadlocals +from .._core._exceptions import ( + BrokenResourceError, + BusyResourceError, + ClosedResourceError, + EndOfStream, + WouldBlock, +) +from .._core._exceptions import ExceptionGroup as BaseExceptionGroup +from .._core._sockets import GetAddrInfoReturnType, convert_ipv6_sockaddr +from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter +from .._core._synchronization import Event as BaseEvent +from .._core._synchronization import ResourceGuard +from .._core._tasks import CancelScope as BaseCancelScope +from ..abc import IPSockAddrType, UDPPacketType +from ..lowlevel import RunVar + +if sys.version_info >= (3, 8): + + def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: + return task.get_coro() + +else: + + def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: + return task._coro + + +from asyncio import all_tasks, create_task, current_task, get_running_loop +from asyncio import run as native_run + + +def _get_task_callbacks(task: asyncio.Task) -> Iterable[Callable]: + return [cb for cb, context in task._callbacks] + + +T_Retval = TypeVar("T_Retval") +T_contra = TypeVar("T_contra", contravariant=True) + +# Check whether there is native support for task names in asyncio (3.8+) +_native_task_names = hasattr(asyncio.Task, "get_name") + + +_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") + + +def find_root_task() -> asyncio.Task: + root_task = _root_task.get(None) + if root_task is not None and not root_task.done(): + return root_task + + # Look for a task that has been started via run_until_complete() + for task in all_tasks(): + if task._callbacks and not task.done(): + for cb in _get_task_callbacks(task): + if ( + cb is _run_until_complete_cb + or getattr(cb, "__module__", None) == "uvloop.loop" + ): + _root_task.set(task) + return task + + # Look up the topmost task in the AnyIO task tree, if possible + task = cast(asyncio.Task, current_task()) + state = _task_states.get(task) + if state: + cancel_scope = state.cancel_scope + while cancel_scope and cancel_scope._parent_scope is not None: + cancel_scope = cancel_scope._parent_scope + + if cancel_scope is not None: + return cast(asyncio.Task, cancel_scope._host_task) + + return task + + +def get_callable_name(func: Callable) -> str: + module = getattr(func, "__module__", None) + qualname = getattr(func, "__qualname__", None) + return ".".join([x for x in (module, qualname) if x]) + + +# +# Event loop +# + +_run_vars = ( + WeakKeyDictionary() +) # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] + +current_token = get_running_loop + + +def _task_started(task: asyncio.Task) -> bool: + """Return ``True`` if the task has been started and has not finished.""" + coro = cast(Coroutine[Any, Any, Any], get_coro(task)) + try: + return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED) + except AttributeError: + try: + return getgeneratorstate(cast(Generator, coro)) in ( + GEN_RUNNING, + GEN_SUSPENDED, + ) + except AttributeError: + # task coro is async_genenerator_asend https://bugs.python.org/issue37771 + raise Exception(f"Cannot determine if task {task} has started or not") + + +def _maybe_set_event_loop_policy( + policy: asyncio.AbstractEventLoopPolicy | None, use_uvloop: bool +) -> None: + # On CPython, use uvloop when possible if no other policy has been given and if not + # explicitly disabled + if policy is None and use_uvloop and sys.implementation.name == "cpython": + try: + import uvloop + except ImportError: + pass + else: + # Test for missing shutdown_default_executor() (uvloop 0.14.0 and earlier) + if not hasattr( + asyncio.AbstractEventLoop, "shutdown_default_executor" + ) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"): + policy = uvloop.EventLoopPolicy() + + if policy is not None: + asyncio.set_event_loop_policy(policy) + + +def run( + func: Callable[..., Awaitable[T_Retval]], + *args: object, + debug: bool = False, + use_uvloop: bool = False, + policy: asyncio.AbstractEventLoopPolicy | None = None, +) -> T_Retval: + @wraps(func) + async def wrapper() -> T_Retval: + task = cast(asyncio.Task, current_task()) + task_state = TaskState(None, get_callable_name(func), None) + _task_states[task] = task_state + if _native_task_names: + task.set_name(task_state.name) + + try: + return await func(*args) + finally: + del _task_states[task] + + _maybe_set_event_loop_policy(policy, use_uvloop) + return native_run(wrapper(), debug=debug) + + +# +# Miscellaneous +# + +sleep = asyncio.sleep + + +# +# Timeouts and cancellation +# + +CancelledError = asyncio.CancelledError + + +class CancelScope(BaseCancelScope): + def __new__( + cls, *, deadline: float = math.inf, shield: bool = False + ) -> CancelScope: + return object.__new__(cls) + + def __init__(self, deadline: float = math.inf, shield: bool = False): + self._deadline = deadline + self._shield = shield + self._parent_scope: CancelScope | None = None + self._cancel_called = False + self._active = False + self._timeout_handle: asyncio.TimerHandle | None = None + self._cancel_handle: asyncio.Handle | None = None + self._tasks: set[asyncio.Task] = set() + self._host_task: asyncio.Task | None = None + self._timeout_expired = False + self._cancel_calls: int = 0 + + def __enter__(self) -> CancelScope: + if self._active: + raise RuntimeError( + "Each CancelScope may only be used for a single 'with' block" + ) + + self._host_task = host_task = cast(asyncio.Task, current_task()) + self._tasks.add(host_task) + try: + task_state = _task_states[host_task] + except KeyError: + task_name = host_task.get_name() if _native_task_names else None + task_state = TaskState(None, task_name, self) + _task_states[host_task] = task_state + else: + self._parent_scope = task_state.cancel_scope + task_state.cancel_scope = self + + self._timeout() + self._active = True + + # Start cancelling the host task if the scope was cancelled before entering + if self._cancel_called: + self._deliver_cancellation() + + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + if not self._active: + raise RuntimeError("This cancel scope is not active") + if current_task() is not self._host_task: + raise RuntimeError( + "Attempted to exit cancel scope in a different task than it was " + "entered in" + ) + + assert self._host_task is not None + host_task_state = _task_states.get(self._host_task) + if host_task_state is None or host_task_state.cancel_scope is not self: + raise RuntimeError( + "Attempted to exit a cancel scope that isn't the current tasks's " + "current cancel scope" + ) + + self._active = False + if self._timeout_handle: + self._timeout_handle.cancel() + self._timeout_handle = None + + self._tasks.remove(self._host_task) + + host_task_state.cancel_scope = self._parent_scope + + # Restart the cancellation effort in the farthest directly cancelled parent scope if this + # one was shielded + if self._shield: + self._deliver_cancellation_to_parent() + + if exc_val is not None: + exceptions = ( + exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val] + ) + if all(isinstance(exc, CancelledError) for exc in exceptions): + if self._timeout_expired: + return self._uncancel() + elif not self._cancel_called: + # Task was cancelled natively + return None + elif not self._parent_cancelled(): + # This scope was directly cancelled + return self._uncancel() + + return None + + def _uncancel(self) -> bool: + if sys.version_info < (3, 11) or self._host_task is None: + self._cancel_calls = 0 + return True + + # Uncancel all AnyIO cancellations + for i in range(self._cancel_calls): + self._host_task.uncancel() + + self._cancel_calls = 0 + return not self._host_task.cancelling() + + def _timeout(self) -> None: + if self._deadline != math.inf: + loop = get_running_loop() + if loop.time() >= self._deadline: + self._timeout_expired = True + self.cancel() + else: + self._timeout_handle = loop.call_at(self._deadline, self._timeout) + + def _deliver_cancellation(self) -> None: + """ + Deliver cancellation to directly contained tasks and nested cancel scopes. + + Schedule another run at the end if we still have tasks eligible for cancellation. + """ + should_retry = False + current = current_task() + for task in self._tasks: + if task._must_cancel: # type: ignore[attr-defined] + continue + + # The task is eligible for cancellation if it has started and is not in a cancel + # scope shielded from this one + cancel_scope = _task_states[task].cancel_scope + while cancel_scope is not self: + if cancel_scope is None or cancel_scope._shield: + break + else: + cancel_scope = cancel_scope._parent_scope + else: + should_retry = True + if task is not current and ( + task is self._host_task or _task_started(task) + ): + self._cancel_calls += 1 + task.cancel() + + # Schedule another callback if there are still tasks left + if should_retry: + self._cancel_handle = get_running_loop().call_soon( + self._deliver_cancellation + ) + else: + self._cancel_handle = None + + def _deliver_cancellation_to_parent(self) -> None: + """Start cancellation effort in the farthest directly cancelled parent scope""" + scope = self._parent_scope + scope_to_cancel: CancelScope | None = None + while scope is not None: + if scope._cancel_called and scope._cancel_handle is None: + scope_to_cancel = scope + + # No point in looking beyond any shielded scope + if scope._shield: + break + + scope = scope._parent_scope + + if scope_to_cancel is not None: + scope_to_cancel._deliver_cancellation() + + def _parent_cancelled(self) -> bool: + # Check whether any parent has been cancelled + cancel_scope = self._parent_scope + while cancel_scope is not None and not cancel_scope._shield: + if cancel_scope._cancel_called: + return True + else: + cancel_scope = cancel_scope._parent_scope + + return False + + def cancel(self) -> DeprecatedAwaitable: + if not self._cancel_called: + if self._timeout_handle: + self._timeout_handle.cancel() + self._timeout_handle = None + + self._cancel_called = True + if self._host_task is not None: + self._deliver_cancellation() + + return DeprecatedAwaitable(self.cancel) + + @property + def deadline(self) -> float: + return self._deadline + + @deadline.setter + def deadline(self, value: float) -> None: + self._deadline = float(value) + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + + if self._active and not self._cancel_called: + self._timeout() + + @property + def cancel_called(self) -> bool: + return self._cancel_called + + @property + def shield(self) -> bool: + return self._shield + + @shield.setter + def shield(self, value: bool) -> None: + if self._shield != value: + self._shield = value + if not value: + self._deliver_cancellation_to_parent() + + +async def checkpoint() -> None: + await sleep(0) + + +async def checkpoint_if_cancelled() -> None: + task = current_task() + if task is None: + return + + try: + cancel_scope = _task_states[task].cancel_scope + except KeyError: + return + + while cancel_scope: + if cancel_scope.cancel_called: + await sleep(0) + elif cancel_scope.shield: + break + else: + cancel_scope = cancel_scope._parent_scope + + +async def cancel_shielded_checkpoint() -> None: + with CancelScope(shield=True): + await sleep(0) + + +def current_effective_deadline() -> float: + try: + cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index] + except KeyError: + return math.inf + + deadline = math.inf + while cancel_scope: + deadline = min(deadline, cancel_scope.deadline) + if cancel_scope._cancel_called: + deadline = -math.inf + break + elif cancel_scope.shield: + break + else: + cancel_scope = cancel_scope._parent_scope + + return deadline + + +def current_time() -> float: + return get_running_loop().time() + + +# +# Task states +# + + +class TaskState: + """ + Encapsulates auxiliary task information that cannot be added to the Task instance itself + because there are no guarantees about its implementation. + """ + + __slots__ = "parent_id", "name", "cancel_scope" + + def __init__( + self, + parent_id: int | None, + name: str | None, + cancel_scope: CancelScope | None, + ): + self.parent_id = parent_id + self.name = name + self.cancel_scope = cancel_scope + + +_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState] + + +# +# Task groups +# + + +class ExceptionGroup(BaseExceptionGroup): + def __init__(self, exceptions: list[BaseException]): + super().__init__() + self.exceptions = exceptions + + +class _AsyncioTaskStatus(abc.TaskStatus): + def __init__(self, future: asyncio.Future, parent_id: int): + self._future = future + self._parent_id = parent_id + + def started(self, value: T_contra | None = None) -> None: + try: + self._future.set_result(value) + except asyncio.InvalidStateError: + raise RuntimeError( + "called 'started' twice on the same task status" + ) from None + + task = cast(asyncio.Task, current_task()) + _task_states[task].parent_id = self._parent_id + + +class TaskGroup(abc.TaskGroup): + def __init__(self) -> None: + self.cancel_scope: CancelScope = CancelScope() + self._active = False + self._exceptions: list[BaseException] = [] + + async def __aenter__(self) -> TaskGroup: + self.cancel_scope.__enter__() + self._active = True + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) + if exc_val is not None: + self.cancel_scope.cancel() + self._exceptions.append(exc_val) + + while self.cancel_scope._tasks: + try: + await asyncio.wait(self.cancel_scope._tasks) + except asyncio.CancelledError: + self.cancel_scope.cancel() + + self._active = False + if not self.cancel_scope._parent_cancelled(): + exceptions = self._filter_cancellation_errors(self._exceptions) + else: + exceptions = self._exceptions + + try: + if len(exceptions) > 1: + if all( + isinstance(e, CancelledError) and not e.args for e in exceptions + ): + # Tasks were cancelled natively, without a cancellation message + raise CancelledError + else: + raise ExceptionGroup(exceptions) + elif exceptions and exceptions[0] is not exc_val: + raise exceptions[0] + except BaseException as exc: + # Clear the context here, as it can only be done in-flight. + # If the context is not cleared, it can result in recursive tracebacks (see #145). + exc.__context__ = None + raise + + return ignore_exception + + @staticmethod + def _filter_cancellation_errors( + exceptions: Sequence[BaseException], + ) -> list[BaseException]: + filtered_exceptions: list[BaseException] = [] + for exc in exceptions: + if isinstance(exc, ExceptionGroup): + new_exceptions = TaskGroup._filter_cancellation_errors(exc.exceptions) + if len(new_exceptions) > 1: + filtered_exceptions.append(exc) + elif len(new_exceptions) == 1: + filtered_exceptions.append(new_exceptions[0]) + elif new_exceptions: + new_exc = ExceptionGroup(new_exceptions) + new_exc.__cause__ = exc.__cause__ + new_exc.__context__ = exc.__context__ + new_exc.__traceback__ = exc.__traceback__ + filtered_exceptions.append(new_exc) + elif not isinstance(exc, CancelledError) or exc.args: + filtered_exceptions.append(exc) + + return filtered_exceptions + + async def _run_wrapped_task( + self, coro: Coroutine, task_status_future: asyncio.Future | None + ) -> None: + # This is the code path for Python 3.7 on which asyncio freaks out if a task + # raises a BaseException. + __traceback_hide__ = __tracebackhide__ = True # noqa: F841 + task = cast(asyncio.Task, current_task()) + try: + await coro + except BaseException as exc: + if task_status_future is None or task_status_future.done(): + self._exceptions.append(exc) + self.cancel_scope.cancel() + else: + task_status_future.set_exception(exc) + else: + if task_status_future is not None and not task_status_future.done(): + task_status_future.set_exception( + RuntimeError("Child exited without calling task_status.started()") + ) + finally: + if task in self.cancel_scope._tasks: + self.cancel_scope._tasks.remove(task) + del _task_states[task] + + def _spawn( + self, + func: Callable[..., Awaitable[Any]], + args: tuple, + name: object, + task_status_future: asyncio.Future | None = None, + ) -> asyncio.Task: + def task_done(_task: asyncio.Task) -> None: + # This is the code path for Python 3.8+ + assert _task in self.cancel_scope._tasks + self.cancel_scope._tasks.remove(_task) + del _task_states[_task] + + try: + exc = _task.exception() + except CancelledError as e: + while isinstance(e.__context__, CancelledError): + e = e.__context__ + + exc = e + + if exc is not None: + if task_status_future is None or task_status_future.done(): + self._exceptions.append(exc) + self.cancel_scope.cancel() + else: + task_status_future.set_exception(exc) + elif task_status_future is not None and not task_status_future.done(): + task_status_future.set_exception( + RuntimeError("Child exited without calling task_status.started()") + ) + + if not self._active: + raise RuntimeError( + "This task group is not active; no new tasks can be started." + ) + + options: dict[str, Any] = {} + name = get_callable_name(func) if name is None else str(name) + if _native_task_names: + options["name"] = name + + kwargs = {} + if task_status_future: + parent_id = id(current_task()) + kwargs["task_status"] = _AsyncioTaskStatus( + task_status_future, id(self.cancel_scope._host_task) + ) + else: + parent_id = id(self.cancel_scope._host_task) + + coro = func(*args, **kwargs) + if not asyncio.iscoroutine(coro): + raise TypeError( + f"Expected an async function, but {func} appears to be synchronous" + ) + + foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame") + if foreign_coro or sys.version_info < (3, 8): + coro = self._run_wrapped_task(coro, task_status_future) + + task = create_task(coro, **options) + if not foreign_coro and sys.version_info >= (3, 8): + task.add_done_callback(task_done) + + # Make the spawned task inherit the task group's cancel scope + _task_states[task] = TaskState( + parent_id=parent_id, name=name, cancel_scope=self.cancel_scope + ) + self.cancel_scope._tasks.add(task) + return task + + def start_soon( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> None: + self._spawn(func, args, name) + + async def start( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> None: + future: asyncio.Future = asyncio.Future() + task = self._spawn(func, args, name, future) + + # If the task raises an exception after sending a start value without a switch point + # between, the task group is cancelled and this method never proceeds to process the + # completed future. That's why we have to have a shielded cancel scope here. + with CancelScope(shield=True): + try: + return await future + except CancelledError: + task.cancel() + raise + + +# +# Threads +# + +_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]] + + +class WorkerThread(Thread): + MAX_IDLE_TIME = 10 # seconds + + def __init__( + self, + root_task: asyncio.Task, + workers: set[WorkerThread], + idle_workers: deque[WorkerThread], + ): + super().__init__(name="AnyIO worker thread") + self.root_task = root_task + self.workers = workers + self.idle_workers = idle_workers + self.loop = root_task._loop + self.queue: Queue[ + tuple[Context, Callable, tuple, asyncio.Future] | None + ] = Queue(2) + self.idle_since = current_time() + self.stopping = False + + def _report_result( + self, future: asyncio.Future, result: Any, exc: BaseException | None + ) -> None: + self.idle_since = current_time() + if not self.stopping: + self.idle_workers.append(self) + + if not future.cancelled(): + if exc is not None: + if isinstance(exc, StopIteration): + new_exc = RuntimeError("coroutine raised StopIteration") + new_exc.__cause__ = exc + exc = new_exc + + future.set_exception(exc) + else: + future.set_result(result) + + def run(self) -> None: + with claim_worker_thread("asyncio"): + threadlocals.loop = self.loop + while True: + item = self.queue.get() + if item is None: + # Shutdown command received + return + + context, func, args, future = item + if not future.cancelled(): + result = None + exception: BaseException | None = None + try: + result = context.run(func, *args) + except BaseException as exc: + exception = exc + + if not self.loop.is_closed(): + self.loop.call_soon_threadsafe( + self._report_result, future, result, exception + ) + + self.queue.task_done() + + def stop(self, f: asyncio.Task | None = None) -> None: + self.stopping = True + self.queue.put_nowait(None) + self.workers.discard(self) + try: + self.idle_workers.remove(self) + except ValueError: + pass + + +_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar( + "_threadpool_idle_workers" +) +_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers") + + +async def run_sync_in_worker_thread( + func: Callable[..., T_Retval], + *args: object, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T_Retval: + await checkpoint() + + # If this is the first run in this event loop thread, set up the necessary variables + try: + idle_workers = _threadpool_idle_workers.get() + workers = _threadpool_workers.get() + except LookupError: + idle_workers = deque() + workers = set() + _threadpool_idle_workers.set(idle_workers) + _threadpool_workers.set(workers) + + async with (limiter or current_default_thread_limiter()): + with CancelScope(shield=not cancellable): + future: asyncio.Future = asyncio.Future() + root_task = find_root_task() + if not idle_workers: + worker = WorkerThread(root_task, workers, idle_workers) + worker.start() + workers.add(worker) + root_task.add_done_callback(worker.stop) + else: + worker = idle_workers.pop() + + # Prune any other workers that have been idle for MAX_IDLE_TIME seconds or longer + now = current_time() + while idle_workers: + if now - idle_workers[0].idle_since < WorkerThread.MAX_IDLE_TIME: + break + + expired_worker = idle_workers.popleft() + expired_worker.root_task.remove_done_callback(expired_worker.stop) + expired_worker.stop() + + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, None) + worker.queue.put_nowait((context, func, args, future)) + return await future + + +def run_sync_from_thread( + func: Callable[..., T_Retval], + *args: object, + loop: asyncio.AbstractEventLoop | None = None, +) -> T_Retval: + @wraps(func) + def wrapper() -> None: + try: + f.set_result(func(*args)) + except BaseException as exc: + f.set_exception(exc) + if not isinstance(exc, Exception): + raise + + f: concurrent.futures.Future[T_Retval] = Future() + loop = loop or threadlocals.loop + loop.call_soon_threadsafe(wrapper) + return f.result() + + +def run_async_from_thread( + func: Callable[..., Awaitable[T_Retval]], *args: object +) -> T_Retval: + f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( + func(*args), threadlocals.loop + ) + return f.result() + + +class BlockingPortal(abc.BlockingPortal): + def __new__(cls) -> BlockingPortal: + return object.__new__(cls) + + def __init__(self) -> None: + super().__init__() + self._loop = get_running_loop() + + def _spawn_task_from_thread( + self, + func: Callable, + args: tuple, + kwargs: dict[str, Any], + name: object, + future: Future, + ) -> None: + run_sync_from_thread( + partial(self._task_group.start_soon, name=name), + self._call_func, + func, + args, + kwargs, + future, + loop=self._loop, + ) + + +# +# Subprocesses +# + + +@dataclass(eq=False) +class StreamReaderWrapper(abc.ByteReceiveStream): + _stream: asyncio.StreamReader + + async def receive(self, max_bytes: int = 65536) -> bytes: + data = await self._stream.read(max_bytes) + if data: + return data + else: + raise EndOfStream + + async def aclose(self) -> None: + self._stream.feed_eof() + + +@dataclass(eq=False) +class StreamWriterWrapper(abc.ByteSendStream): + _stream: asyncio.StreamWriter + + async def send(self, item: bytes) -> None: + self._stream.write(item) + await self._stream.drain() + + async def aclose(self) -> None: + self._stream.close() + + +@dataclass(eq=False) +class Process(abc.Process): + _process: asyncio.subprocess.Process + _stdin: StreamWriterWrapper | None + _stdout: StreamReaderWrapper | None + _stderr: StreamReaderWrapper | None + + async def aclose(self) -> None: + if self._stdin: + await self._stdin.aclose() + if self._stdout: + await self._stdout.aclose() + if self._stderr: + await self._stderr.aclose() + + await self.wait() + + async def wait(self) -> int: + return await self._process.wait() + + def terminate(self) -> None: + self._process.terminate() + + def kill(self) -> None: + self._process.kill() + + def send_signal(self, signal: int) -> None: + self._process.send_signal(signal) + + @property + def pid(self) -> int: + return self._process.pid + + @property + def returncode(self) -> int | None: + return self._process.returncode + + @property + def stdin(self) -> abc.ByteSendStream | None: + return self._stdin + + @property + def stdout(self) -> abc.ByteReceiveStream | None: + return self._stdout + + @property + def stderr(self) -> abc.ByteReceiveStream | None: + return self._stderr + + +async def open_process( + command: str | bytes | Sequence[str | bytes], + *, + shell: bool, + stdin: int | IO[Any] | None, + stdout: int | IO[Any] | None, + stderr: int | IO[Any] | None, + cwd: str | bytes | PathLike | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, +) -> Process: + await checkpoint() + if shell: + process = await asyncio.create_subprocess_shell( + cast(Union[str, bytes], command), + stdin=stdin, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) + else: + process = await asyncio.create_subprocess_exec( + *command, + stdin=stdin, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) + + stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None + stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None + stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None + return Process(process, stdin_stream, stdout_stream, stderr_stream) + + +def _forcibly_shutdown_process_pool_on_exit( + workers: set[Process], _task: object +) -> None: + """ + Forcibly shuts down worker processes belonging to this event loop.""" + child_watcher: asyncio.AbstractChildWatcher | None + try: + child_watcher = asyncio.get_event_loop_policy().get_child_watcher() + except NotImplementedError: + child_watcher = None + + # Close as much as possible (w/o async/await) to avoid warnings + for process in workers: + if process.returncode is None: + continue + + process._stdin._stream._transport.close() # type: ignore[union-attr] + process._stdout._stream._transport.close() # type: ignore[union-attr] + process._stderr._stream._transport.close() # type: ignore[union-attr] + process.kill() + if child_watcher: + child_watcher.remove_child_handler(process.pid) + + +async def _shutdown_process_pool_on_exit(workers: set[Process]) -> None: + """ + Shuts down worker processes belonging to this event loop. + + NOTE: this only works when the event loop was started using asyncio.run() or anyio.run(). + + """ + process: Process + try: + await sleep(math.inf) + except asyncio.CancelledError: + for process in workers: + if process.returncode is None: + process.kill() + + for process in workers: + await process.aclose() + + +def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None: + kwargs: dict[str, Any] = ( + {"name": "AnyIO process pool shutdown task"} if _native_task_names else {} + ) + create_task(_shutdown_process_pool_on_exit(workers), **kwargs) + find_root_task().add_done_callback( + partial(_forcibly_shutdown_process_pool_on_exit, workers) + ) + + +# +# Sockets and networking +# + + +class StreamProtocol(asyncio.Protocol): + read_queue: deque[bytes] + read_event: asyncio.Event + write_event: asyncio.Event + exception: Exception | None = None + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self.read_queue = deque() + self.read_event = asyncio.Event() + self.write_event = asyncio.Event() + self.write_event.set() + cast(asyncio.Transport, transport).set_write_buffer_limits(0) + + def connection_lost(self, exc: Exception | None) -> None: + if exc: + self.exception = BrokenResourceError() + self.exception.__cause__ = exc + + self.read_event.set() + self.write_event.set() + + def data_received(self, data: bytes) -> None: + self.read_queue.append(data) + self.read_event.set() + + def eof_received(self) -> bool | None: + self.read_event.set() + return True + + def pause_writing(self) -> None: + self.write_event = asyncio.Event() + + def resume_writing(self) -> None: + self.write_event.set() + + +class DatagramProtocol(asyncio.DatagramProtocol): + read_queue: deque[tuple[bytes, IPSockAddrType]] + read_event: asyncio.Event + write_event: asyncio.Event + exception: Exception | None = None + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self.read_queue = deque(maxlen=100) # arbitrary value + self.read_event = asyncio.Event() + self.write_event = asyncio.Event() + self.write_event.set() + + def connection_lost(self, exc: Exception | None) -> None: + self.read_event.set() + self.write_event.set() + + def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None: + addr = convert_ipv6_sockaddr(addr) + self.read_queue.append((data, addr)) + self.read_event.set() + + def error_received(self, exc: Exception) -> None: + self.exception = exc + + def pause_writing(self) -> None: + self.write_event.clear() + + def resume_writing(self) -> None: + self.write_event.set() + + +class SocketStream(abc.SocketStream): + def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol): + self._transport = transport + self._protocol = protocol + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + self._closed = False + + @property + def _raw_socket(self) -> socket.socket: + return self._transport.get_extra_info("socket") + + async def receive(self, max_bytes: int = 65536) -> bytes: + with self._receive_guard: + await checkpoint() + + if ( + not self._protocol.read_event.is_set() + and not self._transport.is_closing() + ): + self._transport.resume_reading() + await self._protocol.read_event.wait() + self._transport.pause_reading() + + try: + chunk = self._protocol.read_queue.popleft() + except IndexError: + if self._closed: + raise ClosedResourceError from None + elif self._protocol.exception: + raise self._protocol.exception + else: + raise EndOfStream from None + + if len(chunk) > max_bytes: + # Split the oversized chunk + chunk, leftover = chunk[:max_bytes], chunk[max_bytes:] + self._protocol.read_queue.appendleft(leftover) + + # If the read queue is empty, clear the flag so that the next call will block until + # data is available + if not self._protocol.read_queue: + self._protocol.read_event.clear() + + return chunk + + async def send(self, item: bytes) -> None: + with self._send_guard: + await checkpoint() + + if self._closed: + raise ClosedResourceError + elif self._protocol.exception is not None: + raise self._protocol.exception + + try: + self._transport.write(item) + except RuntimeError as exc: + if self._transport.is_closing(): + raise BrokenResourceError from exc + else: + raise + + await self._protocol.write_event.wait() + + async def send_eof(self) -> None: + try: + self._transport.write_eof() + except OSError: + pass + + async def aclose(self) -> None: + if not self._transport.is_closing(): + self._closed = True + try: + self._transport.write_eof() + except OSError: + pass + + self._transport.close() + await sleep(0) + self._transport.abort() + + +class UNIXSocketStream(abc.SocketStream): + _receive_future: asyncio.Future | None = None + _send_future: asyncio.Future | None = None + _closing = False + + def __init__(self, raw_socket: socket.socket): + self.__raw_socket = raw_socket + self._loop = get_running_loop() + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + + @property + def _raw_socket(self) -> socket.socket: + return self.__raw_socket + + def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: + def callback(f: object) -> None: + del self._receive_future + loop.remove_reader(self.__raw_socket) + + f = self._receive_future = asyncio.Future() + self._loop.add_reader(self.__raw_socket, f.set_result, None) + f.add_done_callback(callback) + return f + + def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: + def callback(f: object) -> None: + del self._send_future + loop.remove_writer(self.__raw_socket) + + f = self._send_future = asyncio.Future() + self._loop.add_writer(self.__raw_socket, f.set_result, None) + f.add_done_callback(callback) + return f + + async def send_eof(self) -> None: + with self._send_guard: + self._raw_socket.shutdown(socket.SHUT_WR) + + async def receive(self, max_bytes: int = 65536) -> bytes: + loop = get_running_loop() + await checkpoint() + with self._receive_guard: + while True: + try: + data = self.__raw_socket.recv(max_bytes) + except BlockingIOError: + await self._wait_until_readable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + else: + if not data: + raise EndOfStream + + return data + + async def send(self, item: bytes) -> None: + loop = get_running_loop() + await checkpoint() + with self._send_guard: + view = memoryview(item) + while view: + try: + bytes_sent = self.__raw_socket.send(view) + except BlockingIOError: + await self._wait_until_writable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + else: + view = view[bytes_sent:] + + async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: + if not isinstance(msglen, int) or msglen < 0: + raise ValueError("msglen must be a non-negative integer") + if not isinstance(maxfds, int) or maxfds < 1: + raise ValueError("maxfds must be a positive integer") + + loop = get_running_loop() + fds = array.array("i") + await checkpoint() + with self._receive_guard: + while True: + try: + message, ancdata, flags, addr = self.__raw_socket.recvmsg( + msglen, socket.CMSG_LEN(maxfds * fds.itemsize) + ) + except BlockingIOError: + await self._wait_until_readable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + else: + if not message and not ancdata: + raise EndOfStream + + break + + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: + raise RuntimeError( + f"Received unexpected ancillary data; message = {message!r}, " + f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" + ) + + fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + + return message, list(fds) + + async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: + if not message: + raise ValueError("message must not be empty") + if not fds: + raise ValueError("fds must not be empty") + + loop = get_running_loop() + filenos: list[int] = [] + for fd in fds: + if isinstance(fd, int): + filenos.append(fd) + elif isinstance(fd, IOBase): + filenos.append(fd.fileno()) + + fdarray = array.array("i", filenos) + await checkpoint() + with self._send_guard: + while True: + try: + # The ignore can be removed after mypy picks up + # https://github.com/python/typeshed/pull/5545 + self.__raw_socket.sendmsg( + [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)] + ) + break + except BlockingIOError: + await self._wait_until_writable(loop) + except OSError as exc: + if self._closing: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + + async def aclose(self) -> None: + if not self._closing: + self._closing = True + if self.__raw_socket.fileno() != -1: + self.__raw_socket.close() + + if self._receive_future: + self._receive_future.set_result(None) + if self._send_future: + self._send_future.set_result(None) + + +class TCPSocketListener(abc.SocketListener): + _accept_scope: CancelScope | None = None + _closed = False + + def __init__(self, raw_socket: socket.socket): + self.__raw_socket = raw_socket + self._loop = cast(asyncio.BaseEventLoop, get_running_loop()) + self._accept_guard = ResourceGuard("accepting connections from") + + @property + def _raw_socket(self) -> socket.socket: + return self.__raw_socket + + async def accept(self) -> abc.SocketStream: + if self._closed: + raise ClosedResourceError + + with self._accept_guard: + await checkpoint() + with CancelScope() as self._accept_scope: + try: + client_sock, _addr = await self._loop.sock_accept(self._raw_socket) + except asyncio.CancelledError: + # Workaround for https://bugs.python.org/issue41317 + try: + self._loop.remove_reader(self._raw_socket) + except (ValueError, NotImplementedError): + pass + + if self._closed: + raise ClosedResourceError from None + + raise + finally: + self._accept_scope = None + + client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + transport, protocol = await self._loop.connect_accepted_socket( + StreamProtocol, client_sock + ) + return SocketStream(transport, protocol) + + async def aclose(self) -> None: + if self._closed: + return + + self._closed = True + if self._accept_scope: + # Workaround for https://bugs.python.org/issue41317 + try: + self._loop.remove_reader(self._raw_socket) + except (ValueError, NotImplementedError): + pass + + self._accept_scope.cancel() + await sleep(0) + + self._raw_socket.close() + + +class UNIXSocketListener(abc.SocketListener): + def __init__(self, raw_socket: socket.socket): + self.__raw_socket = raw_socket + self._loop = get_running_loop() + self._accept_guard = ResourceGuard("accepting connections from") + self._closed = False + + async def accept(self) -> abc.SocketStream: + await checkpoint() + with self._accept_guard: + while True: + try: + client_sock, _ = self.__raw_socket.accept() + client_sock.setblocking(False) + return UNIXSocketStream(client_sock) + except BlockingIOError: + f: asyncio.Future = asyncio.Future() + self._loop.add_reader(self.__raw_socket, f.set_result, None) + f.add_done_callback( + lambda _: self._loop.remove_reader(self.__raw_socket) + ) + await f + except OSError as exc: + if self._closed: + raise ClosedResourceError from None + else: + raise BrokenResourceError from exc + + async def aclose(self) -> None: + self._closed = True + self.__raw_socket.close() + + @property + def _raw_socket(self) -> socket.socket: + return self.__raw_socket + + +class UDPSocket(abc.UDPSocket): + def __init__( + self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol + ): + self._transport = transport + self._protocol = protocol + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + self._closed = False + + @property + def _raw_socket(self) -> socket.socket: + return self._transport.get_extra_info("socket") + + async def aclose(self) -> None: + if not self._transport.is_closing(): + self._closed = True + self._transport.close() + + async def receive(self) -> tuple[bytes, IPSockAddrType]: + with self._receive_guard: + await checkpoint() + + # If the buffer is empty, ask for more data + if not self._protocol.read_queue and not self._transport.is_closing(): + self._protocol.read_event.clear() + await self._protocol.read_event.wait() + + try: + return self._protocol.read_queue.popleft() + except IndexError: + if self._closed: + raise ClosedResourceError from None + else: + raise BrokenResourceError from None + + async def send(self, item: UDPPacketType) -> None: + with self._send_guard: + await checkpoint() + await self._protocol.write_event.wait() + if self._closed: + raise ClosedResourceError + elif self._transport.is_closing(): + raise BrokenResourceError + else: + self._transport.sendto(*item) + + +class ConnectedUDPSocket(abc.ConnectedUDPSocket): + def __init__( + self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol + ): + self._transport = transport + self._protocol = protocol + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + self._closed = False + + @property + def _raw_socket(self) -> socket.socket: + return self._transport.get_extra_info("socket") + + async def aclose(self) -> None: + if not self._transport.is_closing(): + self._closed = True + self._transport.close() + + async def receive(self) -> bytes: + with self._receive_guard: + await checkpoint() + + # If the buffer is empty, ask for more data + if not self._protocol.read_queue and not self._transport.is_closing(): + self._protocol.read_event.clear() + await self._protocol.read_event.wait() + + try: + packet = self._protocol.read_queue.popleft() + except IndexError: + if self._closed: + raise ClosedResourceError from None + else: + raise BrokenResourceError from None + + return packet[0] + + async def send(self, item: bytes) -> None: + with self._send_guard: + await checkpoint() + await self._protocol.write_event.wait() + if self._closed: + raise ClosedResourceError + elif self._transport.is_closing(): + raise BrokenResourceError + else: + self._transport.sendto(item) + + +async def connect_tcp( + host: str, port: int, local_addr: tuple[str, int] | None = None +) -> SocketStream: + transport, protocol = cast( + Tuple[asyncio.Transport, StreamProtocol], + await get_running_loop().create_connection( + StreamProtocol, host, port, local_addr=local_addr + ), + ) + transport.pause_reading() + return SocketStream(transport, protocol) + + +async def connect_unix(path: str) -> UNIXSocketStream: + await checkpoint() + loop = get_running_loop() + raw_socket = socket.socket(socket.AF_UNIX) + raw_socket.setblocking(False) + while True: + try: + raw_socket.connect(path) + except BlockingIOError: + f: asyncio.Future = asyncio.Future() + loop.add_writer(raw_socket, f.set_result, None) + f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) + await f + except BaseException: + raw_socket.close() + raise + else: + return UNIXSocketStream(raw_socket) + + +async def create_udp_socket( + family: socket.AddressFamily, + local_address: IPSockAddrType | None, + remote_address: IPSockAddrType | None, + reuse_port: bool, +) -> UDPSocket | ConnectedUDPSocket: + result = await get_running_loop().create_datagram_endpoint( + DatagramProtocol, + local_addr=local_address, + remote_addr=remote_address, + family=family, + reuse_port=reuse_port, + ) + transport = result[0] + protocol = result[1] + if protocol.exception: + transport.close() + raise protocol.exception + + if not remote_address: + return UDPSocket(transport, protocol) + else: + return ConnectedUDPSocket(transport, protocol) + + +async def getaddrinfo( + host: bytes | str, + port: str | int | None, + *, + family: int | AddressFamily = 0, + type: int | SocketKind = 0, + proto: int = 0, + flags: int = 0, +) -> GetAddrInfoReturnType: + # https://github.com/python/typeshed/pull/4304 + result = await get_running_loop().getaddrinfo( + host, port, family=family, type=type, proto=proto, flags=flags + ) + return cast(GetAddrInfoReturnType, result) + + +async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> tuple[str, str]: + return await get_running_loop().getnameinfo(sockaddr, flags) + + +_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events") +_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events") + + +async def wait_socket_readable(sock: socket.socket) -> None: + await checkpoint() + try: + read_events = _read_events.get() + except LookupError: + read_events = {} + _read_events.set(read_events) + + if read_events.get(sock): + raise BusyResourceError("reading from") from None + + loop = get_running_loop() + event = read_events[sock] = asyncio.Event() + loop.add_reader(sock, event.set) + try: + await event.wait() + finally: + if read_events.pop(sock, None) is not None: + loop.remove_reader(sock) + readable = True + else: + readable = False + + if not readable: + raise ClosedResourceError + + +async def wait_socket_writable(sock: socket.socket) -> None: + await checkpoint() + try: + write_events = _write_events.get() + except LookupError: + write_events = {} + _write_events.set(write_events) + + if write_events.get(sock): + raise BusyResourceError("writing to") from None + + loop = get_running_loop() + event = write_events[sock] = asyncio.Event() + loop.add_writer(sock.fileno(), event.set) + try: + await event.wait() + finally: + if write_events.pop(sock, None) is not None: + loop.remove_writer(sock) + writable = True + else: + writable = False + + if not writable: + raise ClosedResourceError + + +# +# Synchronization +# + + +class Event(BaseEvent): + def __new__(cls) -> Event: + return object.__new__(cls) + + def __init__(self) -> None: + self._event = asyncio.Event() + + def set(self) -> DeprecatedAwaitable: + self._event.set() + return DeprecatedAwaitable(self.set) + + def is_set(self) -> bool: + return self._event.is_set() + + async def wait(self) -> None: + if await self._event.wait(): + await checkpoint() + + def statistics(self) -> EventStatistics: + return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined] + + +class CapacityLimiter(BaseCapacityLimiter): + _total_tokens: float = 0 + + def __new__(cls, total_tokens: float) -> CapacityLimiter: + return object.__new__(cls) + + def __init__(self, total_tokens: float): + self._borrowers: set[Any] = set() + self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict() + self.total_tokens = total_tokens + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.release() + + @property + def total_tokens(self) -> float: + return self._total_tokens + + @total_tokens.setter + def total_tokens(self, value: float) -> None: + if not isinstance(value, int) and not math.isinf(value): + raise TypeError("total_tokens must be an int or math.inf") + if value < 1: + raise ValueError("total_tokens must be >= 1") + + old_value = self._total_tokens + self._total_tokens = value + events = [] + for event in self._wait_queue.values(): + if value <= old_value: + break + + if not event.is_set(): + events.append(event) + old_value += 1 + + for event in events: + event.set() + + @property + def borrowed_tokens(self) -> int: + return len(self._borrowers) + + @property + def available_tokens(self) -> float: + return self._total_tokens - len(self._borrowers) + + def acquire_nowait(self) -> DeprecatedAwaitable: + self.acquire_on_behalf_of_nowait(current_task()) + return DeprecatedAwaitable(self.acquire_nowait) + + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: + if borrower in self._borrowers: + raise RuntimeError( + "this borrower is already holding one of this CapacityLimiter's " + "tokens" + ) + + if self._wait_queue or len(self._borrowers) >= self._total_tokens: + raise WouldBlock + + self._borrowers.add(borrower) + return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) + + async def acquire(self) -> None: + return await self.acquire_on_behalf_of(current_task()) + + async def acquire_on_behalf_of(self, borrower: object) -> None: + await checkpoint_if_cancelled() + try: + self.acquire_on_behalf_of_nowait(borrower) + except WouldBlock: + event = asyncio.Event() + self._wait_queue[borrower] = event + try: + await event.wait() + except BaseException: + self._wait_queue.pop(borrower, None) + raise + + self._borrowers.add(borrower) + else: + try: + await cancel_shielded_checkpoint() + except BaseException: + self.release() + raise + + def release(self) -> None: + self.release_on_behalf_of(current_task()) + + def release_on_behalf_of(self, borrower: object) -> None: + try: + self._borrowers.remove(borrower) + except KeyError: + raise RuntimeError( + "this borrower isn't holding any of this CapacityLimiter's " "tokens" + ) from None + + # Notify the next task in line if this limiter has free capacity now + if self._wait_queue and len(self._borrowers) < self._total_tokens: + event = self._wait_queue.popitem(last=False)[1] + event.set() + + def statistics(self) -> CapacityLimiterStatistics: + return CapacityLimiterStatistics( + self.borrowed_tokens, + self.total_tokens, + tuple(self._borrowers), + len(self._wait_queue), + ) + + +_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter") + + +def current_default_thread_limiter() -> CapacityLimiter: + try: + return _default_thread_limiter.get() + except LookupError: + limiter = CapacityLimiter(40) + _default_thread_limiter.set(limiter) + return limiter + + +# +# Operating system signals +# + + +class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): + def __init__(self, signals: tuple[int, ...]): + self._signals = signals + self._loop = get_running_loop() + self._signal_queue: deque[int] = deque() + self._future: asyncio.Future = asyncio.Future() + self._handled_signals: set[int] = set() + + def _deliver(self, signum: int) -> None: + self._signal_queue.append(signum) + if not self._future.done(): + self._future.set_result(None) + + def __enter__(self) -> _SignalReceiver: + for sig in set(self._signals): + self._loop.add_signal_handler(sig, self._deliver, sig) + self._handled_signals.add(sig) + + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + for sig in self._handled_signals: + self._loop.remove_signal_handler(sig) + return None + + def __aiter__(self) -> _SignalReceiver: + return self + + async def __anext__(self) -> int: + await checkpoint() + if not self._signal_queue: + self._future = asyncio.Future() + await self._future + + return self._signal_queue.popleft() + + +def open_signal_receiver(*signals: int) -> _SignalReceiver: + return _SignalReceiver(signals) + + +# +# Testing and debugging +# + + +def _create_task_info(task: asyncio.Task) -> TaskInfo: + task_state = _task_states.get(task) + if task_state is None: + name = task.get_name() if _native_task_names else None + parent_id = None + else: + name = task_state.name + parent_id = task_state.parent_id + + return TaskInfo(id(task), parent_id, name, get_coro(task)) + + +def get_current_task() -> TaskInfo: + return _create_task_info(current_task()) # type: ignore[arg-type] + + +def get_running_tasks() -> list[TaskInfo]: + return [_create_task_info(task) for task in all_tasks() if not task.done()] + + +async def wait_all_tasks_blocked() -> None: + await checkpoint() + this_task = current_task() + while True: + for task in all_tasks(): + if task is this_task: + continue + + if task._fut_waiter is None or task._fut_waiter.done(): # type: ignore[attr-defined] + await sleep(0.1) + break + else: + return + + +class TestRunner(abc.TestRunner): + def __init__( + self, + debug: bool = False, + use_uvloop: bool = False, + policy: asyncio.AbstractEventLoopPolicy | None = None, + ): + self._exceptions: list[BaseException] = [] + _maybe_set_event_loop_policy(policy, use_uvloop) + self._loop = asyncio.new_event_loop() + self._loop.set_debug(debug) + self._loop.set_exception_handler(self._exception_handler) + asyncio.set_event_loop(self._loop) + + def _cancel_all_tasks(self) -> None: + to_cancel = all_tasks(self._loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + self._loop.run_until_complete( + asyncio.gather(*to_cancel, return_exceptions=True) + ) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + raise cast(BaseException, task.exception()) + + def _exception_handler( + self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] + ) -> None: + if isinstance(context.get("exception"), Exception): + self._exceptions.append(context["exception"]) + else: + loop.default_exception_handler(context) + + def _raise_async_exceptions(self) -> None: + # Re-raise any exceptions raised in asynchronous callbacks + if self._exceptions: + exceptions, self._exceptions = self._exceptions, [] + if len(exceptions) == 1: + raise exceptions[0] + elif exceptions: + raise ExceptionGroup(exceptions) + + def close(self) -> None: + try: + self._cancel_all_tasks() + self._loop.run_until_complete(self._loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + self._loop.close() + + def run_asyncgen_fixture( + self, + fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], + kwargs: dict[str, Any], + ) -> Iterable[T_Retval]: + async def fixture_runner() -> None: + agen = fixture_func(**kwargs) + try: + retval = await agen.asend(None) + self._raise_async_exceptions() + except BaseException as exc: + f.set_exception(exc) + return + else: + f.set_result(retval) + + await event.wait() + try: + await agen.asend(None) + except StopAsyncIteration: + pass + else: + await agen.aclose() + raise RuntimeError("Async generator fixture did not stop") + + f = self._loop.create_future() + event = asyncio.Event() + fixture_task = self._loop.create_task(fixture_runner()) + self._loop.run_until_complete(f) + yield f.result() + event.set() + self._loop.run_until_complete(fixture_task) + self._raise_async_exceptions() + + def run_fixture( + self, + fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], + kwargs: dict[str, Any], + ) -> T_Retval: + retval = self._loop.run_until_complete(fixture_func(**kwargs)) + self._raise_async_exceptions() + return retval + + def run_test( + self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] + ) -> None: + try: + self._loop.run_until_complete(test_func(**kwargs)) + except Exception as exc: + self._exceptions.append(exc) + + self._raise_async_exceptions() diff --git a/contrib/python/anyio/anyio/_backends/_trio.py b/contrib/python/anyio/anyio/_backends/_trio.py new file mode 100644 index 0000000000..cf28943509 --- /dev/null +++ b/contrib/python/anyio/anyio/_backends/_trio.py @@ -0,0 +1,996 @@ +from __future__ import annotations + +import array +import math +import socket +from concurrent.futures import Future +from contextvars import copy_context +from dataclasses import dataclass +from functools import partial +from io import IOBase +from os import PathLike +from signal import Signals +from types import TracebackType +from typing import ( + IO, + TYPE_CHECKING, + Any, + AsyncGenerator, + AsyncIterator, + Awaitable, + Callable, + Collection, + Coroutine, + Generic, + Iterable, + Mapping, + NoReturn, + Sequence, + TypeVar, + cast, +) + +import sniffio +import trio.from_thread +from outcome import Error, Outcome, Value +from trio.socket import SocketType as TrioSocketType +from trio.to_thread import run_sync + +from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc +from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable +from .._core._eventloop import claim_worker_thread +from .._core._exceptions import ( + BrokenResourceError, + BusyResourceError, + ClosedResourceError, + EndOfStream, +) +from .._core._exceptions import ExceptionGroup as BaseExceptionGroup +from .._core._sockets import convert_ipv6_sockaddr +from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter +from .._core._synchronization import Event as BaseEvent +from .._core._synchronization import ResourceGuard +from .._core._tasks import CancelScope as BaseCancelScope +from ..abc import IPSockAddrType, UDPPacketType + +if TYPE_CHECKING: + from trio_typing import TaskStatus + +try: + from trio import lowlevel as trio_lowlevel +except ImportError: + from trio import hazmat as trio_lowlevel # type: ignore[no-redef] + from trio.hazmat import wait_readable, wait_writable +else: + from trio.lowlevel import wait_readable, wait_writable + +try: + trio_open_process = trio_lowlevel.open_process +except AttributeError: + # isort: off + from trio import ( # type: ignore[attr-defined, no-redef] + open_process as trio_open_process, + ) + +T_Retval = TypeVar("T_Retval") +T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType) + + +# +# Event loop +# + +run = trio.run +current_token = trio.lowlevel.current_trio_token +RunVar = trio.lowlevel.RunVar + + +# +# Miscellaneous +# + +sleep = trio.sleep + + +# +# Timeouts and cancellation +# + + +class CancelScope(BaseCancelScope): + def __new__( + cls, original: trio.CancelScope | None = None, **kwargs: object + ) -> CancelScope: + return object.__new__(cls) + + def __init__(self, original: trio.CancelScope | None = None, **kwargs: Any) -> None: + self.__original = original or trio.CancelScope(**kwargs) + + def __enter__(self) -> CancelScope: + self.__original.__enter__() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + # https://github.com/python-trio/trio-typing/pull/79 + return self.__original.__exit__( # type: ignore[func-returns-value] + exc_type, exc_val, exc_tb + ) + + def cancel(self) -> DeprecatedAwaitable: + self.__original.cancel() + return DeprecatedAwaitable(self.cancel) + + @property + def deadline(self) -> float: + return self.__original.deadline + + @deadline.setter + def deadline(self, value: float) -> None: + self.__original.deadline = value + + @property + def cancel_called(self) -> bool: + return self.__original.cancel_called + + @property + def shield(self) -> bool: + return self.__original.shield + + @shield.setter + def shield(self, value: bool) -> None: + self.__original.shield = value + + +CancelledError = trio.Cancelled +checkpoint = trio.lowlevel.checkpoint +checkpoint_if_cancelled = trio.lowlevel.checkpoint_if_cancelled +cancel_shielded_checkpoint = trio.lowlevel.cancel_shielded_checkpoint +current_effective_deadline = trio.current_effective_deadline +current_time = trio.current_time + + +# +# Task groups +# + + +class ExceptionGroup(BaseExceptionGroup, trio.MultiError): + pass + + +class TaskGroup(abc.TaskGroup): + def __init__(self) -> None: + self._active = False + self._nursery_manager = trio.open_nursery() + self.cancel_scope = None # type: ignore[assignment] + + async def __aenter__(self) -> TaskGroup: + self._active = True + self._nursery = await self._nursery_manager.__aenter__() + self.cancel_scope = CancelScope(self._nursery.cancel_scope) + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + try: + return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) + except trio.MultiError as exc: + raise ExceptionGroup(exc.exceptions) from None + finally: + self._active = False + + def start_soon( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> None: + if not self._active: + raise RuntimeError( + "This task group is not active; no new tasks can be started." + ) + + self._nursery.start_soon(func, *args, name=name) + + async def start( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> object: + if not self._active: + raise RuntimeError( + "This task group is not active; no new tasks can be started." + ) + + return await self._nursery.start(func, *args, name=name) + + +# +# Threads +# + + +async def run_sync_in_worker_thread( + func: Callable[..., T_Retval], + *args: object, + cancellable: bool = False, + limiter: trio.CapacityLimiter | None = None, +) -> T_Retval: + def wrapper() -> T_Retval: + with claim_worker_thread("trio"): + return func(*args) + + # TODO: remove explicit context copying when trio 0.20 is the minimum requirement + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, None) + return await run_sync( + context.run, wrapper, cancellable=cancellable, limiter=limiter + ) + + +# TODO: remove this workaround when trio 0.20 is the minimum requirement +def run_async_from_thread( + fn: Callable[..., Awaitable[T_Retval]], *args: Any +) -> T_Retval: + async def wrapper() -> T_Retval: + retval: T_Retval + + async def inner() -> None: + nonlocal retval + __tracebackhide__ = True + retval = await fn(*args) + + async with trio.open_nursery() as n: + context.run(n.start_soon, inner) + + __tracebackhide__ = True + return retval # noqa: F821 + + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, "trio") + return trio.from_thread.run(wrapper) + + +def run_sync_from_thread(fn: Callable[..., T_Retval], *args: Any) -> T_Retval: + # TODO: remove explicit context copying when trio 0.20 is the minimum requirement + retval = trio.from_thread.run_sync(copy_context().run, fn, *args) + return cast(T_Retval, retval) + + +class BlockingPortal(abc.BlockingPortal): + def __new__(cls) -> BlockingPortal: + return object.__new__(cls) + + def __init__(self) -> None: + super().__init__() + self._token = trio.lowlevel.current_trio_token() + + def _spawn_task_from_thread( + self, + func: Callable, + args: tuple, + kwargs: dict[str, Any], + name: object, + future: Future, + ) -> None: + context = copy_context() + context.run(sniffio.current_async_library_cvar.set, "trio") + trio.from_thread.run_sync( + context.run, + partial(self._task_group.start_soon, name=name), + self._call_func, + func, + args, + kwargs, + future, + trio_token=self._token, + ) + + +# +# Subprocesses +# + + +@dataclass(eq=False) +class ReceiveStreamWrapper(abc.ByteReceiveStream): + _stream: trio.abc.ReceiveStream + + async def receive(self, max_bytes: int | None = None) -> bytes: + try: + data = await self._stream.receive_some(max_bytes) + except trio.ClosedResourceError as exc: + raise ClosedResourceError from exc.__cause__ + except trio.BrokenResourceError as exc: + raise BrokenResourceError from exc.__cause__ + + if data: + return data + else: + raise EndOfStream + + async def aclose(self) -> None: + await self._stream.aclose() + + +@dataclass(eq=False) +class SendStreamWrapper(abc.ByteSendStream): + _stream: trio.abc.SendStream + + async def send(self, item: bytes) -> None: + try: + await self._stream.send_all(item) + except trio.ClosedResourceError as exc: + raise ClosedResourceError from exc.__cause__ + except trio.BrokenResourceError as exc: + raise BrokenResourceError from exc.__cause__ + + async def aclose(self) -> None: + await self._stream.aclose() + + +@dataclass(eq=False) +class Process(abc.Process): + _process: trio.Process + _stdin: abc.ByteSendStream | None + _stdout: abc.ByteReceiveStream | None + _stderr: abc.ByteReceiveStream | None + + async def aclose(self) -> None: + if self._stdin: + await self._stdin.aclose() + if self._stdout: + await self._stdout.aclose() + if self._stderr: + await self._stderr.aclose() + + await self.wait() + + async def wait(self) -> int: + return await self._process.wait() + + def terminate(self) -> None: + self._process.terminate() + + def kill(self) -> None: + self._process.kill() + + def send_signal(self, signal: Signals) -> None: + self._process.send_signal(signal) + + @property + def pid(self) -> int: + return self._process.pid + + @property + def returncode(self) -> int | None: + return self._process.returncode + + @property + def stdin(self) -> abc.ByteSendStream | None: + return self._stdin + + @property + def stdout(self) -> abc.ByteReceiveStream | None: + return self._stdout + + @property + def stderr(self) -> abc.ByteReceiveStream | None: + return self._stderr + + +async def open_process( + command: str | bytes | Sequence[str | bytes], + *, + shell: bool, + stdin: int | IO[Any] | None, + stdout: int | IO[Any] | None, + stderr: int | IO[Any] | None, + cwd: str | bytes | PathLike | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, +) -> Process: + process = await trio_open_process( # type: ignore[misc] + command, # type: ignore[arg-type] + stdin=stdin, + stdout=stdout, + stderr=stderr, + shell=shell, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) + stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None + stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None + stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None + return Process(process, stdin_stream, stdout_stream, stderr_stream) + + +class _ProcessPoolShutdownInstrument(trio.abc.Instrument): + def after_run(self) -> None: + super().after_run() + + +current_default_worker_process_limiter: RunVar = RunVar( + "current_default_worker_process_limiter" +) + + +async def _shutdown_process_pool(workers: set[Process]) -> None: + process: Process + try: + await sleep(math.inf) + except trio.Cancelled: + for process in workers: + if process.returncode is None: + process.kill() + + with CancelScope(shield=True): + for process in workers: + await process.aclose() + + +def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None: + trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers) + + +# +# Sockets and networking +# + + +class _TrioSocketMixin(Generic[T_SockAddr]): + def __init__(self, trio_socket: TrioSocketType) -> None: + self._trio_socket = trio_socket + self._closed = False + + def _check_closed(self) -> None: + if self._closed: + raise ClosedResourceError + if self._trio_socket.fileno() < 0: + raise BrokenResourceError + + @property + def _raw_socket(self) -> socket.socket: + return self._trio_socket._sock # type: ignore[attr-defined] + + async def aclose(self) -> None: + if self._trio_socket.fileno() >= 0: + self._closed = True + self._trio_socket.close() + + def _convert_socket_error(self, exc: BaseException) -> NoReturn: + if isinstance(exc, trio.ClosedResourceError): + raise ClosedResourceError from exc + elif self._trio_socket.fileno() < 0 and self._closed: + raise ClosedResourceError from None + elif isinstance(exc, OSError): + raise BrokenResourceError from exc + else: + raise exc + + +class SocketStream(_TrioSocketMixin, abc.SocketStream): + def __init__(self, trio_socket: TrioSocketType) -> None: + super().__init__(trio_socket) + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + + async def receive(self, max_bytes: int = 65536) -> bytes: + with self._receive_guard: + try: + data = await self._trio_socket.recv(max_bytes) + except BaseException as exc: + self._convert_socket_error(exc) + + if data: + return data + else: + raise EndOfStream + + async def send(self, item: bytes) -> None: + with self._send_guard: + view = memoryview(item) + while view: + try: + bytes_sent = await self._trio_socket.send(view) + except BaseException as exc: + self._convert_socket_error(exc) + + view = view[bytes_sent:] + + async def send_eof(self) -> None: + self._trio_socket.shutdown(socket.SHUT_WR) + + +class UNIXSocketStream(SocketStream, abc.UNIXSocketStream): + async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: + if not isinstance(msglen, int) or msglen < 0: + raise ValueError("msglen must be a non-negative integer") + if not isinstance(maxfds, int) or maxfds < 1: + raise ValueError("maxfds must be a positive integer") + + fds = array.array("i") + await checkpoint() + with self._receive_guard: + while True: + try: + message, ancdata, flags, addr = await self._trio_socket.recvmsg( + msglen, socket.CMSG_LEN(maxfds * fds.itemsize) + ) + except BaseException as exc: + self._convert_socket_error(exc) + else: + if not message and not ancdata: + raise EndOfStream + + break + + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: + raise RuntimeError( + f"Received unexpected ancillary data; message = {message!r}, " + f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" + ) + + fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + + return message, list(fds) + + async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: + if not message: + raise ValueError("message must not be empty") + if not fds: + raise ValueError("fds must not be empty") + + filenos: list[int] = [] + for fd in fds: + if isinstance(fd, int): + filenos.append(fd) + elif isinstance(fd, IOBase): + filenos.append(fd.fileno()) + + fdarray = array.array("i", filenos) + await checkpoint() + with self._send_guard: + while True: + try: + await self._trio_socket.sendmsg( + [message], + [ + ( + socket.SOL_SOCKET, + socket.SCM_RIGHTS, # type: ignore[list-item] + fdarray, + ) + ], + ) + break + except BaseException as exc: + self._convert_socket_error(exc) + + +class TCPSocketListener(_TrioSocketMixin, abc.SocketListener): + def __init__(self, raw_socket: socket.socket): + super().__init__(trio.socket.from_stdlib_socket(raw_socket)) + self._accept_guard = ResourceGuard("accepting connections from") + + async def accept(self) -> SocketStream: + with self._accept_guard: + try: + trio_socket, _addr = await self._trio_socket.accept() + except BaseException as exc: + self._convert_socket_error(exc) + + trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return SocketStream(trio_socket) + + +class UNIXSocketListener(_TrioSocketMixin, abc.SocketListener): + def __init__(self, raw_socket: socket.socket): + super().__init__(trio.socket.from_stdlib_socket(raw_socket)) + self._accept_guard = ResourceGuard("accepting connections from") + + async def accept(self) -> UNIXSocketStream: + with self._accept_guard: + try: + trio_socket, _addr = await self._trio_socket.accept() + except BaseException as exc: + self._convert_socket_error(exc) + + return UNIXSocketStream(trio_socket) + + +class UDPSocket(_TrioSocketMixin[IPSockAddrType], abc.UDPSocket): + def __init__(self, trio_socket: TrioSocketType) -> None: + super().__init__(trio_socket) + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + + async def receive(self) -> tuple[bytes, IPSockAddrType]: + with self._receive_guard: + try: + data, addr = await self._trio_socket.recvfrom(65536) + return data, convert_ipv6_sockaddr(addr) + except BaseException as exc: + self._convert_socket_error(exc) + + async def send(self, item: UDPPacketType) -> None: + with self._send_guard: + try: + await self._trio_socket.sendto(*item) + except BaseException as exc: + self._convert_socket_error(exc) + + +class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocket): + def __init__(self, trio_socket: TrioSocketType) -> None: + super().__init__(trio_socket) + self._receive_guard = ResourceGuard("reading from") + self._send_guard = ResourceGuard("writing to") + + async def receive(self) -> bytes: + with self._receive_guard: + try: + return await self._trio_socket.recv(65536) + except BaseException as exc: + self._convert_socket_error(exc) + + async def send(self, item: bytes) -> None: + with self._send_guard: + try: + await self._trio_socket.send(item) + except BaseException as exc: + self._convert_socket_error(exc) + + +async def connect_tcp( + host: str, port: int, local_address: IPSockAddrType | None = None +) -> SocketStream: + family = socket.AF_INET6 if ":" in host else socket.AF_INET + trio_socket = trio.socket.socket(family) + trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if local_address: + await trio_socket.bind(local_address) + + try: + await trio_socket.connect((host, port)) + except BaseException: + trio_socket.close() + raise + + return SocketStream(trio_socket) + + +async def connect_unix(path: str) -> UNIXSocketStream: + trio_socket = trio.socket.socket(socket.AF_UNIX) + try: + await trio_socket.connect(path) + except BaseException: + trio_socket.close() + raise + + return UNIXSocketStream(trio_socket) + + +async def create_udp_socket( + family: socket.AddressFamily, + local_address: IPSockAddrType | None, + remote_address: IPSockAddrType | None, + reuse_port: bool, +) -> UDPSocket | ConnectedUDPSocket: + trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM) + + if reuse_port: + trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + if local_address: + await trio_socket.bind(local_address) + + if remote_address: + await trio_socket.connect(remote_address) + return ConnectedUDPSocket(trio_socket) + else: + return UDPSocket(trio_socket) + + +getaddrinfo = trio.socket.getaddrinfo +getnameinfo = trio.socket.getnameinfo + + +async def wait_socket_readable(sock: socket.socket) -> None: + try: + await wait_readable(sock) + except trio.ClosedResourceError as exc: + raise ClosedResourceError().with_traceback(exc.__traceback__) from None + except trio.BusyResourceError: + raise BusyResourceError("reading from") from None + + +async def wait_socket_writable(sock: socket.socket) -> None: + try: + await wait_writable(sock) + except trio.ClosedResourceError as exc: + raise ClosedResourceError().with_traceback(exc.__traceback__) from None + except trio.BusyResourceError: + raise BusyResourceError("writing to") from None + + +# +# Synchronization +# + + +class Event(BaseEvent): + def __new__(cls) -> Event: + return object.__new__(cls) + + def __init__(self) -> None: + self.__original = trio.Event() + + def is_set(self) -> bool: + return self.__original.is_set() + + async def wait(self) -> None: + return await self.__original.wait() + + def statistics(self) -> EventStatistics: + orig_statistics = self.__original.statistics() + return EventStatistics(tasks_waiting=orig_statistics.tasks_waiting) + + def set(self) -> DeprecatedAwaitable: + self.__original.set() + return DeprecatedAwaitable(self.set) + + +class CapacityLimiter(BaseCapacityLimiter): + def __new__(cls, *args: object, **kwargs: object) -> CapacityLimiter: + return object.__new__(cls) + + def __init__( + self, *args: Any, original: trio.CapacityLimiter | None = None + ) -> None: + self.__original = original or trio.CapacityLimiter(*args) + + async def __aenter__(self) -> None: + return await self.__original.__aenter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.__original.__aexit__(exc_type, exc_val, exc_tb) + + @property + def total_tokens(self) -> float: + return self.__original.total_tokens + + @total_tokens.setter + def total_tokens(self, value: float) -> None: + self.__original.total_tokens = value + + @property + def borrowed_tokens(self) -> int: + return self.__original.borrowed_tokens + + @property + def available_tokens(self) -> float: + return self.__original.available_tokens + + def acquire_nowait(self) -> DeprecatedAwaitable: + self.__original.acquire_nowait() + return DeprecatedAwaitable(self.acquire_nowait) + + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: + self.__original.acquire_on_behalf_of_nowait(borrower) + return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) + + async def acquire(self) -> None: + await self.__original.acquire() + + async def acquire_on_behalf_of(self, borrower: object) -> None: + await self.__original.acquire_on_behalf_of(borrower) + + def release(self) -> None: + return self.__original.release() + + def release_on_behalf_of(self, borrower: object) -> None: + return self.__original.release_on_behalf_of(borrower) + + def statistics(self) -> CapacityLimiterStatistics: + orig = self.__original.statistics() + return CapacityLimiterStatistics( + borrowed_tokens=orig.borrowed_tokens, + total_tokens=orig.total_tokens, + borrowers=orig.borrowers, + tasks_waiting=orig.tasks_waiting, + ) + + +_capacity_limiter_wrapper: RunVar = RunVar("_capacity_limiter_wrapper") + + +def current_default_thread_limiter() -> CapacityLimiter: + try: + return _capacity_limiter_wrapper.get() + except LookupError: + limiter = CapacityLimiter( + original=trio.to_thread.current_default_thread_limiter() + ) + _capacity_limiter_wrapper.set(limiter) + return limiter + + +# +# Signal handling +# + + +class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): + _iterator: AsyncIterator[int] + + def __init__(self, signals: tuple[Signals, ...]): + self._signals = signals + + def __enter__(self) -> _SignalReceiver: + self._cm = trio.open_signal_receiver(*self._signals) + self._iterator = self._cm.__enter__() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return self._cm.__exit__(exc_type, exc_val, exc_tb) + + def __aiter__(self) -> _SignalReceiver: + return self + + async def __anext__(self) -> Signals: + signum = await self._iterator.__anext__() + return Signals(signum) + + +def open_signal_receiver(*signals: Signals) -> _SignalReceiver: + return _SignalReceiver(signals) + + +# +# Testing and debugging +# + + +def get_current_task() -> TaskInfo: + task = trio_lowlevel.current_task() + + parent_id = None + if task.parent_nursery and task.parent_nursery.parent_task: + parent_id = id(task.parent_nursery.parent_task) + + return TaskInfo(id(task), parent_id, task.name, task.coro) + + +def get_running_tasks() -> list[TaskInfo]: + root_task = trio_lowlevel.current_root_task() + task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)] + nurseries = root_task.child_nurseries + while nurseries: + new_nurseries: list[trio.Nursery] = [] + for nursery in nurseries: + for task in nursery.child_tasks: + task_infos.append( + TaskInfo(id(task), id(nursery.parent_task), task.name, task.coro) + ) + new_nurseries.extend(task.child_nurseries) + + nurseries = new_nurseries + + return task_infos + + +def wait_all_tasks_blocked() -> Awaitable[None]: + import trio.testing + + return trio.testing.wait_all_tasks_blocked() + + +class TestRunner(abc.TestRunner): + def __init__(self, **options: Any) -> None: + from collections import deque + from queue import Queue + + self._call_queue: Queue[Callable[..., object]] = Queue() + self._result_queue: deque[Outcome] = deque() + self._stop_event: trio.Event | None = None + self._nursery: trio.Nursery | None = None + self._options = options + + async def _trio_main(self) -> None: + self._stop_event = trio.Event() + async with trio.open_nursery() as self._nursery: + await self._stop_event.wait() + + async def _call_func( + self, func: Callable[..., Awaitable[object]], args: tuple, kwargs: dict + ) -> None: + try: + retval = await func(*args, **kwargs) + except BaseException as exc: + self._result_queue.append(Error(exc)) + else: + self._result_queue.append(Value(retval)) + + def _main_task_finished(self, outcome: object) -> None: + self._nursery = None + + def _get_nursery(self) -> trio.Nursery: + if self._nursery is None: + trio.lowlevel.start_guest_run( + self._trio_main, + run_sync_soon_threadsafe=self._call_queue.put, + done_callback=self._main_task_finished, + **self._options, + ) + while self._nursery is None: + self._call_queue.get()() + + return self._nursery + + def _call( + self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object + ) -> T_Retval: + self._get_nursery().start_soon(self._call_func, func, args, kwargs) + while not self._result_queue: + self._call_queue.get()() + + outcome = self._result_queue.pop() + return outcome.unwrap() + + def close(self) -> None: + if self._stop_event: + self._stop_event.set() + while self._nursery is not None: + self._call_queue.get()() + + def run_asyncgen_fixture( + self, + fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], + kwargs: dict[str, Any], + ) -> Iterable[T_Retval]: + async def fixture_runner(*, task_status: TaskStatus[T_Retval]) -> None: + agen = fixture_func(**kwargs) + retval = await agen.asend(None) + task_status.started(retval) + await teardown_event.wait() + try: + await agen.asend(None) + except StopAsyncIteration: + pass + else: + await agen.aclose() + raise RuntimeError("Async generator fixture did not stop") + + teardown_event = trio.Event() + fixture_value = self._call(lambda: self._get_nursery().start(fixture_runner)) + yield fixture_value + teardown_event.set() + + def run_fixture( + self, + fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], + kwargs: dict[str, Any], + ) -> T_Retval: + return self._call(fixture_func, **kwargs) + + def run_test( + self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] + ) -> None: + self._call(test_func, **kwargs) diff --git a/contrib/python/anyio/anyio/_core/__init__.py b/contrib/python/anyio/anyio/_core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/anyio/anyio/_core/__init__.py diff --git a/contrib/python/anyio/anyio/_core/_compat.py b/contrib/python/anyio/anyio/_core/_compat.py new file mode 100644 index 0000000000..22d29ab8ac --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_compat.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from contextlib import AbstractContextManager +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + Callable, + ContextManager, + Generator, + Generic, + Iterable, + List, + TypeVar, + Union, + overload, +) +from warnings import warn + +if TYPE_CHECKING: + from ._testing import TaskInfo +else: + TaskInfo = object + +T = TypeVar("T") +AnyDeprecatedAwaitable = Union[ + "DeprecatedAwaitable", + "DeprecatedAwaitableFloat", + "DeprecatedAwaitableList[T]", + TaskInfo, +] + + +@overload +async def maybe_async(__obj: TaskInfo) -> TaskInfo: + ... + + +@overload +async def maybe_async(__obj: DeprecatedAwaitableFloat) -> float: + ... + + +@overload +async def maybe_async(__obj: DeprecatedAwaitableList[T]) -> list[T]: + ... + + +@overload +async def maybe_async(__obj: DeprecatedAwaitable) -> None: + ... + + +async def maybe_async( + __obj: AnyDeprecatedAwaitable[T], +) -> TaskInfo | float | list[T] | None: + """ + Await on the given object if necessary. + + This function is intended to bridge the gap between AnyIO 2.x and 3.x where some functions and + methods were converted from coroutine functions into regular functions. + + Do **not** try to use this for any other purpose! + + :return: the result of awaiting on the object if coroutine, or the object itself otherwise + + .. versionadded:: 2.2 + + """ + return __obj._unwrap() + + +class _ContextManagerWrapper: + def __init__(self, cm: ContextManager[T]): + self._cm = cm + + async def __aenter__(self) -> T: + return self._cm.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return self._cm.__exit__(exc_type, exc_val, exc_tb) + + +def maybe_async_cm( + cm: ContextManager[T] | AsyncContextManager[T], +) -> AsyncContextManager[T]: + """ + Wrap a regular context manager as an async one if necessary. + + This function is intended to bridge the gap between AnyIO 2.x and 3.x where some functions and + methods were changed to return regular context managers instead of async ones. + + :param cm: a regular or async context manager + :return: an async context manager + + .. versionadded:: 2.2 + + """ + if not isinstance(cm, AbstractContextManager): + raise TypeError("Given object is not an context manager") + + return _ContextManagerWrapper(cm) + + +def _warn_deprecation( + awaitable: AnyDeprecatedAwaitable[Any], stacklevel: int = 1 +) -> None: + warn( + f'Awaiting on {awaitable._name}() is deprecated. Use "await ' + f"anyio.maybe_async({awaitable._name}(...)) if you have to support both AnyIO 2.x " + f'and 3.x, or just remove the "await" if you are completely migrating to AnyIO 3+.', + DeprecationWarning, + stacklevel=stacklevel + 1, + ) + + +class DeprecatedAwaitable: + def __init__(self, func: Callable[..., DeprecatedAwaitable]): + self._name = f"{func.__module__}.{func.__qualname__}" + + def __await__(self) -> Generator[None, None, None]: + _warn_deprecation(self) + if False: + yield + + def __reduce__(self) -> tuple[type[None], tuple[()]]: + return type(None), () + + def _unwrap(self) -> None: + return None + + +class DeprecatedAwaitableFloat(float): + def __new__( + cls, x: float, func: Callable[..., DeprecatedAwaitableFloat] + ) -> DeprecatedAwaitableFloat: + return super().__new__(cls, x) + + def __init__(self, x: float, func: Callable[..., DeprecatedAwaitableFloat]): + self._name = f"{func.__module__}.{func.__qualname__}" + + def __await__(self) -> Generator[None, None, float]: + _warn_deprecation(self) + if False: + yield + + return float(self) + + def __reduce__(self) -> tuple[type[float], tuple[float]]: + return float, (float(self),) + + def _unwrap(self) -> float: + return float(self) + + +class DeprecatedAwaitableList(List[T]): + def __init__( + self, + iterable: Iterable[T] = (), + *, + func: Callable[..., DeprecatedAwaitableList[T]], + ): + super().__init__(iterable) + self._name = f"{func.__module__}.{func.__qualname__}" + + def __await__(self) -> Generator[None, None, list[T]]: + _warn_deprecation(self) + if False: + yield + + return list(self) + + def __reduce__(self) -> tuple[type[list[T]], tuple[list[T]]]: + return list, (list(self),) + + def _unwrap(self) -> list[T]: + return list(self) + + +class DeprecatedAsyncContextManager(Generic[T], metaclass=ABCMeta): + @abstractmethod + def __enter__(self) -> T: + pass + + @abstractmethod + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + pass + + async def __aenter__(self) -> T: + warn( + f"Using {self.__class__.__name__} as an async context manager has been deprecated. " + f'Use "async with anyio.maybe_async_cm(yourcontextmanager) as foo:" if you have to ' + f'support both AnyIO 2.x and 3.x, or just remove the "async" from "async with" if ' + f"you are completely migrating to AnyIO 3+.", + DeprecationWarning, + ) + return self.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return self.__exit__(exc_type, exc_val, exc_tb) diff --git a/contrib/python/anyio/anyio/_core/_eventloop.py b/contrib/python/anyio/anyio/_core/_eventloop.py new file mode 100644 index 0000000000..ae9864851b --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_eventloop.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import math +import sys +import threading +from contextlib import contextmanager +from importlib import import_module +from typing import ( + Any, + Awaitable, + Callable, + Generator, + TypeVar, +) + +import sniffio + +# This must be updated when new backends are introduced +from ._compat import DeprecatedAwaitableFloat + +BACKENDS = "asyncio", "trio" + +T_Retval = TypeVar("T_Retval") +threadlocals = threading.local() + + +def run( + func: Callable[..., Awaitable[T_Retval]], + *args: object, + backend: str = "asyncio", + backend_options: dict[str, Any] | None = None, +) -> T_Retval: + """ + Run the given coroutine function in an asynchronous event loop. + + The current thread must not be already running an event loop. + + :param func: a coroutine function + :param args: positional arguments to ``func`` + :param backend: name of the asynchronous event loop implementation – currently either + ``asyncio`` or ``trio`` + :param backend_options: keyword arguments to call the backend ``run()`` implementation with + (documented :ref:`here <backend options>`) + :return: the return value of the coroutine function + :raises RuntimeError: if an asynchronous event loop is already running in this thread + :raises LookupError: if the named backend is not found + + """ + try: + asynclib_name = sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + pass + else: + raise RuntimeError(f"Already running {asynclib_name} in this thread") + + try: + asynclib = import_module(f"..._backends._{backend}", package=__name__) + except ImportError as exc: + raise LookupError(f"No such backend: {backend}") from exc + + token = None + if sniffio.current_async_library_cvar.get(None) is None: + # Since we're in control of the event loop, we can cache the name of the async library + token = sniffio.current_async_library_cvar.set(backend) + + try: + backend_options = backend_options or {} + return asynclib.run(func, *args, **backend_options) + finally: + if token: + sniffio.current_async_library_cvar.reset(token) + + +async def sleep(delay: float) -> None: + """ + Pause the current task for the specified duration. + + :param delay: the duration, in seconds + + """ + return await get_asynclib().sleep(delay) + + +async def sleep_forever() -> None: + """ + Pause the current task until it's cancelled. + + This is a shortcut for ``sleep(math.inf)``. + + .. versionadded:: 3.1 + + """ + await sleep(math.inf) + + +async def sleep_until(deadline: float) -> None: + """ + Pause the current task until the given time. + + :param deadline: the absolute time to wake up at (according to the internal monotonic clock of + the event loop) + + .. versionadded:: 3.1 + + """ + now = current_time() + await sleep(max(deadline - now, 0)) + + +def current_time() -> DeprecatedAwaitableFloat: + """ + Return the current value of the event loop's internal clock. + + :return: the clock value (seconds) + + """ + return DeprecatedAwaitableFloat(get_asynclib().current_time(), current_time) + + +def get_all_backends() -> tuple[str, ...]: + """Return a tuple of the names of all built-in backends.""" + return BACKENDS + + +def get_cancelled_exc_class() -> type[BaseException]: + """Return the current async library's cancellation exception class.""" + return get_asynclib().CancelledError + + +# +# Private API +# + + +@contextmanager +def claim_worker_thread(backend: str) -> Generator[Any, None, None]: + module = sys.modules["anyio._backends._" + backend] + threadlocals.current_async_module = module + try: + yield + finally: + del threadlocals.current_async_module + + +def get_asynclib(asynclib_name: str | None = None) -> Any: + if asynclib_name is None: + asynclib_name = sniffio.current_async_library() + + modulename = "anyio._backends._" + asynclib_name + try: + return sys.modules[modulename] + except KeyError: + return import_module(modulename) diff --git a/contrib/python/anyio/anyio/_core/_exceptions.py b/contrib/python/anyio/anyio/_core/_exceptions.py new file mode 100644 index 0000000000..92ccd77a2d --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_exceptions.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from traceback import format_exception + + +class BrokenResourceError(Exception): + """ + Raised when trying to use a resource that has been rendered unusable due to external causes + (e.g. a send stream whose peer has disconnected). + """ + + +class BrokenWorkerProcess(Exception): + """ + Raised by :func:`run_sync_in_process` if the worker process terminates abruptly or otherwise + misbehaves. + """ + + +class BusyResourceError(Exception): + """Raised when two tasks are trying to read from or write to the same resource concurrently.""" + + def __init__(self, action: str): + super().__init__(f"Another task is already {action} this resource") + + +class ClosedResourceError(Exception): + """Raised when trying to use a resource that has been closed.""" + + +class DelimiterNotFound(Exception): + """ + Raised during :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the + maximum number of bytes has been read without the delimiter being found. + """ + + def __init__(self, max_bytes: int) -> None: + super().__init__( + f"The delimiter was not found among the first {max_bytes} bytes" + ) + + +class EndOfStream(Exception): + """Raised when trying to read from a stream that has been closed from the other end.""" + + +class ExceptionGroup(BaseException): + """ + Raised when multiple exceptions have been raised in a task group. + + :var ~typing.Sequence[BaseException] exceptions: the sequence of exceptions raised together + """ + + SEPARATOR = "----------------------------\n" + + exceptions: list[BaseException] + + def __str__(self) -> str: + tracebacks = [ + "".join(format_exception(type(exc), exc, exc.__traceback__)) + for exc in self.exceptions + ] + return ( + f"{len(self.exceptions)} exceptions were raised in the task group:\n" + f"{self.SEPARATOR}{self.SEPARATOR.join(tracebacks)}" + ) + + def __repr__(self) -> str: + exception_reprs = ", ".join(repr(exc) for exc in self.exceptions) + return f"<{self.__class__.__name__}: {exception_reprs}>" + + +class IncompleteRead(Exception): + """ + Raised during :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_exactly` or + :meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the + connection is closed before the requested amount of bytes has been read. + """ + + def __init__(self) -> None: + super().__init__( + "The stream was closed before the read operation could be completed" + ) + + +class TypedAttributeLookupError(LookupError): + """ + Raised by :meth:`~anyio.TypedAttributeProvider.extra` when the given typed attribute is not + found and no default value has been given. + """ + + +class WouldBlock(Exception): + """Raised by ``X_nowait`` functions if ``X()`` would block.""" diff --git a/contrib/python/anyio/anyio/_core/_fileio.py b/contrib/python/anyio/anyio/_core/_fileio.py new file mode 100644 index 0000000000..35e8e8af6c --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_fileio.py @@ -0,0 +1,603 @@ +from __future__ import annotations + +import os +import pathlib +import sys +from dataclasses import dataclass +from functools import partial +from os import PathLike +from typing import ( + IO, + TYPE_CHECKING, + Any, + AnyStr, + AsyncIterator, + Callable, + Generic, + Iterable, + Iterator, + Sequence, + cast, + overload, +) + +from .. import to_thread +from ..abc import AsyncResource + +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + +if TYPE_CHECKING: + from _typeshed import OpenBinaryMode, OpenTextMode, ReadableBuffer, WriteableBuffer +else: + ReadableBuffer = OpenBinaryMode = OpenTextMode = WriteableBuffer = object + + +class AsyncFile(AsyncResource, Generic[AnyStr]): + """ + An asynchronous file object. + + This class wraps a standard file object and provides async friendly versions of the following + blocking methods (where available on the original file object): + + * read + * read1 + * readline + * readlines + * readinto + * readinto1 + * write + * writelines + * truncate + * seek + * tell + * flush + + All other methods are directly passed through. + + This class supports the asynchronous context manager protocol which closes the underlying file + at the end of the context block. + + This class also supports asynchronous iteration:: + + async with await open_file(...) as f: + async for line in f: + print(line) + """ + + def __init__(self, fp: IO[AnyStr]) -> None: + self._fp: Any = fp + + def __getattr__(self, name: str) -> object: + return getattr(self._fp, name) + + @property + def wrapped(self) -> IO[AnyStr]: + """The wrapped file object.""" + return self._fp + + async def __aiter__(self) -> AsyncIterator[AnyStr]: + while True: + line = await self.readline() + if line: + yield line + else: + break + + async def aclose(self) -> None: + return await to_thread.run_sync(self._fp.close) + + async def read(self, size: int = -1) -> AnyStr: + return await to_thread.run_sync(self._fp.read, size) + + async def read1(self: AsyncFile[bytes], size: int = -1) -> bytes: + return await to_thread.run_sync(self._fp.read1, size) + + async def readline(self) -> AnyStr: + return await to_thread.run_sync(self._fp.readline) + + async def readlines(self) -> list[AnyStr]: + return await to_thread.run_sync(self._fp.readlines) + + async def readinto(self: AsyncFile[bytes], b: WriteableBuffer) -> bytes: + return await to_thread.run_sync(self._fp.readinto, b) + + async def readinto1(self: AsyncFile[bytes], b: WriteableBuffer) -> bytes: + return await to_thread.run_sync(self._fp.readinto1, b) + + @overload + async def write(self: AsyncFile[bytes], b: ReadableBuffer) -> int: + ... + + @overload + async def write(self: AsyncFile[str], b: str) -> int: + ... + + async def write(self, b: ReadableBuffer | str) -> int: + return await to_thread.run_sync(self._fp.write, b) + + @overload + async def writelines( + self: AsyncFile[bytes], lines: Iterable[ReadableBuffer] + ) -> None: + ... + + @overload + async def writelines(self: AsyncFile[str], lines: Iterable[str]) -> None: + ... + + async def writelines(self, lines: Iterable[ReadableBuffer] | Iterable[str]) -> None: + return await to_thread.run_sync(self._fp.writelines, lines) + + async def truncate(self, size: int | None = None) -> int: + return await to_thread.run_sync(self._fp.truncate, size) + + async def seek(self, offset: int, whence: int | None = os.SEEK_SET) -> int: + return await to_thread.run_sync(self._fp.seek, offset, whence) + + async def tell(self) -> int: + return await to_thread.run_sync(self._fp.tell) + + async def flush(self) -> None: + return await to_thread.run_sync(self._fp.flush) + + +@overload +async def open_file( + file: str | PathLike[str] | int, + mode: OpenBinaryMode, + buffering: int = ..., + encoding: str | None = ..., + errors: str | None = ..., + newline: str | None = ..., + closefd: bool = ..., + opener: Callable[[str, int], int] | None = ..., +) -> AsyncFile[bytes]: + ... + + +@overload +async def open_file( + file: str | PathLike[str] | int, + mode: OpenTextMode = ..., + buffering: int = ..., + encoding: str | None = ..., + errors: str | None = ..., + newline: str | None = ..., + closefd: bool = ..., + opener: Callable[[str, int], int] | None = ..., +) -> AsyncFile[str]: + ... + + +async def open_file( + file: str | PathLike[str] | int, + mode: str = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: Callable[[str, int], int] | None = None, +) -> AsyncFile[Any]: + """ + Open a file asynchronously. + + The arguments are exactly the same as for the builtin :func:`open`. + + :return: an asynchronous file object + + """ + fp = await to_thread.run_sync( + open, file, mode, buffering, encoding, errors, newline, closefd, opener + ) + return AsyncFile(fp) + + +def wrap_file(file: IO[AnyStr]) -> AsyncFile[AnyStr]: + """ + Wrap an existing file as an asynchronous file. + + :param file: an existing file-like object + :return: an asynchronous file object + + """ + return AsyncFile(file) + + +@dataclass(eq=False) +class _PathIterator(AsyncIterator["Path"]): + iterator: Iterator[PathLike[str]] + + async def __anext__(self) -> Path: + nextval = await to_thread.run_sync(next, self.iterator, None, cancellable=True) + if nextval is None: + raise StopAsyncIteration from None + + return Path(cast("PathLike[str]", nextval)) + + +class Path: + """ + An asynchronous version of :class:`pathlib.Path`. + + This class cannot be substituted for :class:`pathlib.Path` or :class:`pathlib.PurePath`, but + it is compatible with the :class:`os.PathLike` interface. + + It implements the Python 3.10 version of :class:`pathlib.Path` interface, except for the + deprecated :meth:`~pathlib.Path.link_to` method. + + Any methods that do disk I/O need to be awaited on. These methods are: + + * :meth:`~pathlib.Path.absolute` + * :meth:`~pathlib.Path.chmod` + * :meth:`~pathlib.Path.cwd` + * :meth:`~pathlib.Path.exists` + * :meth:`~pathlib.Path.expanduser` + * :meth:`~pathlib.Path.group` + * :meth:`~pathlib.Path.hardlink_to` + * :meth:`~pathlib.Path.home` + * :meth:`~pathlib.Path.is_block_device` + * :meth:`~pathlib.Path.is_char_device` + * :meth:`~pathlib.Path.is_dir` + * :meth:`~pathlib.Path.is_fifo` + * :meth:`~pathlib.Path.is_file` + * :meth:`~pathlib.Path.is_mount` + * :meth:`~pathlib.Path.lchmod` + * :meth:`~pathlib.Path.lstat` + * :meth:`~pathlib.Path.mkdir` + * :meth:`~pathlib.Path.open` + * :meth:`~pathlib.Path.owner` + * :meth:`~pathlib.Path.read_bytes` + * :meth:`~pathlib.Path.read_text` + * :meth:`~pathlib.Path.readlink` + * :meth:`~pathlib.Path.rename` + * :meth:`~pathlib.Path.replace` + * :meth:`~pathlib.Path.rmdir` + * :meth:`~pathlib.Path.samefile` + * :meth:`~pathlib.Path.stat` + * :meth:`~pathlib.Path.touch` + * :meth:`~pathlib.Path.unlink` + * :meth:`~pathlib.Path.write_bytes` + * :meth:`~pathlib.Path.write_text` + + Additionally, the following methods return an async iterator yielding :class:`~.Path` objects: + + * :meth:`~pathlib.Path.glob` + * :meth:`~pathlib.Path.iterdir` + * :meth:`~pathlib.Path.rglob` + """ + + __slots__ = "_path", "__weakref__" + + __weakref__: Any + + def __init__(self, *args: str | PathLike[str]) -> None: + self._path: Final[pathlib.Path] = pathlib.Path(*args) + + def __fspath__(self) -> str: + return self._path.__fspath__() + + def __str__(self) -> str: + return self._path.__str__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.as_posix()!r})" + + def __bytes__(self) -> bytes: + return self._path.__bytes__() + + def __hash__(self) -> int: + return self._path.__hash__() + + def __eq__(self, other: object) -> bool: + target = other._path if isinstance(other, Path) else other + return self._path.__eq__(target) + + def __lt__(self, other: Path) -> bool: + target = other._path if isinstance(other, Path) else other + return self._path.__lt__(target) + + def __le__(self, other: Path) -> bool: + target = other._path if isinstance(other, Path) else other + return self._path.__le__(target) + + def __gt__(self, other: Path) -> bool: + target = other._path if isinstance(other, Path) else other + return self._path.__gt__(target) + + def __ge__(self, other: Path) -> bool: + target = other._path if isinstance(other, Path) else other + return self._path.__ge__(target) + + def __truediv__(self, other: Any) -> Path: + return Path(self._path / other) + + def __rtruediv__(self, other: Any) -> Path: + return Path(other) / self + + @property + def parts(self) -> tuple[str, ...]: + return self._path.parts + + @property + def drive(self) -> str: + return self._path.drive + + @property + def root(self) -> str: + return self._path.root + + @property + def anchor(self) -> str: + return self._path.anchor + + @property + def parents(self) -> Sequence[Path]: + return tuple(Path(p) for p in self._path.parents) + + @property + def parent(self) -> Path: + return Path(self._path.parent) + + @property + def name(self) -> str: + return self._path.name + + @property + def suffix(self) -> str: + return self._path.suffix + + @property + def suffixes(self) -> list[str]: + return self._path.suffixes + + @property + def stem(self) -> str: + return self._path.stem + + async def absolute(self) -> Path: + path = await to_thread.run_sync(self._path.absolute) + return Path(path) + + def as_posix(self) -> str: + return self._path.as_posix() + + def as_uri(self) -> str: + return self._path.as_uri() + + def match(self, path_pattern: str) -> bool: + return self._path.match(path_pattern) + + def is_relative_to(self, *other: str | PathLike[str]) -> bool: + try: + self.relative_to(*other) + return True + except ValueError: + return False + + async def chmod(self, mode: int, *, follow_symlinks: bool = True) -> None: + func = partial(os.chmod, follow_symlinks=follow_symlinks) + return await to_thread.run_sync(func, self._path, mode) + + @classmethod + async def cwd(cls) -> Path: + path = await to_thread.run_sync(pathlib.Path.cwd) + return cls(path) + + async def exists(self) -> bool: + return await to_thread.run_sync(self._path.exists, cancellable=True) + + async def expanduser(self) -> Path: + return Path(await to_thread.run_sync(self._path.expanduser, cancellable=True)) + + def glob(self, pattern: str) -> AsyncIterator[Path]: + gen = self._path.glob(pattern) + return _PathIterator(gen) + + async def group(self) -> str: + return await to_thread.run_sync(self._path.group, cancellable=True) + + async def hardlink_to(self, target: str | pathlib.Path | Path) -> None: + if isinstance(target, Path): + target = target._path + + await to_thread.run_sync(os.link, target, self) + + @classmethod + async def home(cls) -> Path: + home_path = await to_thread.run_sync(pathlib.Path.home) + return cls(home_path) + + def is_absolute(self) -> bool: + return self._path.is_absolute() + + async def is_block_device(self) -> bool: + return await to_thread.run_sync(self._path.is_block_device, cancellable=True) + + async def is_char_device(self) -> bool: + return await to_thread.run_sync(self._path.is_char_device, cancellable=True) + + async def is_dir(self) -> bool: + return await to_thread.run_sync(self._path.is_dir, cancellable=True) + + async def is_fifo(self) -> bool: + return await to_thread.run_sync(self._path.is_fifo, cancellable=True) + + async def is_file(self) -> bool: + return await to_thread.run_sync(self._path.is_file, cancellable=True) + + async def is_mount(self) -> bool: + return await to_thread.run_sync(os.path.ismount, self._path, cancellable=True) + + def is_reserved(self) -> bool: + return self._path.is_reserved() + + async def is_socket(self) -> bool: + return await to_thread.run_sync(self._path.is_socket, cancellable=True) + + async def is_symlink(self) -> bool: + return await to_thread.run_sync(self._path.is_symlink, cancellable=True) + + def iterdir(self) -> AsyncIterator[Path]: + gen = self._path.iterdir() + return _PathIterator(gen) + + def joinpath(self, *args: str | PathLike[str]) -> Path: + return Path(self._path.joinpath(*args)) + + async def lchmod(self, mode: int) -> None: + await to_thread.run_sync(self._path.lchmod, mode) + + async def lstat(self) -> os.stat_result: + return await to_thread.run_sync(self._path.lstat, cancellable=True) + + async def mkdir( + self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False + ) -> None: + await to_thread.run_sync(self._path.mkdir, mode, parents, exist_ok) + + @overload + async def open( + self, + mode: OpenBinaryMode, + buffering: int = ..., + encoding: str | None = ..., + errors: str | None = ..., + newline: str | None = ..., + ) -> AsyncFile[bytes]: + ... + + @overload + async def open( + self, + mode: OpenTextMode = ..., + buffering: int = ..., + encoding: str | None = ..., + errors: str | None = ..., + newline: str | None = ..., + ) -> AsyncFile[str]: + ... + + async def open( + self, + mode: str = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> AsyncFile[Any]: + fp = await to_thread.run_sync( + self._path.open, mode, buffering, encoding, errors, newline + ) + return AsyncFile(fp) + + async def owner(self) -> str: + return await to_thread.run_sync(self._path.owner, cancellable=True) + + async def read_bytes(self) -> bytes: + return await to_thread.run_sync(self._path.read_bytes) + + async def read_text( + self, encoding: str | None = None, errors: str | None = None + ) -> str: + return await to_thread.run_sync(self._path.read_text, encoding, errors) + + def relative_to(self, *other: str | PathLike[str]) -> Path: + return Path(self._path.relative_to(*other)) + + async def readlink(self) -> Path: + target = await to_thread.run_sync(os.readlink, self._path) + return Path(cast(str, target)) + + async def rename(self, target: str | pathlib.PurePath | Path) -> Path: + if isinstance(target, Path): + target = target._path + + await to_thread.run_sync(self._path.rename, target) + return Path(target) + + async def replace(self, target: str | pathlib.PurePath | Path) -> Path: + if isinstance(target, Path): + target = target._path + + await to_thread.run_sync(self._path.replace, target) + return Path(target) + + async def resolve(self, strict: bool = False) -> Path: + func = partial(self._path.resolve, strict=strict) + return Path(await to_thread.run_sync(func, cancellable=True)) + + def rglob(self, pattern: str) -> AsyncIterator[Path]: + gen = self._path.rglob(pattern) + return _PathIterator(gen) + + async def rmdir(self) -> None: + await to_thread.run_sync(self._path.rmdir) + + async def samefile( + self, other_path: str | bytes | int | pathlib.Path | Path + ) -> bool: + if isinstance(other_path, Path): + other_path = other_path._path + + return await to_thread.run_sync( + self._path.samefile, other_path, cancellable=True + ) + + async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result: + func = partial(os.stat, follow_symlinks=follow_symlinks) + return await to_thread.run_sync(func, self._path, cancellable=True) + + async def symlink_to( + self, + target: str | pathlib.Path | Path, + target_is_directory: bool = False, + ) -> None: + if isinstance(target, Path): + target = target._path + + await to_thread.run_sync(self._path.symlink_to, target, target_is_directory) + + async def touch(self, mode: int = 0o666, exist_ok: bool = True) -> None: + await to_thread.run_sync(self._path.touch, mode, exist_ok) + + async def unlink(self, missing_ok: bool = False) -> None: + try: + await to_thread.run_sync(self._path.unlink) + except FileNotFoundError: + if not missing_ok: + raise + + def with_name(self, name: str) -> Path: + return Path(self._path.with_name(name)) + + def with_stem(self, stem: str) -> Path: + return Path(self._path.with_name(stem + self._path.suffix)) + + def with_suffix(self, suffix: str) -> Path: + return Path(self._path.with_suffix(suffix)) + + async def write_bytes(self, data: bytes) -> int: + return await to_thread.run_sync(self._path.write_bytes, data) + + async def write_text( + self, + data: str, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> int: + # Path.write_text() does not support the "newline" parameter before Python 3.10 + def sync_write_text() -> int: + with self._path.open( + "w", encoding=encoding, errors=errors, newline=newline + ) as fp: + return fp.write(data) + + return await to_thread.run_sync(sync_write_text) + + +PathLike.register(Path) diff --git a/contrib/python/anyio/anyio/_core/_resources.py b/contrib/python/anyio/anyio/_core/_resources.py new file mode 100644 index 0000000000..b9a5344aef --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_resources.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from ..abc import AsyncResource +from ._tasks import CancelScope + + +async def aclose_forcefully(resource: AsyncResource) -> None: + """ + Close an asynchronous resource in a cancelled scope. + + Doing this closes the resource without waiting on anything. + + :param resource: the resource to close + + """ + with CancelScope() as scope: + scope.cancel() + await resource.aclose() diff --git a/contrib/python/anyio/anyio/_core/_signals.py b/contrib/python/anyio/anyio/_core/_signals.py new file mode 100644 index 0000000000..8ea54af86c --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_signals.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import AsyncIterator + +from ._compat import DeprecatedAsyncContextManager +from ._eventloop import get_asynclib + + +def open_signal_receiver( + *signals: int, +) -> DeprecatedAsyncContextManager[AsyncIterator[int]]: + """ + Start receiving operating system signals. + + :param signals: signals to receive (e.g. ``signal.SIGINT``) + :return: an asynchronous context manager for an asynchronous iterator which yields signal + numbers + + .. warning:: Windows does not support signals natively so it is best to avoid relying on this + in cross-platform applications. + + .. warning:: On asyncio, this permanently replaces any previous signal handler for the given + signals, as set via :meth:`~asyncio.loop.add_signal_handler`. + + """ + return get_asynclib().open_signal_receiver(*signals) diff --git a/contrib/python/anyio/anyio/_core/_sockets.py b/contrib/python/anyio/anyio/_core/_sockets.py new file mode 100644 index 0000000000..e6970bee27 --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_sockets.py @@ -0,0 +1,607 @@ +from __future__ import annotations + +import socket +import ssl +import sys +from ipaddress import IPv6Address, ip_address +from os import PathLike, chmod +from pathlib import Path +from socket import AddressFamily, SocketKind +from typing import Awaitable, List, Tuple, cast, overload + +from .. import to_thread +from ..abc import ( + ConnectedUDPSocket, + IPAddressType, + IPSockAddrType, + SocketListener, + SocketStream, + UDPSocket, + UNIXSocketStream, +) +from ..streams.stapled import MultiListener +from ..streams.tls import TLSStream +from ._eventloop import get_asynclib +from ._resources import aclose_forcefully +from ._synchronization import Event +from ._tasks import create_task_group, move_on_after + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515 + +GetAddrInfoReturnType = List[ + Tuple[AddressFamily, SocketKind, int, str, Tuple[str, int]] +] +AnyIPAddressFamily = Literal[ + AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6 +] +IPAddressFamily = Literal[AddressFamily.AF_INET, AddressFamily.AF_INET6] + + +# tls_hostname given +@overload +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = ..., + ssl_context: ssl.SSLContext | None = ..., + tls_standard_compatible: bool = ..., + tls_hostname: str, + happy_eyeballs_delay: float = ..., +) -> TLSStream: + ... + + +# ssl_context given +@overload +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = ..., + ssl_context: ssl.SSLContext, + tls_standard_compatible: bool = ..., + tls_hostname: str | None = ..., + happy_eyeballs_delay: float = ..., +) -> TLSStream: + ... + + +# tls=True +@overload +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = ..., + tls: Literal[True], + ssl_context: ssl.SSLContext | None = ..., + tls_standard_compatible: bool = ..., + tls_hostname: str | None = ..., + happy_eyeballs_delay: float = ..., +) -> TLSStream: + ... + + +# tls=False +@overload +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = ..., + tls: Literal[False], + ssl_context: ssl.SSLContext | None = ..., + tls_standard_compatible: bool = ..., + tls_hostname: str | None = ..., + happy_eyeballs_delay: float = ..., +) -> SocketStream: + ... + + +# No TLS arguments +@overload +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = ..., + happy_eyeballs_delay: float = ..., +) -> SocketStream: + ... + + +async def connect_tcp( + remote_host: IPAddressType, + remote_port: int, + *, + local_host: IPAddressType | None = None, + tls: bool = False, + ssl_context: ssl.SSLContext | None = None, + tls_standard_compatible: bool = True, + tls_hostname: str | None = None, + happy_eyeballs_delay: float = 0.25, +) -> SocketStream | TLSStream: + """ + Connect to a host using the TCP protocol. + + This function implements the stateless version of the Happy Eyeballs algorithm (RFC + 6555). If ``remote_host`` is a host name that resolves to multiple IP addresses, + each one is tried until one connection attempt succeeds. If the first attempt does + not connected within 250 milliseconds, a second attempt is started using the next + address in the list, and so on. On IPv6 enabled systems, an IPv6 address (if + available) is tried first. + + When the connection has been established, a TLS handshake will be done if either + ``ssl_context`` or ``tls_hostname`` is not ``None``, or if ``tls`` is ``True``. + + :param remote_host: the IP address or host name to connect to + :param remote_port: port on the target host to connect to + :param local_host: the interface address or name to bind the socket to before connecting + :param tls: ``True`` to do a TLS handshake with the connected stream and return a + :class:`~anyio.streams.tls.TLSStream` instead + :param ssl_context: the SSL context object to use (if omitted, a default context is created) + :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake before closing + the stream and requires that the server does this as well. Otherwise, + :exc:`~ssl.SSLEOFError` may be raised during reads from the stream. + Some protocols, such as HTTP, require this option to be ``False``. + See :meth:`~ssl.SSLContext.wrap_socket` for details. + :param tls_hostname: host name to check the server certificate against (defaults to the value + of ``remote_host``) + :param happy_eyeballs_delay: delay (in seconds) before starting the next connection attempt + :return: a socket stream object if no TLS handshake was done, otherwise a TLS stream + :raises OSError: if the connection attempt fails + + """ + # Placed here due to https://github.com/python/mypy/issues/7057 + connected_stream: SocketStream | None = None + + async def try_connect(remote_host: str, event: Event) -> None: + nonlocal connected_stream + try: + stream = await asynclib.connect_tcp(remote_host, remote_port, local_address) + except OSError as exc: + oserrors.append(exc) + return + else: + if connected_stream is None: + connected_stream = stream + tg.cancel_scope.cancel() + else: + await stream.aclose() + finally: + event.set() + + asynclib = get_asynclib() + local_address: IPSockAddrType | None = None + family = socket.AF_UNSPEC + if local_host: + gai_res = await getaddrinfo(str(local_host), None) + family, *_, local_address = gai_res[0] + + target_host = str(remote_host) + try: + addr_obj = ip_address(remote_host) + except ValueError: + # getaddrinfo() will raise an exception if name resolution fails + gai_res = await getaddrinfo( + target_host, remote_port, family=family, type=socket.SOCK_STREAM + ) + + # Organize the list so that the first address is an IPv6 address (if available) and the + # second one is an IPv4 addresses. The rest can be in whatever order. + v6_found = v4_found = False + target_addrs: list[tuple[socket.AddressFamily, str]] = [] + for af, *rest, sa in gai_res: + if af == socket.AF_INET6 and not v6_found: + v6_found = True + target_addrs.insert(0, (af, sa[0])) + elif af == socket.AF_INET and not v4_found and v6_found: + v4_found = True + target_addrs.insert(1, (af, sa[0])) + else: + target_addrs.append((af, sa[0])) + else: + if isinstance(addr_obj, IPv6Address): + target_addrs = [(socket.AF_INET6, addr_obj.compressed)] + else: + target_addrs = [(socket.AF_INET, addr_obj.compressed)] + + oserrors: list[OSError] = [] + async with create_task_group() as tg: + for i, (af, addr) in enumerate(target_addrs): + event = Event() + tg.start_soon(try_connect, addr, event) + with move_on_after(happy_eyeballs_delay): + await event.wait() + + if connected_stream is None: + cause = oserrors[0] if len(oserrors) == 1 else asynclib.ExceptionGroup(oserrors) + raise OSError("All connection attempts failed") from cause + + if tls or tls_hostname or ssl_context: + try: + return await TLSStream.wrap( + connected_stream, + server_side=False, + hostname=tls_hostname or str(remote_host), + ssl_context=ssl_context, + standard_compatible=tls_standard_compatible, + ) + except BaseException: + await aclose_forcefully(connected_stream) + raise + + return connected_stream + + +async def connect_unix(path: str | PathLike[str]) -> UNIXSocketStream: + """ + Connect to the given UNIX socket. + + Not available on Windows. + + :param path: path to the socket + :return: a socket stream object + + """ + path = str(Path(path)) + return await get_asynclib().connect_unix(path) + + +async def create_tcp_listener( + *, + local_host: IPAddressType | None = None, + local_port: int = 0, + family: AnyIPAddressFamily = socket.AddressFamily.AF_UNSPEC, + backlog: int = 65536, + reuse_port: bool = False, +) -> MultiListener[SocketStream]: + """ + Create a TCP socket listener. + + :param local_port: port number to listen on + :param local_host: IP address of the interface to listen on. If omitted, listen on + all IPv4 and IPv6 interfaces. To listen on all interfaces on a specific address + family, use ``0.0.0.0`` for IPv4 or ``::`` for IPv6. + :param family: address family (used if ``local_host`` was omitted) + :param backlog: maximum number of queued incoming connections (up to a maximum of + 2**16, or 65536) + :param reuse_port: ``True`` to allow multiple sockets to bind to the same + address/port (not supported on Windows) + :return: a list of listener objects + + """ + asynclib = get_asynclib() + backlog = min(backlog, 65536) + local_host = str(local_host) if local_host is not None else None + gai_res = await getaddrinfo( + local_host, # type: ignore[arg-type] + local_port, + family=family, + type=socket.SocketKind.SOCK_STREAM if sys.platform == "win32" else 0, + flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG, + ) + listeners: list[SocketListener] = [] + try: + # The set() is here to work around a glibc bug: + # https://sourceware.org/bugzilla/show_bug.cgi?id=14969 + sockaddr: tuple[str, int] | tuple[str, int, int, int] + for fam, kind, *_, sockaddr in sorted(set(gai_res)): + # Workaround for an uvloop bug where we don't get the correct scope ID for + # IPv6 link-local addresses when passing type=socket.SOCK_STREAM to + # getaddrinfo(): https://github.com/MagicStack/uvloop/issues/539 + if sys.platform != "win32" and kind is not SocketKind.SOCK_STREAM: + continue + + raw_socket = socket.socket(fam) + raw_socket.setblocking(False) + + # For Windows, enable exclusive address use. For others, enable address reuse. + if sys.platform == "win32": + raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + else: + raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + if reuse_port: + raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + # If only IPv6 was requested, disable dual stack operation + if fam == socket.AF_INET6: + raw_socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) + + # Workaround for #554 + if "%" in sockaddr[0]: + addr, scope_id = sockaddr[0].split("%", 1) + sockaddr = (addr, sockaddr[1], 0, int(scope_id)) + + raw_socket.bind(sockaddr) + raw_socket.listen(backlog) + listener = asynclib.TCPSocketListener(raw_socket) + listeners.append(listener) + except BaseException: + for listener in listeners: + await listener.aclose() + + raise + + return MultiListener(listeners) + + +async def create_unix_listener( + path: str | PathLike[str], + *, + mode: int | None = None, + backlog: int = 65536, +) -> SocketListener: + """ + Create a UNIX socket listener. + + Not available on Windows. + + :param path: path of the socket + :param mode: permissions to set on the socket + :param backlog: maximum number of queued incoming connections (up to a maximum of 2**16, or + 65536) + :return: a listener object + + .. versionchanged:: 3.0 + If a socket already exists on the file system in the given path, it will be removed first. + + """ + path_str = str(path) + path = Path(path) + if path.is_socket(): + path.unlink() + + backlog = min(backlog, 65536) + raw_socket = socket.socket(socket.AF_UNIX) + raw_socket.setblocking(False) + try: + await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True) + if mode is not None: + await to_thread.run_sync(chmod, path_str, mode, cancellable=True) + + raw_socket.listen(backlog) + return get_asynclib().UNIXSocketListener(raw_socket) + except BaseException: + raw_socket.close() + raise + + +async def create_udp_socket( + family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC, + *, + local_host: IPAddressType | None = None, + local_port: int = 0, + reuse_port: bool = False, +) -> UDPSocket: + """ + Create a UDP socket. + + If ``local_port`` has been given, the socket will be bound to this port on the local + machine, making this socket suitable for providing UDP based services. + + :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from + ``local_host`` if omitted + :param local_host: IP address or host name of the local interface to bind to + :param local_port: local port to bind to + :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port + (not supported on Windows) + :return: a UDP socket + + """ + if family is AddressFamily.AF_UNSPEC and not local_host: + raise ValueError('Either "family" or "local_host" must be given') + + if local_host: + gai_res = await getaddrinfo( + str(local_host), + local_port, + family=family, + type=socket.SOCK_DGRAM, + flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG, + ) + family = cast(AnyIPAddressFamily, gai_res[0][0]) + local_address = gai_res[0][-1] + elif family is AddressFamily.AF_INET6: + local_address = ("::", 0) + else: + local_address = ("0.0.0.0", 0) + + return await get_asynclib().create_udp_socket( + family, local_address, None, reuse_port + ) + + +async def create_connected_udp_socket( + remote_host: IPAddressType, + remote_port: int, + *, + family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC, + local_host: IPAddressType | None = None, + local_port: int = 0, + reuse_port: bool = False, +) -> ConnectedUDPSocket: + """ + Create a connected UDP socket. + + Connected UDP sockets can only communicate with the specified remote host/port, and any packets + sent from other sources are dropped. + + :param remote_host: remote host to set as the default target + :param remote_port: port on the remote host to set as the default target + :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from + ``local_host`` or ``remote_host`` if omitted + :param local_host: IP address or host name of the local interface to bind to + :param local_port: local port to bind to + :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port + (not supported on Windows) + :return: a connected UDP socket + + """ + local_address = None + if local_host: + gai_res = await getaddrinfo( + str(local_host), + local_port, + family=family, + type=socket.SOCK_DGRAM, + flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG, + ) + family = cast(AnyIPAddressFamily, gai_res[0][0]) + local_address = gai_res[0][-1] + + gai_res = await getaddrinfo( + str(remote_host), remote_port, family=family, type=socket.SOCK_DGRAM + ) + family = cast(AnyIPAddressFamily, gai_res[0][0]) + remote_address = gai_res[0][-1] + + return await get_asynclib().create_udp_socket( + family, local_address, remote_address, reuse_port + ) + + +async def getaddrinfo( + host: bytearray | bytes | str, + port: str | int | None, + *, + family: int | AddressFamily = 0, + type: int | SocketKind = 0, + proto: int = 0, + flags: int = 0, +) -> GetAddrInfoReturnType: + """ + Look up a numeric IP address given a host name. + + Internationalized domain names are translated according to the (non-transitional) IDNA 2008 + standard. + + .. note:: 4-tuple IPv6 socket addresses are automatically converted to 2-tuples of + (host, port), unlike what :func:`socket.getaddrinfo` does. + + :param host: host name + :param port: port number + :param family: socket family (`'AF_INET``, ...) + :param type: socket type (``SOCK_STREAM``, ...) + :param proto: protocol number + :param flags: flags to pass to upstream ``getaddrinfo()`` + :return: list of tuples containing (family, type, proto, canonname, sockaddr) + + .. seealso:: :func:`socket.getaddrinfo` + + """ + # Handle unicode hostnames + if isinstance(host, str): + try: + encoded_host = host.encode("ascii") + except UnicodeEncodeError: + import idna + + encoded_host = idna.encode(host, uts46=True) + else: + encoded_host = host + + gai_res = await get_asynclib().getaddrinfo( + encoded_host, port, family=family, type=type, proto=proto, flags=flags + ) + return [ + (family, type, proto, canonname, convert_ipv6_sockaddr(sockaddr)) + for family, type, proto, canonname, sockaddr in gai_res + ] + + +def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str, str]]: + """ + Look up the host name of an IP address. + + :param sockaddr: socket address (e.g. (ipaddress, port) for IPv4) + :param flags: flags to pass to upstream ``getnameinfo()`` + :return: a tuple of (host name, service name) + + .. seealso:: :func:`socket.getnameinfo` + + """ + return get_asynclib().getnameinfo(sockaddr, flags) + + +def wait_socket_readable(sock: socket.socket) -> Awaitable[None]: + """ + Wait until the given socket has data to be read. + + This does **NOT** work on Windows when using the asyncio backend with a proactor event loop + (default on py3.8+). + + .. warning:: Only use this on raw sockets that have not been wrapped by any higher level + constructs like socket streams! + + :param sock: a socket object + :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the + socket to become readable + :raises ~anyio.BusyResourceError: if another task is already waiting for the socket + to become readable + + """ + return get_asynclib().wait_socket_readable(sock) + + +def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: + """ + Wait until the given socket can be written to. + + This does **NOT** work on Windows when using the asyncio backend with a proactor event loop + (default on py3.8+). + + .. warning:: Only use this on raw sockets that have not been wrapped by any higher level + constructs like socket streams! + + :param sock: a socket object + :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the + socket to become writable + :raises ~anyio.BusyResourceError: if another task is already waiting for the socket + to become writable + + """ + return get_asynclib().wait_socket_writable(sock) + + +# +# Private API +# + + +def convert_ipv6_sockaddr( + sockaddr: tuple[str, int, int, int] | tuple[str, int] +) -> tuple[str, int]: + """ + Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format. + + If the scope ID is nonzero, it is added to the address, separated with ``%``. + Otherwise the flow id and scope id are simply cut off from the tuple. + Any other kinds of socket addresses are returned as-is. + + :param sockaddr: the result of :meth:`~socket.socket.getsockname` + :return: the converted socket address + + """ + # This is more complicated than it should be because of MyPy + if isinstance(sockaddr, tuple) and len(sockaddr) == 4: + host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr) + if scope_id: + # PyPy (as of v7.3.11) leaves the interface name in the result, so + # we discard it and only get the scope ID from the end + # (https://foss.heptapod.net/pypy/pypy/-/issues/3938) + host = host.split("%")[0] + + # Add scope_id to the address + return f"{host}%{scope_id}", port + else: + return host, port + else: + return cast(Tuple[str, int], sockaddr) diff --git a/contrib/python/anyio/anyio/_core/_streams.py b/contrib/python/anyio/anyio/_core/_streams.py new file mode 100644 index 0000000000..54ea2b2baf --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_streams.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import math +from typing import Any, TypeVar, overload + +from ..streams.memory import ( + MemoryObjectReceiveStream, + MemoryObjectSendStream, + MemoryObjectStreamState, +) + +T_Item = TypeVar("T_Item") + + +@overload +def create_memory_object_stream( + max_buffer_size: float = ..., +) -> tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: + ... + + +@overload +def create_memory_object_stream( + max_buffer_size: float = ..., item_type: type[T_Item] = ... +) -> tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]]: + ... + + +def create_memory_object_stream( + max_buffer_size: float = 0, item_type: type[T_Item] | None = None +) -> tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: + """ + Create a memory object stream. + + :param max_buffer_size: number of items held in the buffer until ``send()`` starts blocking + :param item_type: type of item, for marking the streams with the right generic type for + static typing (not used at run time) + :return: a tuple of (send stream, receive stream) + + """ + if max_buffer_size != math.inf and not isinstance(max_buffer_size, int): + raise ValueError("max_buffer_size must be either an integer or math.inf") + if max_buffer_size < 0: + raise ValueError("max_buffer_size cannot be negative") + + state: MemoryObjectStreamState = MemoryObjectStreamState(max_buffer_size) + return MemoryObjectSendStream(state), MemoryObjectReceiveStream(state) diff --git a/contrib/python/anyio/anyio/_core/_subprocesses.py b/contrib/python/anyio/anyio/_core/_subprocesses.py new file mode 100644 index 0000000000..1a26ac8c7f --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_subprocesses.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from io import BytesIO +from os import PathLike +from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess +from typing import ( + IO, + Any, + AsyncIterable, + Mapping, + Sequence, + cast, +) + +from ..abc import Process +from ._eventloop import get_asynclib +from ._tasks import create_task_group + + +async def run_process( + command: str | bytes | Sequence[str | bytes], + *, + input: bytes | None = None, + stdout: int | IO[Any] | None = PIPE, + stderr: int | IO[Any] | None = PIPE, + check: bool = True, + cwd: str | bytes | PathLike[str] | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, +) -> CompletedProcess[bytes]: + """ + Run an external command in a subprocess and wait until it completes. + + .. seealso:: :func:`subprocess.run` + + :param command: either a string to pass to the shell, or an iterable of strings containing the + executable name or path and its arguments + :param input: bytes passed to the standard input of the subprocess + :param stdout: either :data:`subprocess.PIPE` or :data:`subprocess.DEVNULL` + :param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL` or + :data:`subprocess.STDOUT` + :param check: if ``True``, raise :exc:`~subprocess.CalledProcessError` if the process + terminates with a return code other than 0 + :param cwd: If not ``None``, change the working directory to this before running the command + :param env: if not ``None``, this mapping replaces the inherited environment variables from the + parent process + :param start_new_session: if ``true`` the setsid() system call will be made in the child + process prior to the execution of the subprocess. (POSIX only) + :return: an object representing the completed process + :raises ~subprocess.CalledProcessError: if ``check`` is ``True`` and the process exits with a + nonzero return code + + """ + + async def drain_stream(stream: AsyncIterable[bytes], index: int) -> None: + buffer = BytesIO() + async for chunk in stream: + buffer.write(chunk) + + stream_contents[index] = buffer.getvalue() + + async with await open_process( + command, + stdin=PIPE if input else DEVNULL, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) as process: + stream_contents: list[bytes | None] = [None, None] + try: + async with create_task_group() as tg: + if process.stdout: + tg.start_soon(drain_stream, process.stdout, 0) + if process.stderr: + tg.start_soon(drain_stream, process.stderr, 1) + if process.stdin and input: + await process.stdin.send(input) + await process.stdin.aclose() + + await process.wait() + except BaseException: + process.kill() + raise + + output, errors = stream_contents + if check and process.returncode != 0: + raise CalledProcessError(cast(int, process.returncode), command, output, errors) + + return CompletedProcess(command, cast(int, process.returncode), output, errors) + + +async def open_process( + command: str | bytes | Sequence[str | bytes], + *, + stdin: int | IO[Any] | None = PIPE, + stdout: int | IO[Any] | None = PIPE, + stderr: int | IO[Any] | None = PIPE, + cwd: str | bytes | PathLike[str] | None = None, + env: Mapping[str, str] | None = None, + start_new_session: bool = False, +) -> Process: + """ + Start an external command in a subprocess. + + .. seealso:: :class:`subprocess.Popen` + + :param command: either a string to pass to the shell, or an iterable of strings containing the + executable name or path and its arguments + :param stdin: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, a + file-like object, or ``None`` + :param stdout: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, + a file-like object, or ``None`` + :param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, + :data:`subprocess.STDOUT`, a file-like object, or ``None`` + :param cwd: If not ``None``, the working directory is changed before executing + :param env: If env is not ``None``, it must be a mapping that defines the environment + variables for the new process + :param start_new_session: if ``true`` the setsid() system call will be made in the child + process prior to the execution of the subprocess. (POSIX only) + :return: an asynchronous process object + + """ + shell = isinstance(command, str) + return await get_asynclib().open_process( + command, + shell=shell, + stdin=stdin, + stdout=stdout, + stderr=stderr, + cwd=cwd, + env=env, + start_new_session=start_new_session, + ) diff --git a/contrib/python/anyio/anyio/_core/_synchronization.py b/contrib/python/anyio/anyio/_core/_synchronization.py new file mode 100644 index 0000000000..783570c7ac --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_synchronization.py @@ -0,0 +1,596 @@ +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +from types import TracebackType +from warnings import warn + +from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled +from ._compat import DeprecatedAwaitable +from ._eventloop import get_asynclib +from ._exceptions import BusyResourceError, WouldBlock +from ._tasks import CancelScope +from ._testing import TaskInfo, get_current_task + + +@dataclass(frozen=True) +class EventStatistics: + """ + :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Event.wait` + """ + + tasks_waiting: int + + +@dataclass(frozen=True) +class CapacityLimiterStatistics: + """ + :ivar int borrowed_tokens: number of tokens currently borrowed by tasks + :ivar float total_tokens: total number of available tokens + :ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from this + limiter + :ivar int tasks_waiting: number of tasks waiting on :meth:`~.CapacityLimiter.acquire` or + :meth:`~.CapacityLimiter.acquire_on_behalf_of` + """ + + borrowed_tokens: int + total_tokens: float + borrowers: tuple[object, ...] + tasks_waiting: int + + +@dataclass(frozen=True) +class LockStatistics: + """ + :ivar bool locked: flag indicating if this lock is locked or not + :ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the lock is not + held by any task) + :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Lock.acquire` + """ + + locked: bool + owner: TaskInfo | None + tasks_waiting: int + + +@dataclass(frozen=True) +class ConditionStatistics: + """ + :ivar int tasks_waiting: number of tasks blocked on :meth:`~.Condition.wait` + :ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying :class:`~.Lock` + """ + + tasks_waiting: int + lock_statistics: LockStatistics + + +@dataclass(frozen=True) +class SemaphoreStatistics: + """ + :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Semaphore.acquire` + + """ + + tasks_waiting: int + + +class Event: + def __new__(cls) -> Event: + return get_asynclib().Event() + + def set(self) -> DeprecatedAwaitable: + """Set the flag, notifying all listeners.""" + raise NotImplementedError + + def is_set(self) -> bool: + """Return ``True`` if the flag is set, ``False`` if not.""" + raise NotImplementedError + + async def wait(self) -> None: + """ + Wait until the flag has been set. + + If the flag has already been set when this method is called, it returns immediately. + + """ + raise NotImplementedError + + def statistics(self) -> EventStatistics: + """Return statistics about the current state of this event.""" + raise NotImplementedError + + +class Lock: + _owner_task: TaskInfo | None = None + + def __init__(self) -> None: + self._waiters: deque[tuple[TaskInfo, Event]] = deque() + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.release() + + async def acquire(self) -> None: + """Acquire the lock.""" + await checkpoint_if_cancelled() + try: + self.acquire_nowait() + except WouldBlock: + task = get_current_task() + event = Event() + token = task, event + self._waiters.append(token) + try: + await event.wait() + except BaseException: + if not event.is_set(): + self._waiters.remove(token) + elif self._owner_task == task: + self.release() + + raise + + assert self._owner_task == task + else: + try: + await cancel_shielded_checkpoint() + except BaseException: + self.release() + raise + + def acquire_nowait(self) -> None: + """ + Acquire the lock, without blocking. + + :raises ~anyio.WouldBlock: if the operation would block + + """ + task = get_current_task() + if self._owner_task == task: + raise RuntimeError("Attempted to acquire an already held Lock") + + if self._owner_task is not None: + raise WouldBlock + + self._owner_task = task + + def release(self) -> DeprecatedAwaitable: + """Release the lock.""" + if self._owner_task != get_current_task(): + raise RuntimeError("The current task is not holding this lock") + + if self._waiters: + self._owner_task, event = self._waiters.popleft() + event.set() + else: + del self._owner_task + + return DeprecatedAwaitable(self.release) + + def locked(self) -> bool: + """Return True if the lock is currently held.""" + return self._owner_task is not None + + def statistics(self) -> LockStatistics: + """ + Return statistics about the current state of this lock. + + .. versionadded:: 3.0 + """ + return LockStatistics(self.locked(), self._owner_task, len(self._waiters)) + + +class Condition: + _owner_task: TaskInfo | None = None + + def __init__(self, lock: Lock | None = None): + self._lock = lock or Lock() + self._waiters: deque[Event] = deque() + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.release() + + def _check_acquired(self) -> None: + if self._owner_task != get_current_task(): + raise RuntimeError("The current task is not holding the underlying lock") + + async def acquire(self) -> None: + """Acquire the underlying lock.""" + await self._lock.acquire() + self._owner_task = get_current_task() + + def acquire_nowait(self) -> None: + """ + Acquire the underlying lock, without blocking. + + :raises ~anyio.WouldBlock: if the operation would block + + """ + self._lock.acquire_nowait() + self._owner_task = get_current_task() + + def release(self) -> DeprecatedAwaitable: + """Release the underlying lock.""" + self._lock.release() + return DeprecatedAwaitable(self.release) + + def locked(self) -> bool: + """Return True if the lock is set.""" + return self._lock.locked() + + def notify(self, n: int = 1) -> None: + """Notify exactly n listeners.""" + self._check_acquired() + for _ in range(n): + try: + event = self._waiters.popleft() + except IndexError: + break + + event.set() + + def notify_all(self) -> None: + """Notify all the listeners.""" + self._check_acquired() + for event in self._waiters: + event.set() + + self._waiters.clear() + + async def wait(self) -> None: + """Wait for a notification.""" + await checkpoint() + event = Event() + self._waiters.append(event) + self.release() + try: + await event.wait() + except BaseException: + if not event.is_set(): + self._waiters.remove(event) + + raise + finally: + with CancelScope(shield=True): + await self.acquire() + + def statistics(self) -> ConditionStatistics: + """ + Return statistics about the current state of this condition. + + .. versionadded:: 3.0 + """ + return ConditionStatistics(len(self._waiters), self._lock.statistics()) + + +class Semaphore: + def __init__(self, initial_value: int, *, max_value: int | None = None): + if not isinstance(initial_value, int): + raise TypeError("initial_value must be an integer") + if initial_value < 0: + raise ValueError("initial_value must be >= 0") + if max_value is not None: + if not isinstance(max_value, int): + raise TypeError("max_value must be an integer or None") + if max_value < initial_value: + raise ValueError( + "max_value must be equal to or higher than initial_value" + ) + + self._value = initial_value + self._max_value = max_value + self._waiters: deque[Event] = deque() + + async def __aenter__(self) -> Semaphore: + await self.acquire() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.release() + + async def acquire(self) -> None: + """Decrement the semaphore value, blocking if necessary.""" + await checkpoint_if_cancelled() + try: + self.acquire_nowait() + except WouldBlock: + event = Event() + self._waiters.append(event) + try: + await event.wait() + except BaseException: + if not event.is_set(): + self._waiters.remove(event) + else: + self.release() + + raise + else: + try: + await cancel_shielded_checkpoint() + except BaseException: + self.release() + raise + + def acquire_nowait(self) -> None: + """ + Acquire the underlying lock, without blocking. + + :raises ~anyio.WouldBlock: if the operation would block + + """ + if self._value == 0: + raise WouldBlock + + self._value -= 1 + + def release(self) -> DeprecatedAwaitable: + """Increment the semaphore value.""" + if self._max_value is not None and self._value == self._max_value: + raise ValueError("semaphore released too many times") + + if self._waiters: + self._waiters.popleft().set() + else: + self._value += 1 + + return DeprecatedAwaitable(self.release) + + @property + def value(self) -> int: + """The current value of the semaphore.""" + return self._value + + @property + def max_value(self) -> int | None: + """The maximum value of the semaphore.""" + return self._max_value + + def statistics(self) -> SemaphoreStatistics: + """ + Return statistics about the current state of this semaphore. + + .. versionadded:: 3.0 + """ + return SemaphoreStatistics(len(self._waiters)) + + +class CapacityLimiter: + def __new__(cls, total_tokens: float) -> CapacityLimiter: + return get_asynclib().CapacityLimiter(total_tokens) + + async def __aenter__(self) -> None: + raise NotImplementedError + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + raise NotImplementedError + + @property + def total_tokens(self) -> float: + """ + The total number of tokens available for borrowing. + + This is a read-write property. If the total number of tokens is increased, the + proportionate number of tasks waiting on this limiter will be granted their tokens. + + .. versionchanged:: 3.0 + The property is now writable. + + """ + raise NotImplementedError + + @total_tokens.setter + def total_tokens(self, value: float) -> None: + raise NotImplementedError + + async def set_total_tokens(self, value: float) -> None: + warn( + "CapacityLimiter.set_total_tokens has been deprecated. Set the value of the" + '"total_tokens" attribute directly.', + DeprecationWarning, + ) + self.total_tokens = value + + @property + def borrowed_tokens(self) -> int: + """The number of tokens that have currently been borrowed.""" + raise NotImplementedError + + @property + def available_tokens(self) -> float: + """The number of tokens currently available to be borrowed""" + raise NotImplementedError + + def acquire_nowait(self) -> DeprecatedAwaitable: + """ + Acquire a token for the current task without waiting for one to become available. + + :raises ~anyio.WouldBlock: if there are no tokens available for borrowing + + """ + raise NotImplementedError + + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: + """ + Acquire a token without waiting for one to become available. + + :param borrower: the entity borrowing a token + :raises ~anyio.WouldBlock: if there are no tokens available for borrowing + + """ + raise NotImplementedError + + async def acquire(self) -> None: + """ + Acquire a token for the current task, waiting if necessary for one to become available. + + """ + raise NotImplementedError + + async def acquire_on_behalf_of(self, borrower: object) -> None: + """ + Acquire a token, waiting if necessary for one to become available. + + :param borrower: the entity borrowing a token + + """ + raise NotImplementedError + + def release(self) -> None: + """ + Release the token held by the current task. + :raises RuntimeError: if the current task has not borrowed a token from this limiter. + + """ + raise NotImplementedError + + def release_on_behalf_of(self, borrower: object) -> None: + """ + Release the token held by the given borrower. + + :raises RuntimeError: if the borrower has not borrowed a token from this limiter. + + """ + raise NotImplementedError + + def statistics(self) -> CapacityLimiterStatistics: + """ + Return statistics about the current state of this limiter. + + .. versionadded:: 3.0 + + """ + raise NotImplementedError + + +def create_lock() -> Lock: + """ + Create an asynchronous lock. + + :return: a lock object + + .. deprecated:: 3.0 + Use :class:`~Lock` directly. + + """ + warn("create_lock() is deprecated -- use Lock() directly", DeprecationWarning) + return Lock() + + +def create_condition(lock: Lock | None = None) -> Condition: + """ + Create an asynchronous condition. + + :param lock: the lock to base the condition object on + :return: a condition object + + .. deprecated:: 3.0 + Use :class:`~Condition` directly. + + """ + warn( + "create_condition() is deprecated -- use Condition() directly", + DeprecationWarning, + ) + return Condition(lock=lock) + + +def create_event() -> Event: + """ + Create an asynchronous event object. + + :return: an event object + + .. deprecated:: 3.0 + Use :class:`~Event` directly. + + """ + warn("create_event() is deprecated -- use Event() directly", DeprecationWarning) + return get_asynclib().Event() + + +def create_semaphore(value: int, *, max_value: int | None = None) -> Semaphore: + """ + Create an asynchronous semaphore. + + :param value: the semaphore's initial value + :param max_value: if set, makes this a "bounded" semaphore that raises :exc:`ValueError` if the + semaphore's value would exceed this number + :return: a semaphore object + + .. deprecated:: 3.0 + Use :class:`~Semaphore` directly. + + """ + warn( + "create_semaphore() is deprecated -- use Semaphore() directly", + DeprecationWarning, + ) + return Semaphore(value, max_value=max_value) + + +def create_capacity_limiter(total_tokens: float) -> CapacityLimiter: + """ + Create a capacity limiter. + + :param total_tokens: the total number of tokens available for borrowing (can be an integer or + :data:`math.inf`) + :return: a capacity limiter object + + .. deprecated:: 3.0 + Use :class:`~CapacityLimiter` directly. + + """ + warn( + "create_capacity_limiter() is deprecated -- use CapacityLimiter() directly", + DeprecationWarning, + ) + return get_asynclib().CapacityLimiter(total_tokens) + + +class ResourceGuard: + __slots__ = "action", "_guarded" + + def __init__(self, action: str): + self.action = action + self._guarded = False + + def __enter__(self) -> None: + if self._guarded: + raise BusyResourceError(self.action) + + self._guarded = True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + self._guarded = False + return None diff --git a/contrib/python/anyio/anyio/_core/_tasks.py b/contrib/python/anyio/anyio/_core/_tasks.py new file mode 100644 index 0000000000..e9d9c2bd67 --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_tasks.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import math +from types import TracebackType +from warnings import warn + +from ..abc._tasks import TaskGroup, TaskStatus +from ._compat import ( + DeprecatedAsyncContextManager, + DeprecatedAwaitable, + DeprecatedAwaitableFloat, +) +from ._eventloop import get_asynclib + + +class _IgnoredTaskStatus(TaskStatus[object]): + def started(self, value: object = None) -> None: + pass + + +TASK_STATUS_IGNORED = _IgnoredTaskStatus() + + +class CancelScope(DeprecatedAsyncContextManager["CancelScope"]): + """ + Wraps a unit of work that can be made separately cancellable. + + :param deadline: The time (clock value) when this scope is cancelled automatically + :param shield: ``True`` to shield the cancel scope from external cancellation + """ + + def __new__( + cls, *, deadline: float = math.inf, shield: bool = False + ) -> CancelScope: + return get_asynclib().CancelScope(shield=shield, deadline=deadline) + + def cancel(self) -> DeprecatedAwaitable: + """Cancel this scope immediately.""" + raise NotImplementedError + + @property + def deadline(self) -> float: + """ + The time (clock value) when this scope is cancelled automatically. + + Will be ``float('inf')`` if no timeout has been set. + + """ + raise NotImplementedError + + @deadline.setter + def deadline(self, value: float) -> None: + raise NotImplementedError + + @property + def cancel_called(self) -> bool: + """``True`` if :meth:`cancel` has been called.""" + raise NotImplementedError + + @property + def shield(self) -> bool: + """ + ``True`` if this scope is shielded from external cancellation. + + While a scope is shielded, it will not receive cancellations from outside. + + """ + raise NotImplementedError + + @shield.setter + def shield(self, value: bool) -> None: + raise NotImplementedError + + def __enter__(self) -> CancelScope: + raise NotImplementedError + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + raise NotImplementedError + + +def open_cancel_scope(*, shield: bool = False) -> CancelScope: + """ + Open a cancel scope. + + :param shield: ``True`` to shield the cancel scope from external cancellation + :return: a cancel scope + + .. deprecated:: 3.0 + Use :class:`~CancelScope` directly. + + """ + warn( + "open_cancel_scope() is deprecated -- use CancelScope() directly", + DeprecationWarning, + ) + return get_asynclib().CancelScope(shield=shield) + + +class FailAfterContextManager(DeprecatedAsyncContextManager[CancelScope]): + def __init__(self, cancel_scope: CancelScope): + self._cancel_scope = cancel_scope + + def __enter__(self) -> CancelScope: + return self._cancel_scope.__enter__() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + retval = self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) + if self._cancel_scope.cancel_called: + raise TimeoutError + + return retval + + +def fail_after(delay: float | None, shield: bool = False) -> FailAfterContextManager: + """ + Create a context manager which raises a :class:`TimeoutError` if does not finish in time. + + :param delay: maximum allowed time (in seconds) before raising the exception, or ``None`` to + disable the timeout + :param shield: ``True`` to shield the cancel scope from external cancellation + :return: a context manager that yields a cancel scope + :rtype: :class:`~typing.ContextManager`\\[:class:`~anyio.CancelScope`\\] + + """ + deadline = ( + (get_asynclib().current_time() + delay) if delay is not None else math.inf + ) + cancel_scope = get_asynclib().CancelScope(deadline=deadline, shield=shield) + return FailAfterContextManager(cancel_scope) + + +def move_on_after(delay: float | None, shield: bool = False) -> CancelScope: + """ + Create a cancel scope with a deadline that expires after the given delay. + + :param delay: maximum allowed time (in seconds) before exiting the context block, or ``None`` + to disable the timeout + :param shield: ``True`` to shield the cancel scope from external cancellation + :return: a cancel scope + + """ + deadline = ( + (get_asynclib().current_time() + delay) if delay is not None else math.inf + ) + return get_asynclib().CancelScope(deadline=deadline, shield=shield) + + +def current_effective_deadline() -> DeprecatedAwaitableFloat: + """ + Return the nearest deadline among all the cancel scopes effective for the current task. + + :return: a clock value from the event loop's internal clock (or ``float('inf')`` if + there is no deadline in effect, or ``float('-inf')`` if the current scope has + been cancelled) + :rtype: float + + """ + return DeprecatedAwaitableFloat( + get_asynclib().current_effective_deadline(), current_effective_deadline + ) + + +def create_task_group() -> TaskGroup: + """ + Create a task group. + + :return: a task group + + """ + return get_asynclib().TaskGroup() diff --git a/contrib/python/anyio/anyio/_core/_testing.py b/contrib/python/anyio/anyio/_core/_testing.py new file mode 100644 index 0000000000..c8191b3866 --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_testing.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import Any, Awaitable, Generator + +from ._compat import DeprecatedAwaitableList, _warn_deprecation +from ._eventloop import get_asynclib + + +class TaskInfo: + """ + Represents an asynchronous task. + + :ivar int id: the unique identifier of the task + :ivar parent_id: the identifier of the parent task, if any + :vartype parent_id: Optional[int] + :ivar str name: the description of the task (if any) + :ivar ~collections.abc.Coroutine coro: the coroutine object of the task + """ + + __slots__ = "_name", "id", "parent_id", "name", "coro" + + def __init__( + self, + id: int, + parent_id: int | None, + name: str | None, + coro: Generator[Any, Any, Any] | Awaitable[Any], + ): + func = get_current_task + self._name = f"{func.__module__}.{func.__qualname__}" + self.id: int = id + self.parent_id: int | None = parent_id + self.name: str | None = name + self.coro: Generator[Any, Any, Any] | Awaitable[Any] = coro + + def __eq__(self, other: object) -> bool: + if isinstance(other, TaskInfo): + return self.id == other.id + + return NotImplemented + + def __hash__(self) -> int: + return hash(self.id) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})" + + def __await__(self) -> Generator[None, None, TaskInfo]: + _warn_deprecation(self) + if False: + yield + + return self + + def _unwrap(self) -> TaskInfo: + return self + + +def get_current_task() -> TaskInfo: + """ + Return the current task. + + :return: a representation of the current task + + """ + return get_asynclib().get_current_task() + + +def get_running_tasks() -> DeprecatedAwaitableList[TaskInfo]: + """ + Return a list of running tasks in the current event loop. + + :return: a list of task info objects + + """ + tasks = get_asynclib().get_running_tasks() + return DeprecatedAwaitableList(tasks, func=get_running_tasks) + + +async def wait_all_tasks_blocked() -> None: + """Wait until all other tasks are waiting for something.""" + await get_asynclib().wait_all_tasks_blocked() diff --git a/contrib/python/anyio/anyio/_core/_typedattr.py b/contrib/python/anyio/anyio/_core/_typedattr.py new file mode 100644 index 0000000000..bf9202eeab --- /dev/null +++ b/contrib/python/anyio/anyio/_core/_typedattr.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import sys +from typing import Any, Callable, Mapping, TypeVar, overload + +from ._exceptions import TypedAttributeLookupError + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + +T_Attr = TypeVar("T_Attr") +T_Default = TypeVar("T_Default") +undefined = object() + + +def typed_attribute() -> Any: + """Return a unique object, used to mark typed attributes.""" + return object() + + +class TypedAttributeSet: + """ + Superclass for typed attribute collections. + + Checks that every public attribute of every subclass has a type annotation. + """ + + def __init_subclass__(cls) -> None: + annotations: dict[str, Any] = getattr(cls, "__annotations__", {}) + for attrname in dir(cls): + if not attrname.startswith("_") and attrname not in annotations: + raise TypeError( + f"Attribute {attrname!r} is missing its type annotation" + ) + + super().__init_subclass__() + + +class TypedAttributeProvider: + """Base class for classes that wish to provide typed extra attributes.""" + + @property + def extra_attributes(self) -> Mapping[T_Attr, Callable[[], T_Attr]]: + """ + A mapping of the extra attributes to callables that return the corresponding values. + + If the provider wraps another provider, the attributes from that wrapper should also be + included in the returned mapping (but the wrapper may override the callables from the + wrapped instance). + + """ + return {} + + @overload + def extra(self, attribute: T_Attr) -> T_Attr: + ... + + @overload + def extra(self, attribute: T_Attr, default: T_Default) -> T_Attr | T_Default: + ... + + @final + def extra(self, attribute: Any, default: object = undefined) -> object: + """ + extra(attribute, default=undefined) + + Return the value of the given typed extra attribute. + + :param attribute: the attribute (member of a :class:`~TypedAttributeSet`) to look for + :param default: the value that should be returned if no value is found for the attribute + :raises ~anyio.TypedAttributeLookupError: if the search failed and no default value was + given + + """ + try: + return self.extra_attributes[attribute]() + except KeyError: + if default is undefined: + raise TypedAttributeLookupError("Attribute not found") from None + else: + return default diff --git a/contrib/python/anyio/anyio/abc/__init__.py b/contrib/python/anyio/anyio/abc/__init__.py new file mode 100644 index 0000000000..72c34e544e --- /dev/null +++ b/contrib/python/anyio/anyio/abc/__init__.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +__all__ = ( + "AsyncResource", + "IPAddressType", + "IPSockAddrType", + "SocketAttribute", + "SocketStream", + "SocketListener", + "UDPSocket", + "UNIXSocketStream", + "UDPPacketType", + "ConnectedUDPSocket", + "UnreliableObjectReceiveStream", + "UnreliableObjectSendStream", + "UnreliableObjectStream", + "ObjectReceiveStream", + "ObjectSendStream", + "ObjectStream", + "ByteReceiveStream", + "ByteSendStream", + "ByteStream", + "AnyUnreliableByteReceiveStream", + "AnyUnreliableByteSendStream", + "AnyUnreliableByteStream", + "AnyByteReceiveStream", + "AnyByteSendStream", + "AnyByteStream", + "Listener", + "Process", + "Event", + "Condition", + "Lock", + "Semaphore", + "CapacityLimiter", + "CancelScope", + "TaskGroup", + "TaskStatus", + "TestRunner", + "BlockingPortal", +) + +from typing import Any + +from ._resources import AsyncResource +from ._sockets import ( + ConnectedUDPSocket, + IPAddressType, + IPSockAddrType, + SocketAttribute, + SocketListener, + SocketStream, + UDPPacketType, + UDPSocket, + UNIXSocketStream, +) +from ._streams import ( + AnyByteReceiveStream, + AnyByteSendStream, + AnyByteStream, + AnyUnreliableByteReceiveStream, + AnyUnreliableByteSendStream, + AnyUnreliableByteStream, + ByteReceiveStream, + ByteSendStream, + ByteStream, + Listener, + ObjectReceiveStream, + ObjectSendStream, + ObjectStream, + UnreliableObjectReceiveStream, + UnreliableObjectSendStream, + UnreliableObjectStream, +) +from ._subprocesses import Process +from ._tasks import TaskGroup, TaskStatus +from ._testing import TestRunner + +# Re-exported here, for backwards compatibility +# isort: off +from .._core._synchronization import CapacityLimiter, Condition, Event, Lock, Semaphore +from .._core._tasks import CancelScope +from ..from_thread import BlockingPortal + +# Re-export imports so they look like they live directly in this package +key: str +value: Any +for key, value in list(locals().items()): + if getattr(value, "__module__", "").startswith("anyio.abc."): + value.__module__ = __name__ diff --git a/contrib/python/anyio/anyio/abc/_resources.py b/contrib/python/anyio/anyio/abc/_resources.py new file mode 100644 index 0000000000..e0a283fc98 --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_resources.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from types import TracebackType +from typing import TypeVar + +T = TypeVar("T") + + +class AsyncResource(metaclass=ABCMeta): + """ + Abstract base class for all closeable asynchronous resources. + + Works as an asynchronous context manager which returns the instance itself on enter, and calls + :meth:`aclose` on exit. + """ + + async def __aenter__(self: T) -> T: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.aclose() + + @abstractmethod + async def aclose(self) -> None: + """Close the resource.""" diff --git a/contrib/python/anyio/anyio/abc/_sockets.py b/contrib/python/anyio/anyio/abc/_sockets.py new file mode 100644 index 0000000000..6aac5f7c22 --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_sockets.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import socket +from abc import abstractmethod +from contextlib import AsyncExitStack +from io import IOBase +from ipaddress import IPv4Address, IPv6Address +from socket import AddressFamily +from typing import ( + Any, + Callable, + Collection, + Mapping, + Tuple, + TypeVar, + Union, +) + +from .._core._tasks import create_task_group +from .._core._typedattr import ( + TypedAttributeProvider, + TypedAttributeSet, + typed_attribute, +) +from ._streams import ByteStream, Listener, UnreliableObjectStream +from ._tasks import TaskGroup + +IPAddressType = Union[str, IPv4Address, IPv6Address] +IPSockAddrType = Tuple[str, int] +SockAddrType = Union[IPSockAddrType, str] +UDPPacketType = Tuple[bytes, IPSockAddrType] +T_Retval = TypeVar("T_Retval") + + +class SocketAttribute(TypedAttributeSet): + #: the address family of the underlying socket + family: AddressFamily = typed_attribute() + #: the local socket address of the underlying socket + local_address: SockAddrType = typed_attribute() + #: for IP addresses, the local port the underlying socket is bound to + local_port: int = typed_attribute() + #: the underlying stdlib socket object + raw_socket: socket.socket = typed_attribute() + #: the remote address the underlying socket is connected to + remote_address: SockAddrType = typed_attribute() + #: for IP addresses, the remote port the underlying socket is connected to + remote_port: int = typed_attribute() + + +class _SocketProvider(TypedAttributeProvider): + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + from .._core._sockets import convert_ipv6_sockaddr as convert + + attributes: dict[Any, Callable[[], Any]] = { + SocketAttribute.family: lambda: self._raw_socket.family, + SocketAttribute.local_address: lambda: convert( + self._raw_socket.getsockname() + ), + SocketAttribute.raw_socket: lambda: self._raw_socket, + } + try: + peername: tuple[str, int] | None = convert(self._raw_socket.getpeername()) + except OSError: + peername = None + + # Provide the remote address for connected sockets + if peername is not None: + attributes[SocketAttribute.remote_address] = lambda: peername + + # Provide local and remote ports for IP based sockets + if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6): + attributes[ + SocketAttribute.local_port + ] = lambda: self._raw_socket.getsockname()[1] + if peername is not None: + remote_port = peername[1] + attributes[SocketAttribute.remote_port] = lambda: remote_port + + return attributes + + @property + @abstractmethod + def _raw_socket(self) -> socket.socket: + pass + + +class SocketStream(ByteStream, _SocketProvider): + """ + Transports bytes over a socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + +class UNIXSocketStream(SocketStream): + @abstractmethod + async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: + """ + Send file descriptors along with a message to the peer. + + :param message: a non-empty bytestring + :param fds: a collection of files (either numeric file descriptors or open file or socket + objects) + """ + + @abstractmethod + async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: + """ + Receive file descriptors along with a message from the peer. + + :param msglen: length of the message to expect from the peer + :param maxfds: maximum number of file descriptors to expect from the peer + :return: a tuple of (message, file descriptors) + """ + + +class SocketListener(Listener[SocketStream], _SocketProvider): + """ + Listens to incoming socket connections. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + @abstractmethod + async def accept(self) -> SocketStream: + """Accept an incoming connection.""" + + async def serve( + self, + handler: Callable[[SocketStream], Any], + task_group: TaskGroup | None = None, + ) -> None: + async with AsyncExitStack() as exit_stack: + if task_group is None: + task_group = await exit_stack.enter_async_context(create_task_group()) + + while True: + stream = await self.accept() + task_group.start_soon(handler, stream) + + +class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider): + """ + Represents an unconnected UDP socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ + + async def sendto(self, data: bytes, host: str, port: int) -> None: + """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).""" + return await self.send((data, (host, port))) + + +class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider): + """ + Represents an connected UDP socket. + + Supports all relevant extra attributes from :class:`~SocketAttribute`. + """ diff --git a/contrib/python/anyio/anyio/abc/_streams.py b/contrib/python/anyio/anyio/abc/_streams.py new file mode 100644 index 0000000000..4fa7ccc9ff --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_streams.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Callable, Generic, TypeVar, Union + +from .._core._exceptions import EndOfStream +from .._core._typedattr import TypedAttributeProvider +from ._resources import AsyncResource +from ._tasks import TaskGroup + +T_Item = TypeVar("T_Item") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +class UnreliableObjectReceiveStream( + Generic[T_co], AsyncResource, TypedAttributeProvider +): + """ + An interface for receiving objects. + + This interface makes no guarantees that the received messages arrive in the order in which they + were sent, or that no messages are missed. + + Asynchronously iterating over objects of this type will yield objects matching the given type + parameter. + """ + + def __aiter__(self) -> UnreliableObjectReceiveStream[T_co]: + return self + + async def __anext__(self) -> T_co: + try: + return await self.receive() + except EndOfStream: + raise StopAsyncIteration + + @abstractmethod + async def receive(self) -> T_co: + """ + Receive the next item. + + :raises ~anyio.ClosedResourceError: if the receive stream has been explicitly + closed + :raises ~anyio.EndOfStream: if this stream has been closed from the other end + :raises ~anyio.BrokenResourceError: if this stream has been rendered unusable + due to external causes + """ + + +class UnreliableObjectSendStream( + Generic[T_contra], AsyncResource, TypedAttributeProvider +): + """ + An interface for sending objects. + + This interface makes no guarantees that the messages sent will reach the recipient(s) in the + same order in which they were sent, or at all. + """ + + @abstractmethod + async def send(self, item: T_contra) -> None: + """ + Send an item to the peer(s). + + :param item: the item to send + :raises ~anyio.ClosedResourceError: if the send stream has been explicitly + closed + :raises ~anyio.BrokenResourceError: if this stream has been rendered unusable + due to external causes + """ + + +class UnreliableObjectStream( + UnreliableObjectReceiveStream[T_Item], UnreliableObjectSendStream[T_Item] +): + """ + A bidirectional message stream which does not guarantee the order or reliability of message + delivery. + """ + + +class ObjectReceiveStream(UnreliableObjectReceiveStream[T_co]): + """ + A receive message stream which guarantees that messages are received in the same order in + which they were sent, and that no messages are missed. + """ + + +class ObjectSendStream(UnreliableObjectSendStream[T_contra]): + """ + A send message stream which guarantees that messages are delivered in the same order in which + they were sent, without missing any messages in the middle. + """ + + +class ObjectStream( + ObjectReceiveStream[T_Item], + ObjectSendStream[T_Item], + UnreliableObjectStream[T_Item], +): + """ + A bidirectional message stream which guarantees the order and reliability of message delivery. + """ + + @abstractmethod + async def send_eof(self) -> None: + """ + Send an end-of-file indication to the peer. + + You should not try to send any further data to this stream after calling this method. + This method is idempotent (does nothing on successive calls). + """ + + +class ByteReceiveStream(AsyncResource, TypedAttributeProvider): + """ + An interface for receiving bytes from a single peer. + + Iterating this byte stream will yield a byte string of arbitrary length, but no more than + 65536 bytes. + """ + + def __aiter__(self) -> ByteReceiveStream: + return self + + async def __anext__(self) -> bytes: + try: + return await self.receive() + except EndOfStream: + raise StopAsyncIteration + + @abstractmethod + async def receive(self, max_bytes: int = 65536) -> bytes: + """ + Receive at most ``max_bytes`` bytes from the peer. + + .. note:: Implementors of this interface should not return an empty :class:`bytes` object, + and users should ignore them. + + :param max_bytes: maximum number of bytes to receive + :return: the received bytes + :raises ~anyio.EndOfStream: if this stream has been closed from the other end + """ + + +class ByteSendStream(AsyncResource, TypedAttributeProvider): + """An interface for sending bytes to a single peer.""" + + @abstractmethod + async def send(self, item: bytes) -> None: + """ + Send the given bytes to the peer. + + :param item: the bytes to send + """ + + +class ByteStream(ByteReceiveStream, ByteSendStream): + """A bidirectional byte stream.""" + + @abstractmethod + async def send_eof(self) -> None: + """ + Send an end-of-file indication to the peer. + + You should not try to send any further data to this stream after calling this method. + This method is idempotent (does nothing on successive calls). + """ + + +#: Type alias for all unreliable bytes-oriented receive streams. +AnyUnreliableByteReceiveStream = Union[ + UnreliableObjectReceiveStream[bytes], ByteReceiveStream +] +#: Type alias for all unreliable bytes-oriented send streams. +AnyUnreliableByteSendStream = Union[UnreliableObjectSendStream[bytes], ByteSendStream] +#: Type alias for all unreliable bytes-oriented streams. +AnyUnreliableByteStream = Union[UnreliableObjectStream[bytes], ByteStream] +#: Type alias for all bytes-oriented receive streams. +AnyByteReceiveStream = Union[ObjectReceiveStream[bytes], ByteReceiveStream] +#: Type alias for all bytes-oriented send streams. +AnyByteSendStream = Union[ObjectSendStream[bytes], ByteSendStream] +#: Type alias for all bytes-oriented streams. +AnyByteStream = Union[ObjectStream[bytes], ByteStream] + + +class Listener(Generic[T_co], AsyncResource, TypedAttributeProvider): + """An interface for objects that let you accept incoming connections.""" + + @abstractmethod + async def serve( + self, + handler: Callable[[T_co], Any], + task_group: TaskGroup | None = None, + ) -> None: + """ + Accept incoming connections as they come in and start tasks to handle them. + + :param handler: a callable that will be used to handle each accepted connection + :param task_group: the task group that will be used to start tasks for handling each + accepted connection (if omitted, an ad-hoc task group will be created) + """ diff --git a/contrib/python/anyio/anyio/abc/_subprocesses.py b/contrib/python/anyio/anyio/abc/_subprocesses.py new file mode 100644 index 0000000000..704b44a2dd --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_subprocesses.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from abc import abstractmethod +from signal import Signals + +from ._resources import AsyncResource +from ._streams import ByteReceiveStream, ByteSendStream + + +class Process(AsyncResource): + """An asynchronous version of :class:`subprocess.Popen`.""" + + @abstractmethod + async def wait(self) -> int: + """ + Wait until the process exits. + + :return: the exit code of the process + """ + + @abstractmethod + def terminate(self) -> None: + """ + Terminates the process, gracefully if possible. + + On Windows, this calls ``TerminateProcess()``. + On POSIX systems, this sends ``SIGTERM`` to the process. + + .. seealso:: :meth:`subprocess.Popen.terminate` + """ + + @abstractmethod + def kill(self) -> None: + """ + Kills the process. + + On Windows, this calls ``TerminateProcess()``. + On POSIX systems, this sends ``SIGKILL`` to the process. + + .. seealso:: :meth:`subprocess.Popen.kill` + """ + + @abstractmethod + def send_signal(self, signal: Signals) -> None: + """ + Send a signal to the subprocess. + + .. seealso:: :meth:`subprocess.Popen.send_signal` + + :param signal: the signal number (e.g. :data:`signal.SIGHUP`) + """ + + @property + @abstractmethod + def pid(self) -> int: + """The process ID of the process.""" + + @property + @abstractmethod + def returncode(self) -> int | None: + """ + The return code of the process. If the process has not yet terminated, this will be + ``None``. + """ + + @property + @abstractmethod + def stdin(self) -> ByteSendStream | None: + """The stream for the standard input of the process.""" + + @property + @abstractmethod + def stdout(self) -> ByteReceiveStream | None: + """The stream for the standard output of the process.""" + + @property + @abstractmethod + def stderr(self) -> ByteReceiveStream | None: + """The stream for the standard error output of the process.""" diff --git a/contrib/python/anyio/anyio/abc/_tasks.py b/contrib/python/anyio/anyio/abc/_tasks.py new file mode 100644 index 0000000000..e48d3c1e97 --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_tasks.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import sys +from abc import ABCMeta, abstractmethod +from types import TracebackType +from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, overload +from warnings import warn + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + +if TYPE_CHECKING: + from anyio._core._tasks import CancelScope + +T_Retval = TypeVar("T_Retval") +T_contra = TypeVar("T_contra", contravariant=True) + + +class TaskStatus(Protocol[T_contra]): + @overload + def started(self: TaskStatus[None]) -> None: + ... + + @overload + def started(self, value: T_contra) -> None: + ... + + def started(self, value: T_contra | None = None) -> None: + """ + Signal that the task has started. + + :param value: object passed back to the starter of the task + """ + + +class TaskGroup(metaclass=ABCMeta): + """ + Groups several asynchronous tasks together. + + :ivar cancel_scope: the cancel scope inherited by all child tasks + :vartype cancel_scope: CancelScope + """ + + cancel_scope: CancelScope + + async def spawn( + self, + func: Callable[..., Awaitable[Any]], + *args: object, + name: object = None, + ) -> None: + """ + Start a new task in this task group. + + :param func: a coroutine function + :param args: positional arguments to call the function with + :param name: name of the task, for the purposes of introspection and debugging + + .. deprecated:: 3.0 + Use :meth:`start_soon` instead. If your code needs AnyIO 2 compatibility, you + can keep using this until AnyIO 4. + + """ + warn( + 'spawn() is deprecated -- use start_soon() (without the "await") instead', + DeprecationWarning, + ) + self.start_soon(func, *args, name=name) + + @abstractmethod + def start_soon( + self, + func: Callable[..., Awaitable[Any]], + *args: object, + name: object = None, + ) -> None: + """ + Start a new task in this task group. + + :param func: a coroutine function + :param args: positional arguments to call the function with + :param name: name of the task, for the purposes of introspection and debugging + + .. versionadded:: 3.0 + """ + + @abstractmethod + async def start( + self, + func: Callable[..., Awaitable[Any]], + *args: object, + name: object = None, + ) -> Any: + """ + Start a new task and wait until it signals for readiness. + + :param func: a coroutine function + :param args: positional arguments to call the function with + :param name: name of the task, for the purposes of introspection and debugging + :return: the value passed to ``task_status.started()`` + :raises RuntimeError: if the task finishes without calling ``task_status.started()`` + + .. versionadded:: 3.0 + """ + + @abstractmethod + async def __aenter__(self) -> TaskGroup: + """Enter the task group context and allow starting new tasks.""" + + @abstractmethod + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + """Exit the task group context waiting for all tasks to finish.""" diff --git a/contrib/python/anyio/anyio/abc/_testing.py b/contrib/python/anyio/anyio/abc/_testing.py new file mode 100644 index 0000000000..ee2cff5cc3 --- /dev/null +++ b/contrib/python/anyio/anyio/abc/_testing.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import types +from abc import ABCMeta, abstractmethod +from collections.abc import AsyncGenerator, Iterable +from typing import Any, Callable, Coroutine, TypeVar + +_T = TypeVar("_T") + + +class TestRunner(metaclass=ABCMeta): + """ + Encapsulates a running event loop. Every call made through this object will use the same event + loop. + """ + + def __enter__(self) -> TestRunner: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + ) -> bool | None: + self.close() + return None + + @abstractmethod + def close(self) -> None: + """Close the event loop.""" + + @abstractmethod + def run_asyncgen_fixture( + self, + fixture_func: Callable[..., AsyncGenerator[_T, Any]], + kwargs: dict[str, Any], + ) -> Iterable[_T]: + """ + Run an async generator fixture. + + :param fixture_func: the fixture function + :param kwargs: keyword arguments to call the fixture function with + :return: an iterator yielding the value yielded from the async generator + """ + + @abstractmethod + def run_fixture( + self, + fixture_func: Callable[..., Coroutine[Any, Any, _T]], + kwargs: dict[str, Any], + ) -> _T: + """ + Run an async fixture. + + :param fixture_func: the fixture function + :param kwargs: keyword arguments to call the fixture function with + :return: the return value of the fixture function + """ + + @abstractmethod + def run_test( + self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] + ) -> None: + """ + Run an async test function. + + :param test_func: the test function + :param kwargs: keyword arguments to call the test function with + """ diff --git a/contrib/python/anyio/anyio/from_thread.py b/contrib/python/anyio/anyio/from_thread.py new file mode 100644 index 0000000000..6b76861c70 --- /dev/null +++ b/contrib/python/anyio/anyio/from_thread.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +import threading +from asyncio import iscoroutine +from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from contextlib import AbstractContextManager, contextmanager +from types import TracebackType +from typing import ( + Any, + AsyncContextManager, + Awaitable, + Callable, + ContextManager, + Generator, + Generic, + Iterable, + TypeVar, + cast, + overload, +) +from warnings import warn + +from ._core import _eventloop +from ._core._eventloop import get_asynclib, get_cancelled_exc_class, threadlocals +from ._core._synchronization import Event +from ._core._tasks import CancelScope, create_task_group +from .abc._tasks import TaskStatus + +T_Retval = TypeVar("T_Retval") +T_co = TypeVar("T_co") + + +def run(func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: + """ + Call a coroutine function from a worker thread. + + :param func: a coroutine function + :param args: positional arguments for the callable + :return: the return value of the coroutine function + + """ + try: + asynclib = threadlocals.current_async_module + except AttributeError: + raise RuntimeError("This function can only be run from an AnyIO worker thread") + + return asynclib.run_async_from_thread(func, *args) + + +def run_async_from_thread( + func: Callable[..., Awaitable[T_Retval]], *args: object +) -> T_Retval: + warn( + "run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead", + DeprecationWarning, + ) + return run(func, *args) + + +def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval: + """ + Call a function in the event loop thread from a worker thread. + + :param func: a callable + :param args: positional arguments for the callable + :return: the return value of the callable + + """ + try: + asynclib = threadlocals.current_async_module + except AttributeError: + raise RuntimeError("This function can only be run from an AnyIO worker thread") + + return asynclib.run_sync_from_thread(func, *args) + + +def run_sync_from_thread(func: Callable[..., T_Retval], *args: object) -> T_Retval: + warn( + "run_sync_from_thread() has been deprecated, use anyio.from_thread.run_sync() instead", + DeprecationWarning, + ) + return run_sync(func, *args) + + +class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager): + _enter_future: Future + _exit_future: Future + _exit_event: Event + _exit_exc_info: tuple[ + type[BaseException] | None, BaseException | None, TracebackType | None + ] = (None, None, None) + + def __init__(self, async_cm: AsyncContextManager[T_co], portal: BlockingPortal): + self._async_cm = async_cm + self._portal = portal + + async def run_async_cm(self) -> bool | None: + try: + self._exit_event = Event() + value = await self._async_cm.__aenter__() + except BaseException as exc: + self._enter_future.set_exception(exc) + raise + else: + self._enter_future.set_result(value) + + try: + # Wait for the sync context manager to exit. + # This next statement can raise `get_cancelled_exc_class()` if + # something went wrong in a task group in this async context + # manager. + await self._exit_event.wait() + finally: + # In case of cancellation, it could be that we end up here before + # `_BlockingAsyncContextManager.__exit__` is called, and an + # `_exit_exc_info` has been set. + result = await self._async_cm.__aexit__(*self._exit_exc_info) + return result + + def __enter__(self) -> T_co: + self._enter_future = Future() + self._exit_future = self._portal.start_task_soon(self.run_async_cm) + cm = self._enter_future.result() + return cast(T_co, cm) + + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: + self._exit_exc_info = __exc_type, __exc_value, __traceback + self._portal.call(self._exit_event.set) + return self._exit_future.result() + + +class _BlockingPortalTaskStatus(TaskStatus): + def __init__(self, future: Future): + self._future = future + + def started(self, value: object = None) -> None: + self._future.set_result(value) + + +class BlockingPortal: + """An object that lets external threads run code in an asynchronous event loop.""" + + def __new__(cls) -> BlockingPortal: + return get_asynclib().BlockingPortal() + + def __init__(self) -> None: + self._event_loop_thread_id: int | None = threading.get_ident() + self._stop_event = Event() + self._task_group = create_task_group() + self._cancelled_exc_class = get_cancelled_exc_class() + + async def __aenter__(self) -> BlockingPortal: + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.stop() + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + def _check_running(self) -> None: + if self._event_loop_thread_id is None: + raise RuntimeError("This portal is not running") + if self._event_loop_thread_id == threading.get_ident(): + raise RuntimeError( + "This method cannot be called from the event loop thread" + ) + + async def sleep_until_stopped(self) -> None: + """Sleep until :meth:`stop` is called.""" + await self._stop_event.wait() + + async def stop(self, cancel_remaining: bool = False) -> None: + """ + Signal the portal to shut down. + + This marks the portal as no longer accepting new calls and exits from + :meth:`sleep_until_stopped`. + + :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` to let them + finish before returning + + """ + self._event_loop_thread_id = None + self._stop_event.set() + if cancel_remaining: + self._task_group.cancel_scope.cancel() + + async def _call_func( + self, func: Callable, args: tuple, kwargs: dict[str, Any], future: Future + ) -> None: + def callback(f: Future) -> None: + if f.cancelled() and self._event_loop_thread_id not in ( + None, + threading.get_ident(), + ): + self.call(scope.cancel) + + try: + retval = func(*args, **kwargs) + if iscoroutine(retval): + with CancelScope() as scope: + if future.cancelled(): + scope.cancel() + else: + future.add_done_callback(callback) + + retval = await retval + except self._cancelled_exc_class: + future.cancel() + except BaseException as exc: + if not future.cancelled(): + future.set_exception(exc) + + # Let base exceptions fall through + if not isinstance(exc, Exception): + raise + else: + if not future.cancelled(): + future.set_result(retval) + finally: + scope = None # type: ignore[assignment] + + def _spawn_task_from_thread( + self, + func: Callable, + args: tuple, + kwargs: dict[str, Any], + name: object, + future: Future, + ) -> None: + """ + Spawn a new task using the given callable. + + Implementors must ensure that the future is resolved when the task finishes. + + :param func: a callable + :param args: positional arguments to be passed to the callable + :param kwargs: keyword arguments to be passed to the callable + :param name: name of the task (will be coerced to a string if not ``None``) + :param future: a future that will resolve to the return value of the callable, or the + exception raised during its execution + + """ + raise NotImplementedError + + @overload + def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object) -> T_Retval: + ... + + @overload + def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval: + ... + + def call( + self, func: Callable[..., Awaitable[T_Retval] | T_Retval], *args: object + ) -> T_Retval: + """ + Call the given function in the event loop thread. + + If the callable returns a coroutine object, it is awaited on. + + :param func: any callable + :raises RuntimeError: if the portal is not running or if this method is called from within + the event loop thread + + """ + return cast(T_Retval, self.start_task_soon(func, *args).result()) + + @overload + def spawn_task( + self, + func: Callable[..., Awaitable[T_Retval]], + *args: object, + name: object = None, + ) -> Future[T_Retval]: + ... + + @overload + def spawn_task( + self, func: Callable[..., T_Retval], *args: object, name: object = None + ) -> Future[T_Retval]: + ... + + def spawn_task( + self, + func: Callable[..., Awaitable[T_Retval] | T_Retval], + *args: object, + name: object = None, + ) -> Future[T_Retval]: + """ + Start a task in the portal's task group. + + :param func: the target coroutine function + :param args: positional arguments passed to ``func`` + :param name: name of the task (will be coerced to a string if not ``None``) + :return: a future that resolves with the return value of the callable if the task completes + successfully, or with the exception raised in the task + :raises RuntimeError: if the portal is not running or if this method is called from within + the event loop thread + + .. versionadded:: 2.1 + .. deprecated:: 3.0 + Use :meth:`start_task_soon` instead. If your code needs AnyIO 2 compatibility, you + can keep using this until AnyIO 4. + + """ + warn( + "spawn_task() is deprecated -- use start_task_soon() instead", + DeprecationWarning, + ) + return self.start_task_soon(func, *args, name=name) # type: ignore[arg-type] + + @overload + def start_task_soon( + self, + func: Callable[..., Awaitable[T_Retval]], + *args: object, + name: object = None, + ) -> Future[T_Retval]: + ... + + @overload + def start_task_soon( + self, func: Callable[..., T_Retval], *args: object, name: object = None + ) -> Future[T_Retval]: + ... + + def start_task_soon( + self, + func: Callable[..., Awaitable[T_Retval] | T_Retval], + *args: object, + name: object = None, + ) -> Future[T_Retval]: + """ + Start a task in the portal's task group. + + The task will be run inside a cancel scope which can be cancelled by cancelling the + returned future. + + :param func: the target function + :param args: positional arguments passed to ``func`` + :param name: name of the task (will be coerced to a string if not ``None``) + :return: a future that resolves with the return value of the callable if the + task completes successfully, or with the exception raised in the task + :raises RuntimeError: if the portal is not running or if this method is called + from within the event loop thread + :rtype: concurrent.futures.Future[T_Retval] + + .. versionadded:: 3.0 + + """ + self._check_running() + f: Future = Future() + self._spawn_task_from_thread(func, args, {}, name, f) + return f + + def start_task( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> tuple[Future[Any], Any]: + """ + Start a task in the portal's task group and wait until it signals for readiness. + + This method works the same way as :meth:`.abc.TaskGroup.start`. + + :param func: the target function + :param args: positional arguments passed to ``func`` + :param name: name of the task (will be coerced to a string if not ``None``) + :return: a tuple of (future, task_status_value) where the ``task_status_value`` + is the value passed to ``task_status.started()`` from within the target + function + :rtype: tuple[concurrent.futures.Future[Any], Any] + + .. versionadded:: 3.0 + + """ + + def task_done(future: Future) -> None: + if not task_status_future.done(): + if future.cancelled(): + task_status_future.cancel() + elif future.exception(): + task_status_future.set_exception(future.exception()) + else: + exc = RuntimeError( + "Task exited without calling task_status.started()" + ) + task_status_future.set_exception(exc) + + self._check_running() + task_status_future: Future = Future() + task_status = _BlockingPortalTaskStatus(task_status_future) + f: Future = Future() + f.add_done_callback(task_done) + self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f) + return f, task_status_future.result() + + def wrap_async_context_manager( + self, cm: AsyncContextManager[T_co] + ) -> ContextManager[T_co]: + """ + Wrap an async context manager as a synchronous context manager via this portal. + + Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping in the + middle until the synchronous context manager exits. + + :param cm: an asynchronous context manager + :return: a synchronous context manager + + .. versionadded:: 2.1 + + """ + return _BlockingAsyncContextManager(cm, self) + + +def create_blocking_portal() -> BlockingPortal: + """ + Create a portal for running functions in the event loop thread from external threads. + + Use this function in asynchronous code when you need to allow external threads access to the + event loop where your asynchronous code is currently running. + + .. deprecated:: 3.0 + Use :class:`.BlockingPortal` directly. + + """ + warn( + "create_blocking_portal() has been deprecated -- use anyio.from_thread.BlockingPortal() " + "directly", + DeprecationWarning, + ) + return BlockingPortal() + + +@contextmanager +def start_blocking_portal( + backend: str = "asyncio", backend_options: dict[str, Any] | None = None +) -> Generator[BlockingPortal, Any, None]: + """ + Start a new event loop in a new thread and run a blocking portal in its main task. + + The parameters are the same as for :func:`~anyio.run`. + + :param backend: name of the backend + :param backend_options: backend options + :return: a context manager that yields a blocking portal + + .. versionchanged:: 3.0 + Usage as a context manager is now required. + + """ + + async def run_portal() -> None: + async with BlockingPortal() as portal_: + if future.set_running_or_notify_cancel(): + future.set_result(portal_) + await portal_.sleep_until_stopped() + + future: Future[BlockingPortal] = Future() + with ThreadPoolExecutor(1) as executor: + run_future = executor.submit( + _eventloop.run, + run_portal, # type: ignore[arg-type] + backend=backend, + backend_options=backend_options, + ) + try: + wait( + cast(Iterable[Future], [run_future, future]), + return_when=FIRST_COMPLETED, + ) + except BaseException: + future.cancel() + run_future.cancel() + raise + + if future.done(): + portal = future.result() + cancel_remaining_tasks = False + try: + yield portal + except BaseException: + cancel_remaining_tasks = True + raise + finally: + try: + portal.call(portal.stop, cancel_remaining_tasks) + except RuntimeError: + pass + + run_future.result() diff --git a/contrib/python/anyio/anyio/lowlevel.py b/contrib/python/anyio/anyio/lowlevel.py new file mode 100644 index 0000000000..0e908c6547 --- /dev/null +++ b/contrib/python/anyio/anyio/lowlevel.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import enum +import sys +from dataclasses import dataclass +from typing import Any, Generic, TypeVar, overload +from weakref import WeakKeyDictionary + +from ._core._eventloop import get_asynclib + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +T = TypeVar("T") +D = TypeVar("D") + + +async def checkpoint() -> None: + """ + Check for cancellation and allow the scheduler to switch to another task. + + Equivalent to (but more efficient than):: + + await checkpoint_if_cancelled() + await cancel_shielded_checkpoint() + + + .. versionadded:: 3.0 + + """ + await get_asynclib().checkpoint() + + +async def checkpoint_if_cancelled() -> None: + """ + Enter a checkpoint if the enclosing cancel scope has been cancelled. + + This does not allow the scheduler to switch to a different task. + + .. versionadded:: 3.0 + + """ + await get_asynclib().checkpoint_if_cancelled() + + +async def cancel_shielded_checkpoint() -> None: + """ + Allow the scheduler to switch to another task but without checking for cancellation. + + Equivalent to (but potentially more efficient than):: + + with CancelScope(shield=True): + await checkpoint() + + + .. versionadded:: 3.0 + + """ + await get_asynclib().cancel_shielded_checkpoint() + + +def current_token() -> object: + """Return a backend specific token object that can be used to get back to the event loop.""" + return get_asynclib().current_token() + + +_run_vars: WeakKeyDictionary[Any, dict[str, Any]] = WeakKeyDictionary() +_token_wrappers: dict[Any, _TokenWrapper] = {} + + +@dataclass(frozen=True) +class _TokenWrapper: + __slots__ = "_token", "__weakref__" + _token: object + + +class _NoValueSet(enum.Enum): + NO_VALUE_SET = enum.auto() + + +class RunvarToken(Generic[T]): + __slots__ = "_var", "_value", "_redeemed" + + def __init__(self, var: RunVar[T], value: T | Literal[_NoValueSet.NO_VALUE_SET]): + self._var = var + self._value: T | Literal[_NoValueSet.NO_VALUE_SET] = value + self._redeemed = False + + +class RunVar(Generic[T]): + """ + Like a :class:`~contextvars.ContextVar`, except scoped to the running event loop. + """ + + __slots__ = "_name", "_default" + + NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET + + _token_wrappers: set[_TokenWrapper] = set() + + def __init__( + self, + name: str, + default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET, + ): + self._name = name + self._default = default + + @property + def _current_vars(self) -> dict[str, T]: + token = current_token() + while True: + try: + return _run_vars[token] + except TypeError: + # Happens when token isn't weak referable (TrioToken). + # This workaround does mean that some memory will leak on Trio until the problem + # is fixed on their end. + token = _TokenWrapper(token) + self._token_wrappers.add(token) + except KeyError: + run_vars = _run_vars[token] = {} + return run_vars + + @overload + def get(self, default: D) -> T | D: + ... + + @overload + def get(self) -> T: + ... + + def get( + self, default: D | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET + ) -> T | D: + try: + return self._current_vars[self._name] + except KeyError: + if default is not RunVar.NO_VALUE_SET: + return default + elif self._default is not RunVar.NO_VALUE_SET: + return self._default + + raise LookupError( + f'Run variable "{self._name}" has no value and no default set' + ) + + def set(self, value: T) -> RunvarToken[T]: + current_vars = self._current_vars + token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET)) + current_vars[self._name] = value + return token + + def reset(self, token: RunvarToken[T]) -> None: + if token._var is not self: + raise ValueError("This token does not belong to this RunVar") + + if token._redeemed: + raise ValueError("This token has already been used") + + if token._value is _NoValueSet.NO_VALUE_SET: + try: + del self._current_vars[self._name] + except KeyError: + pass + else: + self._current_vars[self._name] = token._value + + token._redeemed = True + + def __repr__(self) -> str: + return f"<RunVar name={self._name!r}>" diff --git a/contrib/python/anyio/anyio/py.typed b/contrib/python/anyio/anyio/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/anyio/anyio/py.typed diff --git a/contrib/python/anyio/anyio/pytest_plugin.py b/contrib/python/anyio/anyio/pytest_plugin.py new file mode 100644 index 0000000000..044ce6914d --- /dev/null +++ b/contrib/python/anyio/anyio/pytest_plugin.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from contextlib import contextmanager +from inspect import isasyncgenfunction, iscoroutinefunction +from typing import Any, Dict, Generator, Tuple, cast + +import pytest +import sniffio + +from ._core._eventloop import get_all_backends, get_asynclib +from .abc import TestRunner + +_current_runner: TestRunner | None = None + + +def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]: + if isinstance(backend, str): + return backend, {} + elif isinstance(backend, tuple) and len(backend) == 2: + if isinstance(backend[0], str) and isinstance(backend[1], dict): + return cast(Tuple[str, Dict[str, Any]], backend) + + raise TypeError("anyio_backend must be either a string or tuple of (string, dict)") + + +@contextmanager +def get_runner( + backend_name: str, backend_options: dict[str, Any] +) -> Generator[TestRunner, object, None]: + global _current_runner + if _current_runner: + yield _current_runner + return + + asynclib = get_asynclib(backend_name) + token = None + if sniffio.current_async_library_cvar.get(None) is None: + # Since we're in control of the event loop, we can cache the name of the async library + token = sniffio.current_async_library_cvar.set(backend_name) + + try: + backend_options = backend_options or {} + with asynclib.TestRunner(**backend_options) as runner: + _current_runner = runner + yield runner + finally: + _current_runner = None + if token: + sniffio.current_async_library_cvar.reset(token) + + +def pytest_configure(config: Any) -> None: + config.addinivalue_line( + "markers", + "anyio: mark the (coroutine function) test to be run " + "asynchronously via anyio.", + ) + + +def pytest_fixture_setup(fixturedef: Any, request: Any) -> None: + def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def] + backend_name, backend_options = extract_backend_and_options(anyio_backend) + if has_backend_arg: + kwargs["anyio_backend"] = anyio_backend + + with get_runner(backend_name, backend_options) as runner: + if isasyncgenfunction(func): + yield from runner.run_asyncgen_fixture(func, kwargs) + else: + yield runner.run_fixture(func, kwargs) + + # Only apply this to coroutine functions and async generator functions in requests that involve + # the anyio_backend fixture + func = fixturedef.func + if isasyncgenfunction(func) or iscoroutinefunction(func): + if "anyio_backend" in request.fixturenames: + has_backend_arg = "anyio_backend" in fixturedef.argnames + fixturedef.func = wrapper + if not has_backend_arg: + fixturedef.argnames += ("anyio_backend",) + + +@pytest.hookimpl(tryfirst=True) +def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None: + if collector.istestfunction(obj, name): + inner_func = obj.hypothesis.inner_test if hasattr(obj, "hypothesis") else obj + if iscoroutinefunction(inner_func): + marker = collector.get_closest_marker("anyio") + own_markers = getattr(obj, "pytestmark", ()) + if marker or any(marker.name == "anyio" for marker in own_markers): + pytest.mark.usefixtures("anyio_backend")(obj) + + +@pytest.hookimpl(tryfirst=True) +def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None: + def run_with_hypothesis(**kwargs: Any) -> None: + with get_runner(backend_name, backend_options) as runner: + runner.run_test(original_func, kwargs) + + backend = pyfuncitem.funcargs.get("anyio_backend") + if backend: + backend_name, backend_options = extract_backend_and_options(backend) + + if hasattr(pyfuncitem.obj, "hypothesis"): + # Wrap the inner test function unless it's already wrapped + original_func = pyfuncitem.obj.hypothesis.inner_test + if original_func.__qualname__ != run_with_hypothesis.__qualname__: + if iscoroutinefunction(original_func): + pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis + + return None + + if iscoroutinefunction(pyfuncitem.obj): + funcargs = pyfuncitem.funcargs + testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames} + with get_runner(backend_name, backend_options) as runner: + runner.run_test(pyfuncitem.obj, testargs) + + return True + + return None + + +@pytest.fixture(params=get_all_backends()) +def anyio_backend(request: Any) -> Any: + return request.param + + +@pytest.fixture +def anyio_backend_name(anyio_backend: Any) -> str: + if isinstance(anyio_backend, str): + return anyio_backend + else: + return anyio_backend[0] + + +@pytest.fixture +def anyio_backend_options(anyio_backend: Any) -> dict[str, Any]: + if isinstance(anyio_backend, str): + return {} + else: + return anyio_backend[1] diff --git a/contrib/python/anyio/anyio/streams/__init__.py b/contrib/python/anyio/anyio/streams/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/anyio/anyio/streams/__init__.py diff --git a/contrib/python/anyio/anyio/streams/buffered.py b/contrib/python/anyio/anyio/streams/buffered.py new file mode 100644 index 0000000000..11474c16a9 --- /dev/null +++ b/contrib/python/anyio/anyio/streams/buffered.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Mapping + +from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead +from ..abc import AnyByteReceiveStream, ByteReceiveStream + + +@dataclass(eq=False) +class BufferedByteReceiveStream(ByteReceiveStream): + """ + Wraps any bytes-based receive stream and uses a buffer to provide sophisticated receiving + capabilities in the form of a byte stream. + """ + + receive_stream: AnyByteReceiveStream + _buffer: bytearray = field(init=False, default_factory=bytearray) + _closed: bool = field(init=False, default=False) + + async def aclose(self) -> None: + await self.receive_stream.aclose() + self._closed = True + + @property + def buffer(self) -> bytes: + """The bytes currently in the buffer.""" + return bytes(self._buffer) + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return self.receive_stream.extra_attributes + + async def receive(self, max_bytes: int = 65536) -> bytes: + if self._closed: + raise ClosedResourceError + + if self._buffer: + chunk = bytes(self._buffer[:max_bytes]) + del self._buffer[:max_bytes] + return chunk + elif isinstance(self.receive_stream, ByteReceiveStream): + return await self.receive_stream.receive(max_bytes) + else: + # With a bytes-oriented object stream, we need to handle any surplus bytes we get from + # the receive() call + chunk = await self.receive_stream.receive() + if len(chunk) > max_bytes: + # Save the surplus bytes in the buffer + self._buffer.extend(chunk[max_bytes:]) + return chunk[:max_bytes] + else: + return chunk + + async def receive_exactly(self, nbytes: int) -> bytes: + """ + Read exactly the given amount of bytes from the stream. + + :param nbytes: the number of bytes to read + :return: the bytes read + :raises ~anyio.IncompleteRead: if the stream was closed before the requested + amount of bytes could be read from the stream + + """ + while True: + remaining = nbytes - len(self._buffer) + if remaining <= 0: + retval = self._buffer[:nbytes] + del self._buffer[:nbytes] + return bytes(retval) + + try: + if isinstance(self.receive_stream, ByteReceiveStream): + chunk = await self.receive_stream.receive(remaining) + else: + chunk = await self.receive_stream.receive() + except EndOfStream as exc: + raise IncompleteRead from exc + + self._buffer.extend(chunk) + + async def receive_until(self, delimiter: bytes, max_bytes: int) -> bytes: + """ + Read from the stream until the delimiter is found or max_bytes have been read. + + :param delimiter: the marker to look for in the stream + :param max_bytes: maximum number of bytes that will be read before raising + :exc:`~anyio.DelimiterNotFound` + :return: the bytes read (not including the delimiter) + :raises ~anyio.IncompleteRead: if the stream was closed before the delimiter + was found + :raises ~anyio.DelimiterNotFound: if the delimiter is not found within the + bytes read up to the maximum allowed + + """ + delimiter_size = len(delimiter) + offset = 0 + while True: + # Check if the delimiter can be found in the current buffer + index = self._buffer.find(delimiter, offset) + if index >= 0: + found = self._buffer[:index] + del self._buffer[: index + len(delimiter) :] + return bytes(found) + + # Check if the buffer is already at or over the limit + if len(self._buffer) >= max_bytes: + raise DelimiterNotFound(max_bytes) + + # Read more data into the buffer from the socket + try: + data = await self.receive_stream.receive() + except EndOfStream as exc: + raise IncompleteRead from exc + + # Move the offset forward and add the new data to the buffer + offset = max(len(self._buffer) - delimiter_size + 1, 0) + self._buffer.extend(data) diff --git a/contrib/python/anyio/anyio/streams/file.py b/contrib/python/anyio/anyio/streams/file.py new file mode 100644 index 0000000000..2840d40ab6 --- /dev/null +++ b/contrib/python/anyio/anyio/streams/file.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from io import SEEK_SET, UnsupportedOperation +from os import PathLike +from pathlib import Path +from typing import Any, BinaryIO, Callable, Mapping, cast + +from .. import ( + BrokenResourceError, + ClosedResourceError, + EndOfStream, + TypedAttributeSet, + to_thread, + typed_attribute, +) +from ..abc import ByteReceiveStream, ByteSendStream + + +class FileStreamAttribute(TypedAttributeSet): + #: the open file descriptor + file: BinaryIO = typed_attribute() + #: the path of the file on the file system, if available (file must be a real file) + path: Path = typed_attribute() + #: the file number, if available (file must be a real file or a TTY) + fileno: int = typed_attribute() + + +class _BaseFileStream: + def __init__(self, file: BinaryIO): + self._file = file + + async def aclose(self) -> None: + await to_thread.run_sync(self._file.close) + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + attributes: dict[Any, Callable[[], Any]] = { + FileStreamAttribute.file: lambda: self._file, + } + + if hasattr(self._file, "name"): + attributes[FileStreamAttribute.path] = lambda: Path(self._file.name) + + try: + self._file.fileno() + except UnsupportedOperation: + pass + else: + attributes[FileStreamAttribute.fileno] = lambda: self._file.fileno() + + return attributes + + +class FileReadStream(_BaseFileStream, ByteReceiveStream): + """ + A byte stream that reads from a file in the file system. + + :param file: a file that has been opened for reading in binary mode + + .. versionadded:: 3.0 + """ + + @classmethod + async def from_path(cls, path: str | PathLike[str]) -> FileReadStream: + """ + Create a file read stream by opening the given file. + + :param path: path of the file to read from + + """ + file = await to_thread.run_sync(Path(path).open, "rb") + return cls(cast(BinaryIO, file)) + + async def receive(self, max_bytes: int = 65536) -> bytes: + try: + data = await to_thread.run_sync(self._file.read, max_bytes) + except ValueError: + raise ClosedResourceError from None + except OSError as exc: + raise BrokenResourceError from exc + + if data: + return data + else: + raise EndOfStream + + async def seek(self, position: int, whence: int = SEEK_SET) -> int: + """ + Seek the file to the given position. + + .. seealso:: :meth:`io.IOBase.seek` + + .. note:: Not all file descriptors are seekable. + + :param position: position to seek the file to + :param whence: controls how ``position`` is interpreted + :return: the new absolute position + :raises OSError: if the file is not seekable + + """ + return await to_thread.run_sync(self._file.seek, position, whence) + + async def tell(self) -> int: + """ + Return the current stream position. + + .. note:: Not all file descriptors are seekable. + + :return: the current absolute position + :raises OSError: if the file is not seekable + + """ + return await to_thread.run_sync(self._file.tell) + + +class FileWriteStream(_BaseFileStream, ByteSendStream): + """ + A byte stream that writes to a file in the file system. + + :param file: a file that has been opened for writing in binary mode + + .. versionadded:: 3.0 + """ + + @classmethod + async def from_path( + cls, path: str | PathLike[str], append: bool = False + ) -> FileWriteStream: + """ + Create a file write stream by opening the given file for writing. + + :param path: path of the file to write to + :param append: if ``True``, open the file for appending; if ``False``, any existing file + at the given path will be truncated + + """ + mode = "ab" if append else "wb" + file = await to_thread.run_sync(Path(path).open, mode) + return cls(cast(BinaryIO, file)) + + async def send(self, item: bytes) -> None: + try: + await to_thread.run_sync(self._file.write, item) + except ValueError: + raise ClosedResourceError from None + except OSError as exc: + raise BrokenResourceError from exc diff --git a/contrib/python/anyio/anyio/streams/memory.py b/contrib/python/anyio/anyio/streams/memory.py new file mode 100644 index 0000000000..a6499c13ff --- /dev/null +++ b/contrib/python/anyio/anyio/streams/memory.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +from collections import OrderedDict, deque +from dataclasses import dataclass, field +from types import TracebackType +from typing import Generic, NamedTuple, TypeVar + +from .. import ( + BrokenResourceError, + ClosedResourceError, + EndOfStream, + WouldBlock, + get_cancelled_exc_class, +) +from .._core._compat import DeprecatedAwaitable +from ..abc import Event, ObjectReceiveStream, ObjectSendStream +from ..lowlevel import checkpoint + +T_Item = TypeVar("T_Item") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +class MemoryObjectStreamStatistics(NamedTuple): + current_buffer_used: int #: number of items stored in the buffer + #: maximum number of items that can be stored on this stream (or :data:`math.inf`) + max_buffer_size: float + open_send_streams: int #: number of unclosed clones of the send stream + open_receive_streams: int #: number of unclosed clones of the receive stream + tasks_waiting_send: int #: number of tasks blocked on :meth:`MemoryObjectSendStream.send` + #: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive` + tasks_waiting_receive: int + + +@dataclass(eq=False) +class MemoryObjectStreamState(Generic[T_Item]): + max_buffer_size: float = field() + buffer: deque[T_Item] = field(init=False, default_factory=deque) + open_send_channels: int = field(init=False, default=0) + open_receive_channels: int = field(init=False, default=0) + waiting_receivers: OrderedDict[Event, list[T_Item]] = field( + init=False, default_factory=OrderedDict + ) + waiting_senders: OrderedDict[Event, T_Item] = field( + init=False, default_factory=OrderedDict + ) + + def statistics(self) -> MemoryObjectStreamStatistics: + return MemoryObjectStreamStatistics( + len(self.buffer), + self.max_buffer_size, + self.open_send_channels, + self.open_receive_channels, + len(self.waiting_senders), + len(self.waiting_receivers), + ) + + +@dataclass(eq=False) +class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]): + _state: MemoryObjectStreamState[T_co] + _closed: bool = field(init=False, default=False) + + def __post_init__(self) -> None: + self._state.open_receive_channels += 1 + + def receive_nowait(self) -> T_co: + """ + Receive the next item if it can be done without waiting. + + :return: the received item + :raises ~anyio.ClosedResourceError: if this send stream has been closed + :raises ~anyio.EndOfStream: if the buffer is empty and this stream has been + closed from the sending end + :raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks + waiting to send + + """ + if self._closed: + raise ClosedResourceError + + if self._state.waiting_senders: + # Get the item from the next sender + send_event, item = self._state.waiting_senders.popitem(last=False) + self._state.buffer.append(item) + send_event.set() + + if self._state.buffer: + return self._state.buffer.popleft() + elif not self._state.open_send_channels: + raise EndOfStream + + raise WouldBlock + + async def receive(self) -> T_co: + await checkpoint() + try: + return self.receive_nowait() + except WouldBlock: + # Add ourselves in the queue + receive_event = Event() + container: list[T_co] = [] + self._state.waiting_receivers[receive_event] = container + + try: + await receive_event.wait() + except get_cancelled_exc_class(): + # Ignore the immediate cancellation if we already received an item, so as not to + # lose it + if not container: + raise + finally: + self._state.waiting_receivers.pop(receive_event, None) + + if container: + return container[0] + else: + raise EndOfStream + + def clone(self) -> MemoryObjectReceiveStream[T_co]: + """ + Create a clone of this receive stream. + + Each clone can be closed separately. Only when all clones have been closed will the + receiving end of the memory stream be considered closed by the sending ends. + + :return: the cloned stream + + """ + if self._closed: + raise ClosedResourceError + + return MemoryObjectReceiveStream(_state=self._state) + + def close(self) -> None: + """ + Close the stream. + + This works the exact same way as :meth:`aclose`, but is provided as a special case for the + benefit of synchronous callbacks. + + """ + if not self._closed: + self._closed = True + self._state.open_receive_channels -= 1 + if self._state.open_receive_channels == 0: + send_events = list(self._state.waiting_senders.keys()) + for event in send_events: + event.set() + + async def aclose(self) -> None: + self.close() + + def statistics(self) -> MemoryObjectStreamStatistics: + """ + Return statistics about the current state of this stream. + + .. versionadded:: 3.0 + """ + return self._state.statistics() + + def __enter__(self) -> MemoryObjectReceiveStream[T_co]: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + +@dataclass(eq=False) +class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]): + _state: MemoryObjectStreamState[T_contra] + _closed: bool = field(init=False, default=False) + + def __post_init__(self) -> None: + self._state.open_send_channels += 1 + + def send_nowait(self, item: T_contra) -> DeprecatedAwaitable: + """ + Send an item immediately if it can be done without waiting. + + :param item: the item to send + :raises ~anyio.ClosedResourceError: if this send stream has been closed + :raises ~anyio.BrokenResourceError: if the stream has been closed from the + receiving end + :raises ~anyio.WouldBlock: if the buffer is full and there are no tasks waiting + to receive + + """ + if self._closed: + raise ClosedResourceError + if not self._state.open_receive_channels: + raise BrokenResourceError + + if self._state.waiting_receivers: + receive_event, container = self._state.waiting_receivers.popitem(last=False) + container.append(item) + receive_event.set() + elif len(self._state.buffer) < self._state.max_buffer_size: + self._state.buffer.append(item) + else: + raise WouldBlock + + return DeprecatedAwaitable(self.send_nowait) + + async def send(self, item: T_contra) -> None: + await checkpoint() + try: + self.send_nowait(item) + except WouldBlock: + # Wait until there's someone on the receiving end + send_event = Event() + self._state.waiting_senders[send_event] = item + try: + await send_event.wait() + except BaseException: + self._state.waiting_senders.pop(send_event, None) # type: ignore[arg-type] + raise + + if self._state.waiting_senders.pop(send_event, None): # type: ignore[arg-type] + raise BrokenResourceError + + def clone(self) -> MemoryObjectSendStream[T_contra]: + """ + Create a clone of this send stream. + + Each clone can be closed separately. Only when all clones have been closed will the + sending end of the memory stream be considered closed by the receiving ends. + + :return: the cloned stream + + """ + if self._closed: + raise ClosedResourceError + + return MemoryObjectSendStream(_state=self._state) + + def close(self) -> None: + """ + Close the stream. + + This works the exact same way as :meth:`aclose`, but is provided as a special case for the + benefit of synchronous callbacks. + + """ + if not self._closed: + self._closed = True + self._state.open_send_channels -= 1 + if self._state.open_send_channels == 0: + receive_events = list(self._state.waiting_receivers.keys()) + self._state.waiting_receivers.clear() + for event in receive_events: + event.set() + + async def aclose(self) -> None: + self.close() + + def statistics(self) -> MemoryObjectStreamStatistics: + """ + Return statistics about the current state of this stream. + + .. versionadded:: 3.0 + """ + return self._state.statistics() + + def __enter__(self) -> MemoryObjectSendStream[T_contra]: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() diff --git a/contrib/python/anyio/anyio/streams/stapled.py b/contrib/python/anyio/anyio/streams/stapled.py new file mode 100644 index 0000000000..1b2862e3ea --- /dev/null +++ b/contrib/python/anyio/anyio/streams/stapled.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Generic, Mapping, Sequence, TypeVar + +from ..abc import ( + ByteReceiveStream, + ByteSendStream, + ByteStream, + Listener, + ObjectReceiveStream, + ObjectSendStream, + ObjectStream, + TaskGroup, +) + +T_Item = TypeVar("T_Item") +T_Stream = TypeVar("T_Stream") + + +@dataclass(eq=False) +class StapledByteStream(ByteStream): + """ + Combines two byte streams into a single, bidirectional byte stream. + + Extra attributes will be provided from both streams, with the receive stream providing the + values in case of a conflict. + + :param ByteSendStream send_stream: the sending byte stream + :param ByteReceiveStream receive_stream: the receiving byte stream + """ + + send_stream: ByteSendStream + receive_stream: ByteReceiveStream + + async def receive(self, max_bytes: int = 65536) -> bytes: + return await self.receive_stream.receive(max_bytes) + + async def send(self, item: bytes) -> None: + await self.send_stream.send(item) + + async def send_eof(self) -> None: + await self.send_stream.aclose() + + async def aclose(self) -> None: + await self.send_stream.aclose() + await self.receive_stream.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self.send_stream.extra_attributes, + **self.receive_stream.extra_attributes, + } + + +@dataclass(eq=False) +class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]): + """ + Combines two object streams into a single, bidirectional object stream. + + Extra attributes will be provided from both streams, with the receive stream providing the + values in case of a conflict. + + :param ObjectSendStream send_stream: the sending object stream + :param ObjectReceiveStream receive_stream: the receiving object stream + """ + + send_stream: ObjectSendStream[T_Item] + receive_stream: ObjectReceiveStream[T_Item] + + async def receive(self) -> T_Item: + return await self.receive_stream.receive() + + async def send(self, item: T_Item) -> None: + await self.send_stream.send(item) + + async def send_eof(self) -> None: + await self.send_stream.aclose() + + async def aclose(self) -> None: + await self.send_stream.aclose() + await self.receive_stream.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self.send_stream.extra_attributes, + **self.receive_stream.extra_attributes, + } + + +@dataclass(eq=False) +class MultiListener(Generic[T_Stream], Listener[T_Stream]): + """ + Combines multiple listeners into one, serving connections from all of them at once. + + Any MultiListeners in the given collection of listeners will have their listeners moved into + this one. + + Extra attributes are provided from each listener, with each successive listener overriding any + conflicting attributes from the previous one. + + :param listeners: listeners to serve + :type listeners: Sequence[Listener[T_Stream]] + """ + + listeners: Sequence[Listener[T_Stream]] + + def __post_init__(self) -> None: + listeners: list[Listener[T_Stream]] = [] + for listener in self.listeners: + if isinstance(listener, MultiListener): + listeners.extend(listener.listeners) + del listener.listeners[:] # type: ignore[attr-defined] + else: + listeners.append(listener) + + self.listeners = listeners + + async def serve( + self, handler: Callable[[T_Stream], Any], task_group: TaskGroup | None = None + ) -> None: + from .. import create_task_group + + async with create_task_group() as tg: + for listener in self.listeners: + tg.start_soon(listener.serve, handler, task_group) + + async def aclose(self) -> None: + for listener in self.listeners: + await listener.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + attributes: dict = {} + for listener in self.listeners: + attributes.update(listener.extra_attributes) + + return attributes diff --git a/contrib/python/anyio/anyio/streams/text.py b/contrib/python/anyio/anyio/streams/text.py new file mode 100644 index 0000000000..bba2d3f7df --- /dev/null +++ b/contrib/python/anyio/anyio/streams/text.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import codecs +from dataclasses import InitVar, dataclass, field +from typing import Any, Callable, Mapping + +from ..abc import ( + AnyByteReceiveStream, + AnyByteSendStream, + AnyByteStream, + ObjectReceiveStream, + ObjectSendStream, + ObjectStream, +) + + +@dataclass(eq=False) +class TextReceiveStream(ObjectReceiveStream[str]): + """ + Stream wrapper that decodes bytes to strings using the given encoding. + + Decoding is done using :class:`~codecs.IncrementalDecoder` which returns any completely + received unicode characters as soon as they come in. + + :param transport_stream: any bytes-based receive stream + :param encoding: character encoding to use for decoding bytes to strings (defaults to + ``utf-8``) + :param errors: handling scheme for decoding errors (defaults to ``strict``; see the + `codecs module documentation`_ for a comprehensive list of options) + + .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects + """ + + transport_stream: AnyByteReceiveStream + encoding: InitVar[str] = "utf-8" + errors: InitVar[str] = "strict" + _decoder: codecs.IncrementalDecoder = field(init=False) + + def __post_init__(self, encoding: str, errors: str) -> None: + decoder_class = codecs.getincrementaldecoder(encoding) + self._decoder = decoder_class(errors=errors) + + async def receive(self) -> str: + while True: + chunk = await self.transport_stream.receive() + decoded = self._decoder.decode(chunk) + if decoded: + return decoded + + async def aclose(self) -> None: + await self.transport_stream.aclose() + self._decoder.reset() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return self.transport_stream.extra_attributes + + +@dataclass(eq=False) +class TextSendStream(ObjectSendStream[str]): + """ + Sends strings to the wrapped stream as bytes using the given encoding. + + :param AnyByteSendStream transport_stream: any bytes-based send stream + :param str encoding: character encoding to use for encoding strings to bytes (defaults to + ``utf-8``) + :param str errors: handling scheme for encoding errors (defaults to ``strict``; see the + `codecs module documentation`_ for a comprehensive list of options) + + .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects + """ + + transport_stream: AnyByteSendStream + encoding: InitVar[str] = "utf-8" + errors: str = "strict" + _encoder: Callable[..., tuple[bytes, int]] = field(init=False) + + def __post_init__(self, encoding: str) -> None: + self._encoder = codecs.getencoder(encoding) + + async def send(self, item: str) -> None: + encoded = self._encoder(item, self.errors)[0] + await self.transport_stream.send(encoded) + + async def aclose(self) -> None: + await self.transport_stream.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return self.transport_stream.extra_attributes + + +@dataclass(eq=False) +class TextStream(ObjectStream[str]): + """ + A bidirectional stream that decodes bytes to strings on receive and encodes strings to bytes on + send. + + Extra attributes will be provided from both streams, with the receive stream providing the + values in case of a conflict. + + :param AnyByteStream transport_stream: any bytes-based stream + :param str encoding: character encoding to use for encoding/decoding strings to/from bytes + (defaults to ``utf-8``) + :param str errors: handling scheme for encoding errors (defaults to ``strict``; see the + `codecs module documentation`_ for a comprehensive list of options) + + .. _codecs module documentation: https://docs.python.org/3/library/codecs.html#codec-objects + """ + + transport_stream: AnyByteStream + encoding: InitVar[str] = "utf-8" + errors: InitVar[str] = "strict" + _receive_stream: TextReceiveStream = field(init=False) + _send_stream: TextSendStream = field(init=False) + + def __post_init__(self, encoding: str, errors: str) -> None: + self._receive_stream = TextReceiveStream( + self.transport_stream, encoding=encoding, errors=errors + ) + self._send_stream = TextSendStream( + self.transport_stream, encoding=encoding, errors=errors + ) + + async def receive(self) -> str: + return await self._receive_stream.receive() + + async def send(self, item: str) -> None: + await self._send_stream.send(item) + + async def send_eof(self) -> None: + await self.transport_stream.send_eof() + + async def aclose(self) -> None: + await self._send_stream.aclose() + await self._receive_stream.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self._send_stream.extra_attributes, + **self._receive_stream.extra_attributes, + } diff --git a/contrib/python/anyio/anyio/streams/tls.py b/contrib/python/anyio/anyio/streams/tls.py new file mode 100644 index 0000000000..9f9e9fd89c --- /dev/null +++ b/contrib/python/anyio/anyio/streams/tls.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import logging +import re +import ssl +from dataclasses import dataclass +from functools import wraps +from typing import Any, Callable, Mapping, Tuple, TypeVar + +from .. import ( + BrokenResourceError, + EndOfStream, + aclose_forcefully, + get_cancelled_exc_class, +) +from .._core._typedattr import TypedAttributeSet, typed_attribute +from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup + +T_Retval = TypeVar("T_Retval") +_PCTRTT = Tuple[Tuple[str, str], ...] +_PCTRTTT = Tuple[_PCTRTT, ...] + + +class TLSAttribute(TypedAttributeSet): + """Contains Transport Layer Security related attributes.""" + + #: the selected ALPN protocol + alpn_protocol: str | None = typed_attribute() + #: the channel binding for type ``tls-unique`` + channel_binding_tls_unique: bytes = typed_attribute() + #: the selected cipher + cipher: tuple[str, str, int] = typed_attribute() + #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` + #: for more information) + peer_certificate: dict[str, str | _PCTRTTT | _PCTRTT] | None = typed_attribute() + #: the peer certificate in binary form + peer_certificate_binary: bytes | None = typed_attribute() + #: ``True`` if this is the server side of the connection + server_side: bool = typed_attribute() + #: ciphers shared by the client during the TLS handshake (``None`` if this is the + #: client side) + shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute() + #: the :class:`~ssl.SSLObject` used for encryption + ssl_object: ssl.SSLObject = typed_attribute() + #: ``True`` if this stream does (and expects) a closing TLS handshake when the + #: stream is being closed + standard_compatible: bool = typed_attribute() + #: the TLS protocol version (e.g. ``TLSv1.2``) + tls_version: str = typed_attribute() + + +@dataclass(eq=False) +class TLSStream(ByteStream): + """ + A stream wrapper that encrypts all sent data and decrypts received data. + + This class has no public initializer; use :meth:`wrap` instead. + All extra attributes from :class:`~TLSAttribute` are supported. + + :var AnyByteStream transport_stream: the wrapped stream + + """ + + transport_stream: AnyByteStream + standard_compatible: bool + _ssl_object: ssl.SSLObject + _read_bio: ssl.MemoryBIO + _write_bio: ssl.MemoryBIO + + @classmethod + async def wrap( + cls, + transport_stream: AnyByteStream, + *, + server_side: bool | None = None, + hostname: str | None = None, + ssl_context: ssl.SSLContext | None = None, + standard_compatible: bool = True, + ) -> TLSStream: + """ + Wrap an existing stream with Transport Layer Security. + + This performs a TLS handshake with the peer. + + :param transport_stream: a bytes-transporting stream to wrap + :param server_side: ``True`` if this is the server side of the connection, + ``False`` if this is the client side (if omitted, will be set to ``False`` + if ``hostname`` has been provided, ``False`` otherwise). Used only to create + a default context when an explicit context has not been provided. + :param hostname: host name of the peer (if host name checking is desired) + :param ssl_context: the SSLContext object to use (if not provided, a secure + default will be created) + :param standard_compatible: if ``False``, skip the closing handshake when closing the + connection, and don't raise an exception if the peer does the same + :raises ~ssl.SSLError: if the TLS handshake fails + + """ + if server_side is None: + server_side = not hostname + + if not ssl_context: + purpose = ( + ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH + ) + ssl_context = ssl.create_default_context(purpose) + + # Re-enable detection of unexpected EOFs if it was disabled by Python + if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): + ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF + + bio_in = ssl.MemoryBIO() + bio_out = ssl.MemoryBIO() + ssl_object = ssl_context.wrap_bio( + bio_in, bio_out, server_side=server_side, server_hostname=hostname + ) + wrapper = cls( + transport_stream=transport_stream, + standard_compatible=standard_compatible, + _ssl_object=ssl_object, + _read_bio=bio_in, + _write_bio=bio_out, + ) + await wrapper._call_sslobject_method(ssl_object.do_handshake) + return wrapper + + async def _call_sslobject_method( + self, func: Callable[..., T_Retval], *args: object + ) -> T_Retval: + while True: + try: + result = func(*args) + except ssl.SSLWantReadError: + try: + # Flush any pending writes first + if self._write_bio.pending: + await self.transport_stream.send(self._write_bio.read()) + + data = await self.transport_stream.receive() + except EndOfStream: + self._read_bio.write_eof() + except OSError as exc: + self._read_bio.write_eof() + self._write_bio.write_eof() + raise BrokenResourceError from exc + else: + self._read_bio.write(data) + except ssl.SSLWantWriteError: + await self.transport_stream.send(self._write_bio.read()) + except ssl.SSLSyscallError as exc: + self._read_bio.write_eof() + self._write_bio.write_eof() + raise BrokenResourceError from exc + except ssl.SSLError as exc: + self._read_bio.write_eof() + self._write_bio.write_eof() + if ( + isinstance(exc, ssl.SSLEOFError) + or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror + ): + if self.standard_compatible: + raise BrokenResourceError from exc + else: + raise EndOfStream from None + + raise + else: + # Flush any pending writes first + if self._write_bio.pending: + await self.transport_stream.send(self._write_bio.read()) + + return result + + async def unwrap(self) -> tuple[AnyByteStream, bytes]: + """ + Does the TLS closing handshake. + + :return: a tuple of (wrapped byte stream, bytes left in the read buffer) + + """ + await self._call_sslobject_method(self._ssl_object.unwrap) + self._read_bio.write_eof() + self._write_bio.write_eof() + return self.transport_stream, self._read_bio.read() + + async def aclose(self) -> None: + if self.standard_compatible: + try: + await self.unwrap() + except BaseException: + await aclose_forcefully(self.transport_stream) + raise + + await self.transport_stream.aclose() + + async def receive(self, max_bytes: int = 65536) -> bytes: + data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) + if not data: + raise EndOfStream + + return data + + async def send(self, item: bytes) -> None: + await self._call_sslobject_method(self._ssl_object.write, item) + + async def send_eof(self) -> None: + tls_version = self.extra(TLSAttribute.tls_version) + match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) + if match: + major, minor = int(match.group(1)), int(match.group(2) or 0) + if (major, minor) < (1, 3): + raise NotImplementedError( + f"send_eof() requires at least TLSv1.3; current " + f"session uses {tls_version}" + ) + + raise NotImplementedError( + "send_eof() has not yet been implemented for TLS streams" + ) + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self.transport_stream.extra_attributes, + TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, + TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding, + TLSAttribute.cipher: self._ssl_object.cipher, + TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), + TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( + True + ), + TLSAttribute.server_side: lambda: self._ssl_object.server_side, + TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers() + if self._ssl_object.server_side + else None, + TLSAttribute.standard_compatible: lambda: self.standard_compatible, + TLSAttribute.ssl_object: lambda: self._ssl_object, + TLSAttribute.tls_version: self._ssl_object.version, + } + + +@dataclass(eq=False) +class TLSListener(Listener[TLSStream]): + """ + A convenience listener that wraps another listener and auto-negotiates a TLS session on every + accepted connection. + + If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is + called to do whatever post-mortem processing is deemed necessary. + + Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. + + :param Listener listener: the listener to wrap + :param ssl_context: the SSL context object + :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` + :param handshake_timeout: time limit for the TLS handshake + (passed to :func:`~anyio.fail_after`) + """ + + listener: Listener[Any] + ssl_context: ssl.SSLContext + standard_compatible: bool = True + handshake_timeout: float = 30 + + @staticmethod + async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: + """ + Handle an exception raised during the TLS handshake. + + This method does 3 things: + + #. Forcefully closes the original stream + #. Logs the exception (unless it was a cancellation exception) using the + ``anyio.streams.tls`` logger + #. Reraises the exception if it was a base exception or a cancellation exception + + :param exc: the exception + :param stream: the original stream + + """ + await aclose_forcefully(stream) + + # Log all except cancellation exceptions + if not isinstance(exc, get_cancelled_exc_class()): + logging.getLogger(__name__).exception("Error during TLS handshake") + + # Only reraise base exceptions and cancellation exceptions + if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): + raise + + async def serve( + self, + handler: Callable[[TLSStream], Any], + task_group: TaskGroup | None = None, + ) -> None: + @wraps(handler) + async def handler_wrapper(stream: AnyByteStream) -> None: + from .. import fail_after + + try: + with fail_after(self.handshake_timeout): + wrapped_stream = await TLSStream.wrap( + stream, + ssl_context=self.ssl_context, + standard_compatible=self.standard_compatible, + ) + except BaseException as exc: + await self.handle_handshake_error(exc, stream) + else: + await handler(wrapped_stream) + + await self.listener.serve(handler_wrapper, task_group) + + async def aclose(self) -> None: + await self.listener.aclose() + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + TLSAttribute.standard_compatible: lambda: self.standard_compatible, + } diff --git a/contrib/python/anyio/anyio/to_process.py b/contrib/python/anyio/anyio/to_process.py new file mode 100644 index 0000000000..7ba9d44198 --- /dev/null +++ b/contrib/python/anyio/anyio/to_process.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import os +import pickle +import subprocess +import sys +from collections import deque +from importlib.util import module_from_spec, spec_from_file_location +from typing import Callable, TypeVar, cast + +from ._core._eventloop import current_time, get_asynclib, get_cancelled_exc_class +from ._core._exceptions import BrokenWorkerProcess +from ._core._subprocesses import open_process +from ._core._synchronization import CapacityLimiter +from ._core._tasks import CancelScope, fail_after +from .abc import ByteReceiveStream, ByteSendStream, Process +from .lowlevel import RunVar, checkpoint_if_cancelled +from .streams.buffered import BufferedByteReceiveStream + +WORKER_MAX_IDLE_TIME = 300 # 5 minutes + +T_Retval = TypeVar("T_Retval") +_process_pool_workers: RunVar[set[Process]] = RunVar("_process_pool_workers") +_process_pool_idle_workers: RunVar[deque[tuple[Process, float]]] = RunVar( + "_process_pool_idle_workers" +) +_default_process_limiter: RunVar[CapacityLimiter] = RunVar("_default_process_limiter") + + +async def run_sync( + func: Callable[..., T_Retval], + *args: object, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T_Retval: + """ + Call the given function with the given arguments in a worker process. + + If the ``cancellable`` option is enabled and the task waiting for its completion is cancelled, + the worker process running it will be abruptly terminated using SIGKILL (or + ``terminateProcess()`` on Windows). + + :param func: a callable + :param args: positional arguments for the callable + :param cancellable: ``True`` to allow cancellation of the operation while it's running + :param limiter: capacity limiter to use to limit the total amount of processes running + (if omitted, the default limiter is used) + :return: an awaitable that yields the return value of the function. + + """ + + async def send_raw_command(pickled_cmd: bytes) -> object: + try: + await stdin.send(pickled_cmd) + response = await buffered.receive_until(b"\n", 50) + status, length = response.split(b" ") + if status not in (b"RETURN", b"EXCEPTION"): + raise RuntimeError( + f"Worker process returned unexpected response: {response!r}" + ) + + pickled_response = await buffered.receive_exactly(int(length)) + except BaseException as exc: + workers.discard(process) + try: + process.kill() + with CancelScope(shield=True): + await process.aclose() + except ProcessLookupError: + pass + + if isinstance(exc, get_cancelled_exc_class()): + raise + else: + raise BrokenWorkerProcess from exc + + retval = pickle.loads(pickled_response) + if status == b"EXCEPTION": + assert isinstance(retval, BaseException) + raise retval + else: + return retval + + # First pickle the request before trying to reserve a worker process + await checkpoint_if_cancelled() + request = pickle.dumps(("run", func, args), protocol=pickle.HIGHEST_PROTOCOL) + + # If this is the first run in this event loop thread, set up the necessary variables + try: + workers = _process_pool_workers.get() + idle_workers = _process_pool_idle_workers.get() + except LookupError: + workers = set() + idle_workers = deque() + _process_pool_workers.set(workers) + _process_pool_idle_workers.set(idle_workers) + get_asynclib().setup_process_pool_exit_at_shutdown(workers) + + async with (limiter or current_default_process_limiter()): + # Pop processes from the pool (starting from the most recently used) until we find one that + # hasn't exited yet + process: Process + while idle_workers: + process, idle_since = idle_workers.pop() + if process.returncode is None: + stdin = cast(ByteSendStream, process.stdin) + buffered = BufferedByteReceiveStream( + cast(ByteReceiveStream, process.stdout) + ) + + # Prune any other workers that have been idle for WORKER_MAX_IDLE_TIME seconds or + # longer + now = current_time() + killed_processes: list[Process] = [] + while idle_workers: + if now - idle_workers[0][1] < WORKER_MAX_IDLE_TIME: + break + + process, idle_since = idle_workers.popleft() + process.kill() + workers.remove(process) + killed_processes.append(process) + + with CancelScope(shield=True): + for process in killed_processes: + await process.aclose() + + break + + workers.remove(process) + else: + command = [sys.executable, "-u", "-m", __name__] + process = await open_process( + command, stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) + try: + stdin = cast(ByteSendStream, process.stdin) + buffered = BufferedByteReceiveStream( + cast(ByteReceiveStream, process.stdout) + ) + with fail_after(20): + message = await buffered.receive(6) + + if message != b"READY\n": + raise BrokenWorkerProcess( + f"Worker process returned unexpected response: {message!r}" + ) + + main_module_path = getattr(sys.modules["__main__"], "__file__", None) + pickled = pickle.dumps( + ("init", sys.path, main_module_path), + protocol=pickle.HIGHEST_PROTOCOL, + ) + await send_raw_command(pickled) + except (BrokenWorkerProcess, get_cancelled_exc_class()): + raise + except BaseException as exc: + process.kill() + raise BrokenWorkerProcess( + "Error during worker process initialization" + ) from exc + + workers.add(process) + + with CancelScope(shield=not cancellable): + try: + return cast(T_Retval, await send_raw_command(request)) + finally: + if process in workers: + idle_workers.append((process, current_time())) + + +def current_default_process_limiter() -> CapacityLimiter: + """ + Return the capacity limiter that is used by default to limit the number of worker processes. + + :return: a capacity limiter object + + """ + try: + return _default_process_limiter.get() + except LookupError: + limiter = CapacityLimiter(os.cpu_count() or 2) + _default_process_limiter.set(limiter) + return limiter + + +def process_worker() -> None: + # Redirect standard streams to os.devnull so that user code won't interfere with the + # parent-worker communication + stdin = sys.stdin + stdout = sys.stdout + sys.stdin = open(os.devnull) + sys.stdout = open(os.devnull, "w") + + stdout.buffer.write(b"READY\n") + while True: + retval = exception = None + try: + command, *args = pickle.load(stdin.buffer) + except EOFError: + return + except BaseException as exc: + exception = exc + else: + if command == "run": + func, args = args + try: + retval = func(*args) + except BaseException as exc: + exception = exc + elif command == "init": + main_module_path: str | None + sys.path, main_module_path = args + del sys.modules["__main__"] + if main_module_path: + # Load the parent's main module but as __mp_main__ instead of __main__ + # (like multiprocessing does) to avoid infinite recursion + try: + spec = spec_from_file_location("__mp_main__", main_module_path) + if spec and spec.loader: + main = module_from_spec(spec) + spec.loader.exec_module(main) + sys.modules["__main__"] = main + except BaseException as exc: + exception = exc + + try: + if exception is not None: + status = b"EXCEPTION" + pickled = pickle.dumps(exception, pickle.HIGHEST_PROTOCOL) + else: + status = b"RETURN" + pickled = pickle.dumps(retval, pickle.HIGHEST_PROTOCOL) + except BaseException as exc: + exception = exc + status = b"EXCEPTION" + pickled = pickle.dumps(exc, pickle.HIGHEST_PROTOCOL) + + stdout.buffer.write(b"%s %d\n" % (status, len(pickled))) + stdout.buffer.write(pickled) + + # Respect SIGTERM + if isinstance(exception, SystemExit): + raise exception + + +if __name__ == "__main__": + process_worker() diff --git a/contrib/python/anyio/anyio/to_thread.py b/contrib/python/anyio/anyio/to_thread.py new file mode 100644 index 0000000000..9315d1ecf1 --- /dev/null +++ b/contrib/python/anyio/anyio/to_thread.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Callable, TypeVar +from warnings import warn + +from ._core._eventloop import get_asynclib +from .abc import CapacityLimiter + +T_Retval = TypeVar("T_Retval") + + +async def run_sync( + func: Callable[..., T_Retval], + *args: object, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T_Retval: + """ + Call the given function with the given arguments in a worker thread. + + If the ``cancellable`` option is enabled and the task waiting for its completion is cancelled, + the thread will still run its course but its return value (or any raised exception) will be + ignored. + + :param func: a callable + :param args: positional arguments for the callable + :param cancellable: ``True`` to allow cancellation of the operation + :param limiter: capacity limiter to use to limit the total amount of threads running + (if omitted, the default limiter is used) + :return: an awaitable that yields the return value of the function. + + """ + return await get_asynclib().run_sync_in_worker_thread( + func, *args, cancellable=cancellable, limiter=limiter + ) + + +async def run_sync_in_worker_thread( + func: Callable[..., T_Retval], + *args: object, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T_Retval: + warn( + "run_sync_in_worker_thread() has been deprecated, use anyio.to_thread.run_sync() instead", + DeprecationWarning, + ) + return await run_sync(func, *args, cancellable=cancellable, limiter=limiter) + + +def current_default_thread_limiter() -> CapacityLimiter: + """ + Return the capacity limiter that is used by default to limit the number of concurrent threads. + + :return: a capacity limiter object + + """ + return get_asynclib().current_default_thread_limiter() + + +def current_default_worker_thread_limiter() -> CapacityLimiter: + warn( + "current_default_worker_thread_limiter() has been deprecated, " + "use anyio.to_thread.current_default_thread_limiter() instead", + DeprecationWarning, + ) + return current_default_thread_limiter() diff --git a/contrib/python/anyio/ya.make b/contrib/python/anyio/ya.make new file mode 100644 index 0000000000..f8534a7d6c --- /dev/null +++ b/contrib/python/anyio/ya.make @@ -0,0 +1,70 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(3.7.1) + +LICENSE(MIT) + +PEERDIR( + contrib/python/idna + contrib/python/sniffio +) + +NO_LINT() + +NO_CHECK_IMPORTS( + anyio._backends._trio + anyio.pytest_plugin +) + +PY_SRCS( + TOP_LEVEL + anyio/__init__.py + anyio/_backends/__init__.py + anyio/_backends/_asyncio.py + anyio/_backends/_trio.py + anyio/_core/__init__.py + anyio/_core/_compat.py + anyio/_core/_eventloop.py + anyio/_core/_exceptions.py + anyio/_core/_fileio.py + anyio/_core/_resources.py + anyio/_core/_signals.py + anyio/_core/_sockets.py + anyio/_core/_streams.py + anyio/_core/_subprocesses.py + anyio/_core/_synchronization.py + anyio/_core/_tasks.py + anyio/_core/_testing.py + anyio/_core/_typedattr.py + anyio/abc/__init__.py + anyio/abc/_resources.py + anyio/abc/_sockets.py + anyio/abc/_streams.py + anyio/abc/_subprocesses.py + anyio/abc/_tasks.py + anyio/abc/_testing.py + anyio/from_thread.py + anyio/lowlevel.py + anyio/pytest_plugin.py + anyio/streams/__init__.py + anyio/streams/buffered.py + anyio/streams/file.py + anyio/streams/memory.py + anyio/streams/stapled.py + anyio/streams/text.py + anyio/streams/tls.py + anyio/to_process.py + anyio/to_thread.py +) + +RESOURCE_FILES( + PREFIX contrib/python/anyio/ + .dist-info/METADATA + .dist-info/entry_points.txt + .dist-info/top_level.txt + anyio/py.typed +) + +END() diff --git a/contrib/python/h11/.dist-info/METADATA b/contrib/python/h11/.dist-info/METADATA new file mode 100644 index 0000000000..cf12a82f19 --- /dev/null +++ b/contrib/python/h11/.dist-info/METADATA @@ -0,0 +1,193 @@ +Metadata-Version: 2.1 +Name: h11 +Version: 0.14.0 +Summary: A pure-Python, bring-your-own-I/O implementation of HTTP/1.1 +Home-page: https://github.com/python-hyper/h11 +Author: Nathaniel J. Smith +Author-email: njs@pobox.com +License: MIT +Classifier: Development Status :: 3 - Alpha +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Topic :: Internet :: WWW/HTTP +Classifier: Topic :: System :: Networking +Requires-Python: >=3.7 +License-File: LICENSE.txt +Requires-Dist: typing-extensions ; python_version < "3.8" + +h11 +=== + +.. image:: https://travis-ci.org/python-hyper/h11.svg?branch=master + :target: https://travis-ci.org/python-hyper/h11 + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-hyper/h11/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-hyper/h11 + :alt: Test coverage + +.. image:: https://readthedocs.org/projects/h11/badge/?version=latest + :target: http://h11.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +This is a little HTTP/1.1 library written from scratch in Python, +heavily inspired by `hyper-h2 <https://hyper-h2.readthedocs.io/>`_. + +It's a "bring-your-own-I/O" library; h11 contains no IO code +whatsoever. This means you can hook h11 up to your favorite network +API, and that could be anything you want: synchronous, threaded, +asynchronous, or your own implementation of `RFC 6214 +<https://tools.ietf.org/html/rfc6214>`_ -- h11 won't judge you. +(Compare this to the current state of the art, where every time a `new +network API <https://trio.readthedocs.io/>`_ comes along then someone +gets to start over reimplementing the entire HTTP protocol from +scratch.) Cory Benfield made an `excellent blog post describing the +benefits of this approach +<https://lukasa.co.uk/2015/10/The_New_Hyper/>`_, or if you like video +then here's his `PyCon 2016 talk on the same theme +<https://www.youtube.com/watch?v=7cC3_jGwl_U>`_. + +This also means that h11 is not immediately useful out of the box: +it's a toolkit for building programs that speak HTTP, not something +that could directly replace ``requests`` or ``twisted.web`` or +whatever. But h11 makes it much easier to implement something like +``requests`` or ``twisted.web``. + +At a high level, working with h11 goes like this: + +1) First, create an ``h11.Connection`` object to track the state of a + single HTTP/1.1 connection. + +2) When you read data off the network, pass it to + ``conn.receive_data(...)``; you'll get back a list of objects + representing high-level HTTP "events". + +3) When you want to send a high-level HTTP event, create the + corresponding "event" object and pass it to ``conn.send(...)``; + this will give you back some bytes that you can then push out + through the network. + +For example, a client might instantiate and then send a +``h11.Request`` object, then zero or more ``h11.Data`` objects for the +request body (e.g., if this is a POST), and then a +``h11.EndOfMessage`` to indicate the end of the message. Then the +server would then send back a ``h11.Response``, some ``h11.Data``, and +its own ``h11.EndOfMessage``. If either side violates the protocol, +you'll get a ``h11.ProtocolError`` exception. + +h11 is suitable for implementing both servers and clients, and has a +pleasantly symmetric API: the events you send as a client are exactly +the ones that you receive as a server and vice-versa. + +`Here's an example of a tiny HTTP client +<https://github.com/python-hyper/h11/blob/master/examples/basic-client.py>`_ + +It also has `a fine manual <https://h11.readthedocs.io/>`_. + +FAQ +--- + +*Whyyyyy?* + +I wanted to play with HTTP in `Curio +<https://curio.readthedocs.io/en/latest/tutorial.html>`__ and `Trio +<https://trio.readthedocs.io>`__, which at the time didn't have any +HTTP libraries. So I thought, no big deal, Python has, like, a dozen +different implementations of HTTP, surely I can find one that's +reusable. I didn't find one, but I did find Cory's call-to-arms +blog-post. So I figured, well, fine, if I have to implement HTTP from +scratch, at least I can make sure no-one *else* has to ever again. + +*Should I use it?* + +Maybe. You should be aware that it's a very young project. But, it's +feature complete and has an exhaustive test-suite and complete docs, +so the next step is for people to try using it and see how it goes +:-). If you do then please let us know -- if nothing else we'll want +to talk to you before making any incompatible changes! + +*What are the features/limitations?* + +Roughly speaking, it's trying to be a robust, complete, and non-hacky +implementation of the first "chapter" of the HTTP/1.1 spec: `RFC 7230: +HTTP/1.1 Message Syntax and Routing +<https://tools.ietf.org/html/rfc7230>`_. That is, it mostly focuses on +implementing HTTP at the level of taking bytes on and off the wire, +and the headers related to that, and tries to be anal about spec +conformance. It doesn't know about higher-level concerns like URL +routing, conditional GETs, cross-origin cookie policies, or content +negotiation. But it does know how to take care of framing, +cross-version differences in keep-alive handling, and the "obsolete +line folding" rule, so you can focus your energies on the hard / +interesting parts for your application, and it tries to support the +full specification in the sense that any useful HTTP/1.1 conformant +application should be able to use h11. + +It's pure Python, and has no dependencies outside of the standard +library. + +It has a test suite with 100.0% coverage for both statements and +branches. + +Currently it supports Python 3 (testing on 3.7-3.10) and PyPy 3. +The last Python 2-compatible version was h11 0.11.x. +(Originally it had a Cython wrapper for `http-parser +<https://github.com/nodejs/http-parser>`_ and a beautiful nested state +machine implemented with ``yield from`` to postprocess the output. But +I had to take these out -- the new *parser* needs fewer lines-of-code +than the old *parser wrapper*, is written in pure Python, uses no +exotic language syntax, and has more features. It's sad, really; that +old state machine was really slick. I just need a few sentences here +to mourn that.) + +I don't know how fast it is. I haven't benchmarked or profiled it yet, +so it's probably got a few pointless hot spots, and I've been trying +to err on the side of simplicity and robustness instead of +micro-optimization. But at the architectural level I tried hard to +avoid fundamentally bad decisions, e.g., I believe that all the +parsing algorithms remain linear-time even in the face of pathological +input like slowloris, and there are no byte-by-byte loops. (I also +believe that it maintains bounded memory usage in the face of +arbitrary/pathological input.) + +The whole library is ~800 lines-of-code. You can read and understand +the whole thing in less than an hour. Most of the energy invested in +this so far has been spent on trying to keep things simple by +minimizing special-cases and ad hoc state manipulation; even though it +is now quite small and simple, I'm still annoyed that I haven't +figured out how to make it even smaller and simpler. (Unfortunately, +HTTP does not lend itself to simplicity.) + +The API is ~feature complete and I don't expect the general outlines +to change much, but you can't judge an API's ergonomics until you +actually document and use it, so I'd expect some changes in the +details. + +*How do I try it?* + +.. code-block:: sh + + $ pip install h11 + $ git clone git@github.com:python-hyper/h11 + $ cd h11/examples + $ python basic-client.py + +and go from there. + +*License?* + +MIT + +*Code of conduct?* + +Contributors are requested to follow our `code of conduct +<https://github.com/python-hyper/h11/blob/master/CODE_OF_CONDUCT.md>`_ in +all project spaces. diff --git a/contrib/python/h11/.dist-info/top_level.txt b/contrib/python/h11/.dist-info/top_level.txt new file mode 100644 index 0000000000..0d24def711 --- /dev/null +++ b/contrib/python/h11/.dist-info/top_level.txt @@ -0,0 +1 @@ +h11 diff --git a/contrib/python/h11/LICENSE.txt b/contrib/python/h11/LICENSE.txt new file mode 100644 index 0000000000..8f080eae84 --- /dev/null +++ b/contrib/python/h11/LICENSE.txt @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2016 Nathaniel J. Smith <njs@pobox.com> and other contributors + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/contrib/python/h11/README.rst b/contrib/python/h11/README.rst new file mode 100644 index 0000000000..56e277e3d1 --- /dev/null +++ b/contrib/python/h11/README.rst @@ -0,0 +1,168 @@ +h11 +=== + +.. image:: https://travis-ci.org/python-hyper/h11.svg?branch=master + :target: https://travis-ci.org/python-hyper/h11 + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-hyper/h11/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-hyper/h11 + :alt: Test coverage + +.. image:: https://readthedocs.org/projects/h11/badge/?version=latest + :target: http://h11.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +This is a little HTTP/1.1 library written from scratch in Python, +heavily inspired by `hyper-h2 <https://hyper-h2.readthedocs.io/>`_. + +It's a "bring-your-own-I/O" library; h11 contains no IO code +whatsoever. This means you can hook h11 up to your favorite network +API, and that could be anything you want: synchronous, threaded, +asynchronous, or your own implementation of `RFC 6214 +<https://tools.ietf.org/html/rfc6214>`_ -- h11 won't judge you. +(Compare this to the current state of the art, where every time a `new +network API <https://trio.readthedocs.io/>`_ comes along then someone +gets to start over reimplementing the entire HTTP protocol from +scratch.) Cory Benfield made an `excellent blog post describing the +benefits of this approach +<https://lukasa.co.uk/2015/10/The_New_Hyper/>`_, or if you like video +then here's his `PyCon 2016 talk on the same theme +<https://www.youtube.com/watch?v=7cC3_jGwl_U>`_. + +This also means that h11 is not immediately useful out of the box: +it's a toolkit for building programs that speak HTTP, not something +that could directly replace ``requests`` or ``twisted.web`` or +whatever. But h11 makes it much easier to implement something like +``requests`` or ``twisted.web``. + +At a high level, working with h11 goes like this: + +1) First, create an ``h11.Connection`` object to track the state of a + single HTTP/1.1 connection. + +2) When you read data off the network, pass it to + ``conn.receive_data(...)``; you'll get back a list of objects + representing high-level HTTP "events". + +3) When you want to send a high-level HTTP event, create the + corresponding "event" object and pass it to ``conn.send(...)``; + this will give you back some bytes that you can then push out + through the network. + +For example, a client might instantiate and then send a +``h11.Request`` object, then zero or more ``h11.Data`` objects for the +request body (e.g., if this is a POST), and then a +``h11.EndOfMessage`` to indicate the end of the message. Then the +server would then send back a ``h11.Response``, some ``h11.Data``, and +its own ``h11.EndOfMessage``. If either side violates the protocol, +you'll get a ``h11.ProtocolError`` exception. + +h11 is suitable for implementing both servers and clients, and has a +pleasantly symmetric API: the events you send as a client are exactly +the ones that you receive as a server and vice-versa. + +`Here's an example of a tiny HTTP client +<https://github.com/python-hyper/h11/blob/master/examples/basic-client.py>`_ + +It also has `a fine manual <https://h11.readthedocs.io/>`_. + +FAQ +--- + +*Whyyyyy?* + +I wanted to play with HTTP in `Curio +<https://curio.readthedocs.io/en/latest/tutorial.html>`__ and `Trio +<https://trio.readthedocs.io>`__, which at the time didn't have any +HTTP libraries. So I thought, no big deal, Python has, like, a dozen +different implementations of HTTP, surely I can find one that's +reusable. I didn't find one, but I did find Cory's call-to-arms +blog-post. So I figured, well, fine, if I have to implement HTTP from +scratch, at least I can make sure no-one *else* has to ever again. + +*Should I use it?* + +Maybe. You should be aware that it's a very young project. But, it's +feature complete and has an exhaustive test-suite and complete docs, +so the next step is for people to try using it and see how it goes +:-). If you do then please let us know -- if nothing else we'll want +to talk to you before making any incompatible changes! + +*What are the features/limitations?* + +Roughly speaking, it's trying to be a robust, complete, and non-hacky +implementation of the first "chapter" of the HTTP/1.1 spec: `RFC 7230: +HTTP/1.1 Message Syntax and Routing +<https://tools.ietf.org/html/rfc7230>`_. That is, it mostly focuses on +implementing HTTP at the level of taking bytes on and off the wire, +and the headers related to that, and tries to be anal about spec +conformance. It doesn't know about higher-level concerns like URL +routing, conditional GETs, cross-origin cookie policies, or content +negotiation. But it does know how to take care of framing, +cross-version differences in keep-alive handling, and the "obsolete +line folding" rule, so you can focus your energies on the hard / +interesting parts for your application, and it tries to support the +full specification in the sense that any useful HTTP/1.1 conformant +application should be able to use h11. + +It's pure Python, and has no dependencies outside of the standard +library. + +It has a test suite with 100.0% coverage for both statements and +branches. + +Currently it supports Python 3 (testing on 3.7-3.10) and PyPy 3. +The last Python 2-compatible version was h11 0.11.x. +(Originally it had a Cython wrapper for `http-parser +<https://github.com/nodejs/http-parser>`_ and a beautiful nested state +machine implemented with ``yield from`` to postprocess the output. But +I had to take these out -- the new *parser* needs fewer lines-of-code +than the old *parser wrapper*, is written in pure Python, uses no +exotic language syntax, and has more features. It's sad, really; that +old state machine was really slick. I just need a few sentences here +to mourn that.) + +I don't know how fast it is. I haven't benchmarked or profiled it yet, +so it's probably got a few pointless hot spots, and I've been trying +to err on the side of simplicity and robustness instead of +micro-optimization. But at the architectural level I tried hard to +avoid fundamentally bad decisions, e.g., I believe that all the +parsing algorithms remain linear-time even in the face of pathological +input like slowloris, and there are no byte-by-byte loops. (I also +believe that it maintains bounded memory usage in the face of +arbitrary/pathological input.) + +The whole library is ~800 lines-of-code. You can read and understand +the whole thing in less than an hour. Most of the energy invested in +this so far has been spent on trying to keep things simple by +minimizing special-cases and ad hoc state manipulation; even though it +is now quite small and simple, I'm still annoyed that I haven't +figured out how to make it even smaller and simpler. (Unfortunately, +HTTP does not lend itself to simplicity.) + +The API is ~feature complete and I don't expect the general outlines +to change much, but you can't judge an API's ergonomics until you +actually document and use it, so I'd expect some changes in the +details. + +*How do I try it?* + +.. code-block:: sh + + $ pip install h11 + $ git clone git@github.com:python-hyper/h11 + $ cd h11/examples + $ python basic-client.py + +and go from there. + +*License?* + +MIT + +*Code of conduct?* + +Contributors are requested to follow our `code of conduct +<https://github.com/python-hyper/h11/blob/master/CODE_OF_CONDUCT.md>`_ in +all project spaces. diff --git a/contrib/python/h11/h11/__init__.py b/contrib/python/h11/h11/__init__.py new file mode 100644 index 0000000000..989e92c345 --- /dev/null +++ b/contrib/python/h11/h11/__init__.py @@ -0,0 +1,62 @@ +# A highish-level implementation of the HTTP/1.1 wire protocol (RFC 7230), +# containing no networking code at all, loosely modelled on hyper-h2's generic +# implementation of HTTP/2 (and in particular the h2.connection.H2Connection +# class). There's still a bunch of subtle details you need to get right if you +# want to make this actually useful, because it doesn't implement all the +# semantics to check that what you're asking to write to the wire is sensible, +# but at least it gets you out of dealing with the wire itself. + +from h11._connection import Connection, NEED_DATA, PAUSED +from h11._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from h11._state import ( + CLIENT, + CLOSED, + DONE, + ERROR, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) +from h11._util import LocalProtocolError, ProtocolError, RemoteProtocolError +from h11._version import __version__ + +PRODUCT_ID = "python-h11/" + __version__ + + +__all__ = ( + "Connection", + "NEED_DATA", + "PAUSED", + "ConnectionClosed", + "Data", + "EndOfMessage", + "Event", + "InformationalResponse", + "Request", + "Response", + "CLIENT", + "CLOSED", + "DONE", + "ERROR", + "IDLE", + "MUST_CLOSE", + "SEND_BODY", + "SEND_RESPONSE", + "SERVER", + "SWITCHED_PROTOCOL", + "ProtocolError", + "LocalProtocolError", + "RemoteProtocolError", +) diff --git a/contrib/python/h11/h11/_abnf.py b/contrib/python/h11/h11/_abnf.py new file mode 100644 index 0000000000..933587fba2 --- /dev/null +++ b/contrib/python/h11/h11/_abnf.py @@ -0,0 +1,132 @@ +# We use native strings for all the re patterns, to take advantage of string +# formatting, and then convert to bytestrings when compiling the final re +# objects. + +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#whitespace +# OWS = *( SP / HTAB ) +# ; optional whitespace +OWS = r"[ \t]*" + +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#rule.token.separators +# token = 1*tchar +# +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" +# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" +# / DIGIT / ALPHA +# ; any VCHAR, except delimiters +token = r"[-!#$%&'*+.^_`|~0-9a-zA-Z]+" + +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#header.fields +# field-name = token +field_name = token + +# The standard says: +# +# field-value = *( field-content / obs-fold ) +# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +# field-vchar = VCHAR / obs-text +# obs-fold = CRLF 1*( SP / HTAB ) +# ; obsolete line folding +# ; see Section 3.2.4 +# +# https://tools.ietf.org/html/rfc5234#appendix-B.1 +# +# VCHAR = %x21-7E +# ; visible (printing) characters +# +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#rule.quoted-string +# obs-text = %x80-FF +# +# However, the standard definition of field-content is WRONG! It disallows +# fields containing a single visible character surrounded by whitespace, +# e.g. "foo a bar". +# +# See: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 +# +# So our definition of field_content attempts to fix it up... +# +# Also, we allow lots of control characters, because apparently people assume +# that they're legal in practice (e.g., google analytics makes cookies with +# \x01 in them!): +# https://github.com/python-hyper/h11/issues/57 +# We still don't allow NUL or whitespace, because those are often treated as +# meta-characters and letting them through can lead to nasty issues like SSRF. +vchar = r"[\x21-\x7e]" +vchar_or_obs_text = r"[^\x00\s]" +field_vchar = vchar_or_obs_text +field_content = r"{field_vchar}+(?:[ \t]+{field_vchar}+)*".format(**globals()) + +# We handle obs-fold at a different level, and our fixed-up field_content +# already grows to swallow the whole value, so ? instead of * +field_value = r"({field_content})?".format(**globals()) + +# header-field = field-name ":" OWS field-value OWS +header_field = ( + r"(?P<field_name>{field_name})" + r":" + r"{OWS}" + r"(?P<field_value>{field_value})" + r"{OWS}".format(**globals()) +) + +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#request.line +# +# request-line = method SP request-target SP HTTP-version CRLF +# method = token +# HTTP-version = HTTP-name "/" DIGIT "." DIGIT +# HTTP-name = %x48.54.54.50 ; "HTTP", case-sensitive +# +# request-target is complicated (see RFC 7230 sec 5.3) -- could be path, full +# URL, host+port (for connect), or even "*", but in any case we are guaranteed +# that it contists of the visible printing characters. +method = token +request_target = r"{vchar}+".format(**globals()) +http_version = r"HTTP/(?P<http_version>[0-9]\.[0-9])" +request_line = ( + r"(?P<method>{method})" + r" " + r"(?P<target>{request_target})" + r" " + r"{http_version}".format(**globals()) +) + +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#status.line +# +# status-line = HTTP-version SP status-code SP reason-phrase CRLF +# status-code = 3DIGIT +# reason-phrase = *( HTAB / SP / VCHAR / obs-text ) +status_code = r"[0-9]{3}" +reason_phrase = r"([ \t]|{vchar_or_obs_text})*".format(**globals()) +status_line = ( + r"{http_version}" + r" " + r"(?P<status_code>{status_code})" + # However, there are apparently a few too many servers out there that just + # leave out the reason phrase: + # https://github.com/scrapy/scrapy/issues/345#issuecomment-281756036 + # https://github.com/seanmonstar/httparse/issues/29 + # so make it optional. ?: is a non-capturing group. + r"(?: (?P<reason>{reason_phrase}))?".format(**globals()) +) + +HEXDIG = r"[0-9A-Fa-f]" +# Actually +# +# chunk-size = 1*HEXDIG +# +# but we impose an upper-limit to avoid ridiculosity. len(str(2**64)) == 20 +chunk_size = r"({HEXDIG}){{1,20}}".format(**globals()) +# Actually +# +# chunk-ext = *( ";" chunk-ext-name [ "=" chunk-ext-val ] ) +# +# but we aren't parsing the things so we don't really care. +chunk_ext = r";.*" +chunk_header = ( + r"(?P<chunk_size>{chunk_size})" + r"(?P<chunk_ext>{chunk_ext})?" + r"{OWS}\r\n".format( + **globals() + ) # Even though the specification does not allow for extra whitespaces, + # we are lenient with trailing whitespaces because some servers on the wild use it. +) diff --git a/contrib/python/h11/h11/_connection.py b/contrib/python/h11/h11/_connection.py new file mode 100644 index 0000000000..d175270759 --- /dev/null +++ b/contrib/python/h11/h11/_connection.py @@ -0,0 +1,633 @@ +# This contains the main Connection class. Everything in h11 revolves around +# this. +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union + +from ._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from ._headers import get_comma_header, has_expect_100_continue, set_comma_header +from ._readers import READERS, ReadersType +from ._receivebuffer import ReceiveBuffer +from ._state import ( + _SWITCH_CONNECT, + _SWITCH_UPGRADE, + CLIENT, + ConnectionState, + DONE, + ERROR, + MIGHT_SWITCH_PROTOCOL, + SEND_BODY, + SERVER, + SWITCHED_PROTOCOL, +) +from ._util import ( # Import the internal things we need + LocalProtocolError, + RemoteProtocolError, + Sentinel, +) +from ._writers import WRITERS, WritersType + +# Everything in __all__ gets re-exported as part of the h11 public API. +__all__ = ["Connection", "NEED_DATA", "PAUSED"] + + +class NEED_DATA(Sentinel, metaclass=Sentinel): + pass + + +class PAUSED(Sentinel, metaclass=Sentinel): + pass + + +# If we ever have this much buffered without it making a complete parseable +# event, we error out. The only time we really buffer is when reading the +# request/response line + headers together, so this is effectively the limit on +# the size of that. +# +# Some precedents for defaults: +# - node.js: 80 * 1024 +# - tomcat: 8 * 1024 +# - IIS: 16 * 1024 +# - Apache: <8 KiB per line> +DEFAULT_MAX_INCOMPLETE_EVENT_SIZE = 16 * 1024 + +# RFC 7230's rules for connection lifecycles: +# - If either side says they want to close the connection, then the connection +# must close. +# - HTTP/1.1 defaults to keep-alive unless someone says Connection: close +# - HTTP/1.0 defaults to close unless both sides say Connection: keep-alive +# (and even this is a mess -- e.g. if you're implementing a proxy then +# sending Connection: keep-alive is forbidden). +# +# We simplify life by simply not supporting keep-alive with HTTP/1.0 peers. So +# our rule is: +# - If someone says Connection: close, we will close +# - If someone uses HTTP/1.0, we will close. +def _keep_alive(event: Union[Request, Response]) -> bool: + connection = get_comma_header(event.headers, b"connection") + if b"close" in connection: + return False + if getattr(event, "http_version", b"1.1") < b"1.1": + return False + return True + + +def _body_framing( + request_method: bytes, event: Union[Request, Response] +) -> Tuple[str, Union[Tuple[()], Tuple[int]]]: + # Called when we enter SEND_BODY to figure out framing information for + # this body. + # + # These are the only two events that can trigger a SEND_BODY state: + assert type(event) in (Request, Response) + # Returns one of: + # + # ("content-length", count) + # ("chunked", ()) + # ("http/1.0", ()) + # + # which are (lookup key, *args) for constructing body reader/writer + # objects. + # + # Reference: https://tools.ietf.org/html/rfc7230#section-3.3.3 + # + # Step 1: some responses always have an empty body, regardless of what the + # headers say. + if type(event) is Response: + if ( + event.status_code in (204, 304) + or request_method == b"HEAD" + or (request_method == b"CONNECT" and 200 <= event.status_code < 300) + ): + return ("content-length", (0,)) + # Section 3.3.3 also lists another case -- responses with status_code + # < 200. For us these are InformationalResponses, not Responses, so + # they can't get into this function in the first place. + assert event.status_code >= 200 + + # Step 2: check for Transfer-Encoding (T-E beats C-L): + transfer_encodings = get_comma_header(event.headers, b"transfer-encoding") + if transfer_encodings: + assert transfer_encodings == [b"chunked"] + return ("chunked", ()) + + # Step 3: check for Content-Length + content_lengths = get_comma_header(event.headers, b"content-length") + if content_lengths: + return ("content-length", (int(content_lengths[0]),)) + + # Step 4: no applicable headers; fallback/default depends on type + if type(event) is Request: + return ("content-length", (0,)) + else: + return ("http/1.0", ()) + + +################################################################ +# +# The main Connection class +# +################################################################ + + +class Connection: + """An object encapsulating the state of an HTTP connection. + + Args: + our_role: If you're implementing a client, pass :data:`h11.CLIENT`. If + you're implementing a server, pass :data:`h11.SERVER`. + + max_incomplete_event_size (int): + The maximum number of bytes we're willing to buffer of an + incomplete event. In practice this mostly sets a limit on the + maximum size of the request/response line + headers. If this is + exceeded, then :meth:`next_event` will raise + :exc:`RemoteProtocolError`. + + """ + + def __init__( + self, + our_role: Type[Sentinel], + max_incomplete_event_size: int = DEFAULT_MAX_INCOMPLETE_EVENT_SIZE, + ) -> None: + self._max_incomplete_event_size = max_incomplete_event_size + # State and role tracking + if our_role not in (CLIENT, SERVER): + raise ValueError("expected CLIENT or SERVER, not {!r}".format(our_role)) + self.our_role = our_role + self.their_role: Type[Sentinel] + if our_role is CLIENT: + self.their_role = SERVER + else: + self.their_role = CLIENT + self._cstate = ConnectionState() + + # Callables for converting data->events or vice-versa given the + # current state + self._writer = self._get_io_object(self.our_role, None, WRITERS) + self._reader = self._get_io_object(self.their_role, None, READERS) + + # Holds any unprocessed received data + self._receive_buffer = ReceiveBuffer() + # If this is true, then it indicates that the incoming connection was + # closed *after* the end of whatever's in self._receive_buffer: + self._receive_buffer_closed = False + + # Extra bits of state that don't fit into the state machine. + # + # These two are only used to interpret framing headers for figuring + # out how to read/write response bodies. their_http_version is also + # made available as a convenient public API. + self.their_http_version: Optional[bytes] = None + self._request_method: Optional[bytes] = None + # This is pure flow-control and doesn't at all affect the set of legal + # transitions, so no need to bother ConnectionState with it: + self.client_is_waiting_for_100_continue = False + + @property + def states(self) -> Dict[Type[Sentinel], Type[Sentinel]]: + """A dictionary like:: + + {CLIENT: <client state>, SERVER: <server state>} + + See :ref:`state-machine` for details. + + """ + return dict(self._cstate.states) + + @property + def our_state(self) -> Type[Sentinel]: + """The current state of whichever role we are playing. See + :ref:`state-machine` for details. + """ + return self._cstate.states[self.our_role] + + @property + def their_state(self) -> Type[Sentinel]: + """The current state of whichever role we are NOT playing. See + :ref:`state-machine` for details. + """ + return self._cstate.states[self.their_role] + + @property + def they_are_waiting_for_100_continue(self) -> bool: + return self.their_role is CLIENT and self.client_is_waiting_for_100_continue + + def start_next_cycle(self) -> None: + """Attempt to reset our connection state for a new request/response + cycle. + + If both client and server are in :data:`DONE` state, then resets them + both to :data:`IDLE` state in preparation for a new request/response + cycle on this same connection. Otherwise, raises a + :exc:`LocalProtocolError`. + + See :ref:`keepalive-and-pipelining`. + + """ + old_states = dict(self._cstate.states) + self._cstate.start_next_cycle() + self._request_method = None + # self.their_http_version gets left alone, since it presumably lasts + # beyond a single request/response cycle + assert not self.client_is_waiting_for_100_continue + self._respond_to_state_changes(old_states) + + def _process_error(self, role: Type[Sentinel]) -> None: + old_states = dict(self._cstate.states) + self._cstate.process_error(role) + self._respond_to_state_changes(old_states) + + def _server_switch_event(self, event: Event) -> Optional[Type[Sentinel]]: + if type(event) is InformationalResponse and event.status_code == 101: + return _SWITCH_UPGRADE + if type(event) is Response: + if ( + _SWITCH_CONNECT in self._cstate.pending_switch_proposals + and 200 <= event.status_code < 300 + ): + return _SWITCH_CONNECT + return None + + # All events go through here + def _process_event(self, role: Type[Sentinel], event: Event) -> None: + # First, pass the event through the state machine to make sure it + # succeeds. + old_states = dict(self._cstate.states) + if role is CLIENT and type(event) is Request: + if event.method == b"CONNECT": + self._cstate.process_client_switch_proposal(_SWITCH_CONNECT) + if get_comma_header(event.headers, b"upgrade"): + self._cstate.process_client_switch_proposal(_SWITCH_UPGRADE) + server_switch_event = None + if role is SERVER: + server_switch_event = self._server_switch_event(event) + self._cstate.process_event(role, type(event), server_switch_event) + + # Then perform the updates triggered by it. + + if type(event) is Request: + self._request_method = event.method + + if role is self.their_role and type(event) in ( + Request, + Response, + InformationalResponse, + ): + event = cast(Union[Request, Response, InformationalResponse], event) + self.their_http_version = event.http_version + + # Keep alive handling + # + # RFC 7230 doesn't really say what one should do if Connection: close + # shows up on a 1xx InformationalResponse. I think the idea is that + # this is not supposed to happen. In any case, if it does happen, we + # ignore it. + if type(event) in (Request, Response) and not _keep_alive( + cast(Union[Request, Response], event) + ): + self._cstate.process_keep_alive_disabled() + + # 100-continue + if type(event) is Request and has_expect_100_continue(event): + self.client_is_waiting_for_100_continue = True + if type(event) in (InformationalResponse, Response): + self.client_is_waiting_for_100_continue = False + if role is CLIENT and type(event) in (Data, EndOfMessage): + self.client_is_waiting_for_100_continue = False + + self._respond_to_state_changes(old_states, event) + + def _get_io_object( + self, + role: Type[Sentinel], + event: Optional[Event], + io_dict: Union[ReadersType, WritersType], + ) -> Optional[Callable[..., Any]]: + # event may be None; it's only used when entering SEND_BODY + state = self._cstate.states[role] + if state is SEND_BODY: + # Special case: the io_dict has a dict of reader/writer factories + # that depend on the request/response framing. + framing_type, args = _body_framing( + cast(bytes, self._request_method), cast(Union[Request, Response], event) + ) + return io_dict[SEND_BODY][framing_type](*args) # type: ignore[index] + else: + # General case: the io_dict just has the appropriate reader/writer + # for this state + return io_dict.get((role, state)) # type: ignore[return-value] + + # This must be called after any action that might have caused + # self._cstate.states to change. + def _respond_to_state_changes( + self, + old_states: Dict[Type[Sentinel], Type[Sentinel]], + event: Optional[Event] = None, + ) -> None: + # Update reader/writer + if self.our_state != old_states[self.our_role]: + self._writer = self._get_io_object(self.our_role, event, WRITERS) + if self.their_state != old_states[self.their_role]: + self._reader = self._get_io_object(self.their_role, event, READERS) + + @property + def trailing_data(self) -> Tuple[bytes, bool]: + """Data that has been received, but not yet processed, represented as + a tuple with two elements, where the first is a byte-string containing + the unprocessed data itself, and the second is a bool that is True if + the receive connection was closed. + + See :ref:`switching-protocols` for discussion of why you'd want this. + """ + return (bytes(self._receive_buffer), self._receive_buffer_closed) + + def receive_data(self, data: bytes) -> None: + """Add data to our internal receive buffer. + + This does not actually do any processing on the data, just stores + it. To trigger processing, you have to call :meth:`next_event`. + + Args: + data (:term:`bytes-like object`): + The new data that was just received. + + Special case: If *data* is an empty byte-string like ``b""``, + then this indicates that the remote side has closed the + connection (end of file). Normally this is convenient, because + standard Python APIs like :meth:`file.read` or + :meth:`socket.recv` use ``b""`` to indicate end-of-file, while + other failures to read are indicated using other mechanisms + like raising :exc:`TimeoutError`. When using such an API you + can just blindly pass through whatever you get from ``read`` + to :meth:`receive_data`, and everything will work. + + But, if you have an API where reading an empty string is a + valid non-EOF condition, then you need to be aware of this and + make sure to check for such strings and avoid passing them to + :meth:`receive_data`. + + Returns: + Nothing, but after calling this you should call :meth:`next_event` + to parse the newly received data. + + Raises: + RuntimeError: + Raised if you pass an empty *data*, indicating EOF, and then + pass a non-empty *data*, indicating more data that somehow + arrived after the EOF. + + (Calling ``receive_data(b"")`` multiple times is fine, + and equivalent to calling it once.) + + """ + if data: + if self._receive_buffer_closed: + raise RuntimeError("received close, then received more data?") + self._receive_buffer += data + else: + self._receive_buffer_closed = True + + def _extract_next_receive_event( + self, + ) -> Union[Event, Type[NEED_DATA], Type[PAUSED]]: + state = self.their_state + # We don't pause immediately when they enter DONE, because even in + # DONE state we can still process a ConnectionClosed() event. But + # if we have data in our buffer, then we definitely aren't getting + # a ConnectionClosed() immediately and we need to pause. + if state is DONE and self._receive_buffer: + return PAUSED + if state is MIGHT_SWITCH_PROTOCOL or state is SWITCHED_PROTOCOL: + return PAUSED + assert self._reader is not None + event = self._reader(self._receive_buffer) + if event is None: + if not self._receive_buffer and self._receive_buffer_closed: + # In some unusual cases (basically just HTTP/1.0 bodies), EOF + # triggers an actual protocol event; in that case, we want to + # return that event, and then the state will change and we'll + # get called again to generate the actual ConnectionClosed(). + if hasattr(self._reader, "read_eof"): + event = self._reader.read_eof() # type: ignore[attr-defined] + else: + event = ConnectionClosed() + if event is None: + event = NEED_DATA + return event # type: ignore[no-any-return] + + def next_event(self) -> Union[Event, Type[NEED_DATA], Type[PAUSED]]: + """Parse the next event out of our receive buffer, update our internal + state, and return it. + + This is a mutating operation -- think of it like calling :func:`next` + on an iterator. + + Returns: + : One of three things: + + 1) An event object -- see :ref:`events`. + + 2) The special constant :data:`NEED_DATA`, which indicates that + you need to read more data from your socket and pass it to + :meth:`receive_data` before this method will be able to return + any more events. + + 3) The special constant :data:`PAUSED`, which indicates that we + are not in a state where we can process incoming data (usually + because the peer has finished their part of the current + request/response cycle, and you have not yet called + :meth:`start_next_cycle`). See :ref:`flow-control` for details. + + Raises: + RemoteProtocolError: + The peer has misbehaved. You should close the connection + (possibly after sending some kind of 4xx response). + + Once this method returns :class:`ConnectionClosed` once, then all + subsequent calls will also return :class:`ConnectionClosed`. + + If this method raises any exception besides :exc:`RemoteProtocolError` + then that's a bug -- if it happens please file a bug report! + + If this method raises any exception then it also sets + :attr:`Connection.their_state` to :data:`ERROR` -- see + :ref:`error-handling` for discussion. + + """ + + if self.their_state is ERROR: + raise RemoteProtocolError("Can't receive data when peer state is ERROR") + try: + event = self._extract_next_receive_event() + if event not in [NEED_DATA, PAUSED]: + self._process_event(self.their_role, cast(Event, event)) + if event is NEED_DATA: + if len(self._receive_buffer) > self._max_incomplete_event_size: + # 431 is "Request header fields too large" which is pretty + # much the only situation where we can get here + raise RemoteProtocolError( + "Receive buffer too long", error_status_hint=431 + ) + if self._receive_buffer_closed: + # We're still trying to complete some event, but that's + # never going to happen because no more data is coming + raise RemoteProtocolError("peer unexpectedly closed connection") + return event + except BaseException as exc: + self._process_error(self.their_role) + if isinstance(exc, LocalProtocolError): + exc._reraise_as_remote_protocol_error() + else: + raise + + def send(self, event: Event) -> Optional[bytes]: + """Convert a high-level event into bytes that can be sent to the peer, + while updating our internal state machine. + + Args: + event: The :ref:`event <events>` to send. + + Returns: + If ``type(event) is ConnectionClosed``, then returns + ``None``. Otherwise, returns a :term:`bytes-like object`. + + Raises: + LocalProtocolError: + Sending this event at this time would violate our + understanding of the HTTP/1.1 protocol. + + If this method raises any exception then it also sets + :attr:`Connection.our_state` to :data:`ERROR` -- see + :ref:`error-handling` for discussion. + + """ + data_list = self.send_with_data_passthrough(event) + if data_list is None: + return None + else: + return b"".join(data_list) + + def send_with_data_passthrough(self, event: Event) -> Optional[List[bytes]]: + """Identical to :meth:`send`, except that in situations where + :meth:`send` returns a single :term:`bytes-like object`, this instead + returns a list of them -- and when sending a :class:`Data` event, this + list is guaranteed to contain the exact object you passed in as + :attr:`Data.data`. See :ref:`sendfile` for discussion. + + """ + if self.our_state is ERROR: + raise LocalProtocolError("Can't send data when our state is ERROR") + try: + if type(event) is Response: + event = self._clean_up_response_headers_for_sending(event) + # We want to call _process_event before calling the writer, + # because if someone tries to do something invalid then this will + # give a sensible error message, while our writers all just assume + # they will only receive valid events. But, _process_event might + # change self._writer. So we have to do a little dance: + writer = self._writer + self._process_event(self.our_role, event) + if type(event) is ConnectionClosed: + return None + else: + # In any situation where writer is None, process_event should + # have raised ProtocolError + assert writer is not None + data_list: List[bytes] = [] + writer(event, data_list.append) + return data_list + except: + self._process_error(self.our_role) + raise + + def send_failed(self) -> None: + """Notify the state machine that we failed to send the data it gave + us. + + This causes :attr:`Connection.our_state` to immediately become + :data:`ERROR` -- see :ref:`error-handling` for discussion. + + """ + self._process_error(self.our_role) + + # When sending a Response, we take responsibility for a few things: + # + # - Sometimes you MUST set Connection: close. We take care of those + # times. (You can also set it yourself if you want, and if you do then + # we'll respect that and close the connection at the right time. But you + # don't have to worry about that unless you want to.) + # + # - The user has to set Content-Length if they want it. Otherwise, for + # responses that have bodies (e.g. not HEAD), then we will automatically + # select the right mechanism for streaming a body of unknown length, + # which depends on depending on the peer's HTTP version. + # + # This function's *only* responsibility is making sure headers are set up + # right -- everything downstream just looks at the headers. There are no + # side channels. + def _clean_up_response_headers_for_sending(self, response: Response) -> Response: + assert type(response) is Response + + headers = response.headers + need_close = False + + # HEAD requests need some special handling: they always act like they + # have Content-Length: 0, and that's how _body_framing treats + # them. But their headers are supposed to match what we would send if + # the request was a GET. (Technically there is one deviation allowed: + # we're allowed to leave out the framing headers -- see + # https://tools.ietf.org/html/rfc7231#section-4.3.2 . But it's just as + # easy to get them right.) + method_for_choosing_headers = cast(bytes, self._request_method) + if method_for_choosing_headers == b"HEAD": + method_for_choosing_headers = b"GET" + framing_type, _ = _body_framing(method_for_choosing_headers, response) + if framing_type in ("chunked", "http/1.0"): + # This response has a body of unknown length. + # If our peer is HTTP/1.1, we use Transfer-Encoding: chunked + # If our peer is HTTP/1.0, we use no framing headers, and close the + # connection afterwards. + # + # Make sure to clear Content-Length (in principle user could have + # set both and then we ignored Content-Length b/c + # Transfer-Encoding overwrote it -- this would be naughty of them, + # but the HTTP spec says that if our peer does this then we have + # to fix it instead of erroring out, so we'll accord the user the + # same respect). + headers = set_comma_header(headers, b"content-length", []) + if self.their_http_version is None or self.their_http_version < b"1.1": + # Either we never got a valid request and are sending back an + # error (their_http_version is None), so we assume the worst; + # or else we did get a valid HTTP/1.0 request, so we know that + # they don't understand chunked encoding. + headers = set_comma_header(headers, b"transfer-encoding", []) + # This is actually redundant ATM, since currently we + # unconditionally disable keep-alive when talking to HTTP/1.0 + # peers. But let's be defensive just in case we add + # Connection: keep-alive support later: + if self._request_method != b"HEAD": + need_close = True + else: + headers = set_comma_header(headers, b"transfer-encoding", [b"chunked"]) + + if not self._cstate.keep_alive or need_close: + # Make sure Connection: close is set + connection = set(get_comma_header(headers, b"connection")) + connection.discard(b"keep-alive") + connection.add(b"close") + headers = set_comma_header(headers, b"connection", sorted(connection)) + + return Response( + headers=headers, + status_code=response.status_code, + http_version=response.http_version, + reason=response.reason, + ) diff --git a/contrib/python/h11/h11/_events.py b/contrib/python/h11/h11/_events.py new file mode 100644 index 0000000000..075bf8a469 --- /dev/null +++ b/contrib/python/h11/h11/_events.py @@ -0,0 +1,369 @@ +# High level events that make up HTTP/1.1 conversations. Loosely inspired by +# the corresponding events in hyper-h2: +# +# http://python-hyper.org/h2/en/stable/api.html#events +# +# Don't subclass these. Stuff will break. + +import re +from abc import ABC +from dataclasses import dataclass, field +from typing import Any, cast, Dict, List, Tuple, Union + +from ._abnf import method, request_target +from ._headers import Headers, normalize_and_validate +from ._util import bytesify, LocalProtocolError, validate + +# Everything in __all__ gets re-exported as part of the h11 public API. +__all__ = [ + "Event", + "Request", + "InformationalResponse", + "Response", + "Data", + "EndOfMessage", + "ConnectionClosed", +] + +method_re = re.compile(method.encode("ascii")) +request_target_re = re.compile(request_target.encode("ascii")) + + +class Event(ABC): + """ + Base class for h11 events. + """ + + __slots__ = () + + +@dataclass(init=False, frozen=True) +class Request(Event): + """The beginning of an HTTP request. + + Fields: + + .. attribute:: method + + An HTTP method, e.g. ``b"GET"`` or ``b"POST"``. Always a byte + string. :term:`Bytes-like objects <bytes-like object>` and native + strings containing only ascii characters will be automatically + converted to byte strings. + + .. attribute:: target + + The target of an HTTP request, e.g. ``b"/index.html"``, or one of the + more exotic formats described in `RFC 7320, section 5.3 + <https://tools.ietf.org/html/rfc7230#section-5.3>`_. Always a byte + string. :term:`Bytes-like objects <bytes-like object>` and native + strings containing only ascii characters will be automatically + converted to byte strings. + + .. attribute:: headers + + Request headers, represented as a list of (name, value) pairs. See + :ref:`the header normalization rules <headers-format>` for details. + + .. attribute:: http_version + + The HTTP protocol version, represented as a byte string like + ``b"1.1"``. See :ref:`the HTTP version normalization rules + <http_version-format>` for details. + + """ + + __slots__ = ("method", "headers", "target", "http_version") + + method: bytes + headers: Headers + target: bytes + http_version: bytes + + def __init__( + self, + *, + method: Union[bytes, str], + headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]], + target: Union[bytes, str], + http_version: Union[bytes, str] = b"1.1", + _parsed: bool = False, + ) -> None: + super().__init__() + if isinstance(headers, Headers): + object.__setattr__(self, "headers", headers) + else: + object.__setattr__( + self, "headers", normalize_and_validate(headers, _parsed=_parsed) + ) + if not _parsed: + object.__setattr__(self, "method", bytesify(method)) + object.__setattr__(self, "target", bytesify(target)) + object.__setattr__(self, "http_version", bytesify(http_version)) + else: + object.__setattr__(self, "method", method) + object.__setattr__(self, "target", target) + object.__setattr__(self, "http_version", http_version) + + # "A server MUST respond with a 400 (Bad Request) status code to any + # HTTP/1.1 request message that lacks a Host header field and to any + # request message that contains more than one Host header field or a + # Host header field with an invalid field-value." + # -- https://tools.ietf.org/html/rfc7230#section-5.4 + host_count = 0 + for name, value in self.headers: + if name == b"host": + host_count += 1 + if self.http_version == b"1.1" and host_count == 0: + raise LocalProtocolError("Missing mandatory Host: header") + if host_count > 1: + raise LocalProtocolError("Found multiple Host: headers") + + validate(method_re, self.method, "Illegal method characters") + validate(request_target_re, self.target, "Illegal target characters") + + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(init=False, frozen=True) +class _ResponseBase(Event): + __slots__ = ("headers", "http_version", "reason", "status_code") + + headers: Headers + http_version: bytes + reason: bytes + status_code: int + + def __init__( + self, + *, + headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]], + status_code: int, + http_version: Union[bytes, str] = b"1.1", + reason: Union[bytes, str] = b"", + _parsed: bool = False, + ) -> None: + super().__init__() + if isinstance(headers, Headers): + object.__setattr__(self, "headers", headers) + else: + object.__setattr__( + self, "headers", normalize_and_validate(headers, _parsed=_parsed) + ) + if not _parsed: + object.__setattr__(self, "reason", bytesify(reason)) + object.__setattr__(self, "http_version", bytesify(http_version)) + if not isinstance(status_code, int): + raise LocalProtocolError("status code must be integer") + # Because IntEnum objects are instances of int, but aren't + # duck-compatible (sigh), see gh-72. + object.__setattr__(self, "status_code", int(status_code)) + else: + object.__setattr__(self, "reason", reason) + object.__setattr__(self, "http_version", http_version) + object.__setattr__(self, "status_code", status_code) + + self.__post_init__() + + def __post_init__(self) -> None: + pass + + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(init=False, frozen=True) +class InformationalResponse(_ResponseBase): + """An HTTP informational response. + + Fields: + + .. attribute:: status_code + + The status code of this response, as an integer. For an + :class:`InformationalResponse`, this is always in the range [100, + 200). + + .. attribute:: headers + + Request headers, represented as a list of (name, value) pairs. See + :ref:`the header normalization rules <headers-format>` for + details. + + .. attribute:: http_version + + The HTTP protocol version, represented as a byte string like + ``b"1.1"``. See :ref:`the HTTP version normalization rules + <http_version-format>` for details. + + .. attribute:: reason + + The reason phrase of this response, as a byte string. For example: + ``b"OK"``, or ``b"Not Found"``. + + """ + + def __post_init__(self) -> None: + if not (100 <= self.status_code < 200): + raise LocalProtocolError( + "InformationalResponse status_code should be in range " + "[100, 200), not {}".format(self.status_code) + ) + + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(init=False, frozen=True) +class Response(_ResponseBase): + """The beginning of an HTTP response. + + Fields: + + .. attribute:: status_code + + The status code of this response, as an integer. For an + :class:`Response`, this is always in the range [200, + 1000). + + .. attribute:: headers + + Request headers, represented as a list of (name, value) pairs. See + :ref:`the header normalization rules <headers-format>` for details. + + .. attribute:: http_version + + The HTTP protocol version, represented as a byte string like + ``b"1.1"``. See :ref:`the HTTP version normalization rules + <http_version-format>` for details. + + .. attribute:: reason + + The reason phrase of this response, as a byte string. For example: + ``b"OK"``, or ``b"Not Found"``. + + """ + + def __post_init__(self) -> None: + if not (200 <= self.status_code < 1000): + raise LocalProtocolError( + "Response status_code should be in range [200, 1000), not {}".format( + self.status_code + ) + ) + + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(init=False, frozen=True) +class Data(Event): + """Part of an HTTP message body. + + Fields: + + .. attribute:: data + + A :term:`bytes-like object` containing part of a message body. Or, if + using the ``combine=False`` argument to :meth:`Connection.send`, then + any object that your socket writing code knows what to do with, and for + which calling :func:`len` returns the number of bytes that will be + written -- see :ref:`sendfile` for details. + + .. attribute:: chunk_start + + A marker that indicates whether this data object is from the start of a + chunked transfer encoding chunk. This field is ignored when when a Data + event is provided to :meth:`Connection.send`: it is only valid on + events emitted from :meth:`Connection.next_event`. You probably + shouldn't use this attribute at all; see + :ref:`chunk-delimiters-are-bad` for details. + + .. attribute:: chunk_end + + A marker that indicates whether this data object is the last for a + given chunked transfer encoding chunk. This field is ignored when when + a Data event is provided to :meth:`Connection.send`: it is only valid + on events emitted from :meth:`Connection.next_event`. You probably + shouldn't use this attribute at all; see + :ref:`chunk-delimiters-are-bad` for details. + + """ + + __slots__ = ("data", "chunk_start", "chunk_end") + + data: bytes + chunk_start: bool + chunk_end: bool + + def __init__( + self, data: bytes, chunk_start: bool = False, chunk_end: bool = False + ) -> None: + object.__setattr__(self, "data", data) + object.__setattr__(self, "chunk_start", chunk_start) + object.__setattr__(self, "chunk_end", chunk_end) + + # This is an unhashable type. + __hash__ = None # type: ignore + + +# XX FIXME: "A recipient MUST ignore (or consider as an error) any fields that +# are forbidden to be sent in a trailer, since processing them as if they were +# present in the header section might bypass external security filters." +# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#chunked.trailer.part +# Unfortunately, the list of forbidden fields is long and vague :-/ +@dataclass(init=False, frozen=True) +class EndOfMessage(Event): + """The end of an HTTP message. + + Fields: + + .. attribute:: headers + + Default value: ``[]`` + + Any trailing headers attached to this message, represented as a list of + (name, value) pairs. See :ref:`the header normalization rules + <headers-format>` for details. + + Must be empty unless ``Transfer-Encoding: chunked`` is in use. + + """ + + __slots__ = ("headers",) + + headers: Headers + + def __init__( + self, + *, + headers: Union[ + Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]], None + ] = None, + _parsed: bool = False, + ) -> None: + super().__init__() + if headers is None: + headers = Headers([]) + elif not isinstance(headers, Headers): + headers = normalize_and_validate(headers, _parsed=_parsed) + + object.__setattr__(self, "headers", headers) + + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(frozen=True) +class ConnectionClosed(Event): + """This event indicates that the sender has closed their outgoing + connection. + + Note that this does not necessarily mean that they can't *receive* further + data, because TCP connections are composed to two one-way channels which + can be closed independently. See :ref:`closing` for details. + + No fields. + """ + + pass diff --git a/contrib/python/h11/h11/_headers.py b/contrib/python/h11/h11/_headers.py new file mode 100644 index 0000000000..b97d020b63 --- /dev/null +++ b/contrib/python/h11/h11/_headers.py @@ -0,0 +1,278 @@ +import re +from typing import AnyStr, cast, List, overload, Sequence, Tuple, TYPE_CHECKING, Union + +from ._abnf import field_name, field_value +from ._util import bytesify, LocalProtocolError, validate + +if TYPE_CHECKING: + from ._events import Request + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + + +# Facts +# ----- +# +# Headers are: +# keys: case-insensitive ascii +# values: mixture of ascii and raw bytes +# +# "Historically, HTTP has allowed field content with text in the ISO-8859-1 +# charset [ISO-8859-1], supporting other charsets only through use of +# [RFC2047] encoding. In practice, most HTTP header field values use only a +# subset of the US-ASCII charset [USASCII]. Newly defined header fields SHOULD +# limit their field values to US-ASCII octets. A recipient SHOULD treat other +# octets in field content (obs-text) as opaque data." +# And it deprecates all non-ascii values +# +# Leading/trailing whitespace in header names is forbidden +# +# Values get leading/trailing whitespace stripped +# +# Content-Disposition actually needs to contain unicode semantically; to +# accomplish this it has a terrifically weird way of encoding the filename +# itself as ascii (and even this still has lots of cross-browser +# incompatibilities) +# +# Order is important: +# "a proxy MUST NOT change the order of these field values when forwarding a +# message" +# (and there are several headers where the order indicates a preference) +# +# Multiple occurences of the same header: +# "A sender MUST NOT generate multiple header fields with the same field name +# in a message unless either the entire field value for that header field is +# defined as a comma-separated list [or the header is Set-Cookie which gets a +# special exception]" - RFC 7230. (cookies are in RFC 6265) +# +# So every header aside from Set-Cookie can be merged by b", ".join if it +# occurs repeatedly. But, of course, they can't necessarily be split by +# .split(b","), because quoting. +# +# Given all this mess (case insensitive, duplicates allowed, order is +# important, ...), there doesn't appear to be any standard way to handle +# headers in Python -- they're almost like dicts, but... actually just +# aren't. For now we punt and just use a super simple representation: headers +# are a list of pairs +# +# [(name1, value1), (name2, value2), ...] +# +# where all entries are bytestrings, names are lowercase and have no +# leading/trailing whitespace, and values are bytestrings with no +# leading/trailing whitespace. Searching and updating are done via naive O(n) +# methods. +# +# Maybe a dict-of-lists would be better? + +_content_length_re = re.compile(rb"[0-9]+") +_field_name_re = re.compile(field_name.encode("ascii")) +_field_value_re = re.compile(field_value.encode("ascii")) + + +class Headers(Sequence[Tuple[bytes, bytes]]): + """ + A list-like interface that allows iterating over headers as byte-pairs + of (lowercased-name, value). + + Internally we actually store the representation as three-tuples, + including both the raw original casing, in order to preserve casing + over-the-wire, and the lowercased name, for case-insensitive comparisions. + + r = Request( + method="GET", + target="/", + headers=[("Host", "example.org"), ("Connection", "keep-alive")], + http_version="1.1", + ) + assert r.headers == [ + (b"host", b"example.org"), + (b"connection", b"keep-alive") + ] + assert r.headers.raw_items() == [ + (b"Host", b"example.org"), + (b"Connection", b"keep-alive") + ] + """ + + __slots__ = "_full_items" + + def __init__(self, full_items: List[Tuple[bytes, bytes, bytes]]) -> None: + self._full_items = full_items + + def __bool__(self) -> bool: + return bool(self._full_items) + + def __eq__(self, other: object) -> bool: + return list(self) == list(other) # type: ignore + + def __len__(self) -> int: + return len(self._full_items) + + def __repr__(self) -> str: + return "<Headers(%s)>" % repr(list(self)) + + def __getitem__(self, idx: int) -> Tuple[bytes, bytes]: # type: ignore[override] + _, name, value = self._full_items[idx] + return (name, value) + + def raw_items(self) -> List[Tuple[bytes, bytes]]: + return [(raw_name, value) for raw_name, _, value in self._full_items] + + +HeaderTypes = Union[ + List[Tuple[bytes, bytes]], + List[Tuple[bytes, str]], + List[Tuple[str, bytes]], + List[Tuple[str, str]], +] + + +@overload +def normalize_and_validate(headers: Headers, _parsed: Literal[True]) -> Headers: + ... + + +@overload +def normalize_and_validate(headers: HeaderTypes, _parsed: Literal[False]) -> Headers: + ... + + +@overload +def normalize_and_validate( + headers: Union[Headers, HeaderTypes], _parsed: bool = False +) -> Headers: + ... + + +def normalize_and_validate( + headers: Union[Headers, HeaderTypes], _parsed: bool = False +) -> Headers: + new_headers = [] + seen_content_length = None + saw_transfer_encoding = False + for name, value in headers: + # For headers coming out of the parser, we can safely skip some steps, + # because it always returns bytes and has already run these regexes + # over the data: + if not _parsed: + name = bytesify(name) + value = bytesify(value) + validate(_field_name_re, name, "Illegal header name {!r}", name) + validate(_field_value_re, value, "Illegal header value {!r}", value) + assert isinstance(name, bytes) + assert isinstance(value, bytes) + + raw_name = name + name = name.lower() + if name == b"content-length": + lengths = {length.strip() for length in value.split(b",")} + if len(lengths) != 1: + raise LocalProtocolError("conflicting Content-Length headers") + value = lengths.pop() + validate(_content_length_re, value, "bad Content-Length") + if seen_content_length is None: + seen_content_length = value + new_headers.append((raw_name, name, value)) + elif seen_content_length != value: + raise LocalProtocolError("conflicting Content-Length headers") + elif name == b"transfer-encoding": + # "A server that receives a request message with a transfer coding + # it does not understand SHOULD respond with 501 (Not + # Implemented)." + # https://tools.ietf.org/html/rfc7230#section-3.3.1 + if saw_transfer_encoding: + raise LocalProtocolError( + "multiple Transfer-Encoding headers", error_status_hint=501 + ) + # "All transfer-coding names are case-insensitive" + # -- https://tools.ietf.org/html/rfc7230#section-4 + value = value.lower() + if value != b"chunked": + raise LocalProtocolError( + "Only Transfer-Encoding: chunked is supported", + error_status_hint=501, + ) + saw_transfer_encoding = True + new_headers.append((raw_name, name, value)) + else: + new_headers.append((raw_name, name, value)) + return Headers(new_headers) + + +def get_comma_header(headers: Headers, name: bytes) -> List[bytes]: + # Should only be used for headers whose value is a list of + # comma-separated, case-insensitive values. + # + # The header name `name` is expected to be lower-case bytes. + # + # Connection: meets these criteria (including cast insensitivity). + # + # Content-Length: technically is just a single value (1*DIGIT), but the + # standard makes reference to implementations that do multiple values, and + # using this doesn't hurt. Ditto, case insensitivity doesn't things either + # way. + # + # Transfer-Encoding: is more complex (allows for quoted strings), so + # splitting on , is actually wrong. For example, this is legal: + # + # Transfer-Encoding: foo; options="1,2", chunked + # + # and should be parsed as + # + # foo; options="1,2" + # chunked + # + # but this naive function will parse it as + # + # foo; options="1 + # 2" + # chunked + # + # However, this is okay because the only thing we are going to do with + # any Transfer-Encoding is reject ones that aren't just "chunked", so + # both of these will be treated the same anyway. + # + # Expect: the only legal value is the literal string + # "100-continue". Splitting on commas is harmless. Case insensitive. + # + out: List[bytes] = [] + for _, found_name, found_raw_value in headers._full_items: + if found_name == name: + found_raw_value = found_raw_value.lower() + for found_split_value in found_raw_value.split(b","): + found_split_value = found_split_value.strip() + if found_split_value: + out.append(found_split_value) + return out + + +def set_comma_header(headers: Headers, name: bytes, new_values: List[bytes]) -> Headers: + # The header name `name` is expected to be lower-case bytes. + # + # Note that when we store the header we use title casing for the header + # names, in order to match the conventional HTTP header style. + # + # Simply calling `.title()` is a blunt approach, but it's correct + # here given the cases where we're using `set_comma_header`... + # + # Connection, Content-Length, Transfer-Encoding. + new_headers: List[Tuple[bytes, bytes]] = [] + for found_raw_name, found_name, found_raw_value in headers._full_items: + if found_name != name: + new_headers.append((found_raw_name, found_raw_value)) + for new_value in new_values: + new_headers.append((name.title(), new_value)) + return normalize_and_validate(new_headers) + + +def has_expect_100_continue(request: "Request") -> bool: + # https://tools.ietf.org/html/rfc7231#section-5.1.1 + # "A server that receives a 100-continue expectation in an HTTP/1.0 request + # MUST ignore that expectation." + if request.http_version < b"1.1": + return False + expect = get_comma_header(request.headers, b"expect") + return b"100-continue" in expect diff --git a/contrib/python/h11/h11/_readers.py b/contrib/python/h11/h11/_readers.py new file mode 100644 index 0000000000..08a9574da4 --- /dev/null +++ b/contrib/python/h11/h11/_readers.py @@ -0,0 +1,247 @@ +# Code to read HTTP data +# +# Strategy: each reader is a callable which takes a ReceiveBuffer object, and +# either: +# 1) consumes some of it and returns an Event +# 2) raises a LocalProtocolError (for consistency -- e.g. we call validate() +# and it might raise a LocalProtocolError, so simpler just to always use +# this) +# 3) returns None, meaning "I need more data" +# +# If they have a .read_eof attribute, then this will be called if an EOF is +# received -- but this is optional. Either way, the actual ConnectionClosed +# event will be generated afterwards. +# +# READERS is a dict describing how to pick a reader. It maps states to either: +# - a reader +# - or, for body readers, a dict of per-framing reader factories + +import re +from typing import Any, Callable, Dict, Iterable, NoReturn, Optional, Tuple, Type, Union + +from ._abnf import chunk_header, header_field, request_line, status_line +from ._events import Data, EndOfMessage, InformationalResponse, Request, Response +from ._receivebuffer import ReceiveBuffer +from ._state import ( + CLIENT, + CLOSED, + DONE, + IDLE, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, +) +from ._util import LocalProtocolError, RemoteProtocolError, Sentinel, validate + +__all__ = ["READERS"] + +header_field_re = re.compile(header_field.encode("ascii")) +obs_fold_re = re.compile(rb"[ \t]+") + + +def _obsolete_line_fold(lines: Iterable[bytes]) -> Iterable[bytes]: + it = iter(lines) + last: Optional[bytes] = None + for line in it: + match = obs_fold_re.match(line) + if match: + if last is None: + raise LocalProtocolError("continuation line at start of headers") + if not isinstance(last, bytearray): + # Cast to a mutable type, avoiding copy on append to ensure O(n) time + last = bytearray(last) + last += b" " + last += line[match.end() :] + else: + if last is not None: + yield last + last = line + if last is not None: + yield last + + +def _decode_header_lines( + lines: Iterable[bytes], +) -> Iterable[Tuple[bytes, bytes]]: + for line in _obsolete_line_fold(lines): + matches = validate(header_field_re, line, "illegal header line: {!r}", line) + yield (matches["field_name"], matches["field_value"]) + + +request_line_re = re.compile(request_line.encode("ascii")) + + +def maybe_read_from_IDLE_client(buf: ReceiveBuffer) -> Optional[Request]: + lines = buf.maybe_extract_lines() + if lines is None: + if buf.is_next_line_obviously_invalid_request_line(): + raise LocalProtocolError("illegal request line") + return None + if not lines: + raise LocalProtocolError("no request line received") + matches = validate( + request_line_re, lines[0], "illegal request line: {!r}", lines[0] + ) + return Request( + headers=list(_decode_header_lines(lines[1:])), _parsed=True, **matches + ) + + +status_line_re = re.compile(status_line.encode("ascii")) + + +def maybe_read_from_SEND_RESPONSE_server( + buf: ReceiveBuffer, +) -> Union[InformationalResponse, Response, None]: + lines = buf.maybe_extract_lines() + if lines is None: + if buf.is_next_line_obviously_invalid_request_line(): + raise LocalProtocolError("illegal request line") + return None + if not lines: + raise LocalProtocolError("no response line received") + matches = validate(status_line_re, lines[0], "illegal status line: {!r}", lines[0]) + http_version = ( + b"1.1" if matches["http_version"] is None else matches["http_version"] + ) + reason = b"" if matches["reason"] is None else matches["reason"] + status_code = int(matches["status_code"]) + class_: Union[Type[InformationalResponse], Type[Response]] = ( + InformationalResponse if status_code < 200 else Response + ) + return class_( + headers=list(_decode_header_lines(lines[1:])), + _parsed=True, + status_code=status_code, + reason=reason, + http_version=http_version, + ) + + +class ContentLengthReader: + def __init__(self, length: int) -> None: + self._length = length + self._remaining = length + + def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]: + if self._remaining == 0: + return EndOfMessage() + data = buf.maybe_extract_at_most(self._remaining) + if data is None: + return None + self._remaining -= len(data) + return Data(data=data) + + def read_eof(self) -> NoReturn: + raise RemoteProtocolError( + "peer closed connection without sending complete message body " + "(received {} bytes, expected {})".format( + self._length - self._remaining, self._length + ) + ) + + +chunk_header_re = re.compile(chunk_header.encode("ascii")) + + +class ChunkedReader: + def __init__(self) -> None: + self._bytes_in_chunk = 0 + # After reading a chunk, we have to throw away the trailing \r\n; if + # this is >0 then we discard that many bytes before resuming regular + # de-chunkification. + self._bytes_to_discard = 0 + self._reading_trailer = False + + def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]: + if self._reading_trailer: + lines = buf.maybe_extract_lines() + if lines is None: + return None + return EndOfMessage(headers=list(_decode_header_lines(lines))) + if self._bytes_to_discard > 0: + data = buf.maybe_extract_at_most(self._bytes_to_discard) + if data is None: + return None + self._bytes_to_discard -= len(data) + if self._bytes_to_discard > 0: + return None + # else, fall through and read some more + assert self._bytes_to_discard == 0 + if self._bytes_in_chunk == 0: + # We need to refill our chunk count + chunk_header = buf.maybe_extract_next_line() + if chunk_header is None: + return None + matches = validate( + chunk_header_re, + chunk_header, + "illegal chunk header: {!r}", + chunk_header, + ) + # XX FIXME: we discard chunk extensions. Does anyone care? + self._bytes_in_chunk = int(matches["chunk_size"], base=16) + if self._bytes_in_chunk == 0: + self._reading_trailer = True + return self(buf) + chunk_start = True + else: + chunk_start = False + assert self._bytes_in_chunk > 0 + data = buf.maybe_extract_at_most(self._bytes_in_chunk) + if data is None: + return None + self._bytes_in_chunk -= len(data) + if self._bytes_in_chunk == 0: + self._bytes_to_discard = 2 + chunk_end = True + else: + chunk_end = False + return Data(data=data, chunk_start=chunk_start, chunk_end=chunk_end) + + def read_eof(self) -> NoReturn: + raise RemoteProtocolError( + "peer closed connection without sending complete message body " + "(incomplete chunked read)" + ) + + +class Http10Reader: + def __call__(self, buf: ReceiveBuffer) -> Optional[Data]: + data = buf.maybe_extract_at_most(999999999) + if data is None: + return None + return Data(data=data) + + def read_eof(self) -> EndOfMessage: + return EndOfMessage() + + +def expect_nothing(buf: ReceiveBuffer) -> None: + if buf: + raise LocalProtocolError("Got data when expecting EOF") + return None + + +ReadersType = Dict[ + Union[Type[Sentinel], Tuple[Type[Sentinel], Type[Sentinel]]], + Union[Callable[..., Any], Dict[str, Callable[..., Any]]], +] + +READERS: ReadersType = { + (CLIENT, IDLE): maybe_read_from_IDLE_client, + (SERVER, IDLE): maybe_read_from_SEND_RESPONSE_server, + (SERVER, SEND_RESPONSE): maybe_read_from_SEND_RESPONSE_server, + (CLIENT, DONE): expect_nothing, + (CLIENT, MUST_CLOSE): expect_nothing, + (CLIENT, CLOSED): expect_nothing, + (SERVER, DONE): expect_nothing, + (SERVER, MUST_CLOSE): expect_nothing, + (SERVER, CLOSED): expect_nothing, + SEND_BODY: { + "chunked": ChunkedReader, + "content-length": ContentLengthReader, + "http/1.0": Http10Reader, + }, +} diff --git a/contrib/python/h11/h11/_receivebuffer.py b/contrib/python/h11/h11/_receivebuffer.py new file mode 100644 index 0000000000..e5c4e08a56 --- /dev/null +++ b/contrib/python/h11/h11/_receivebuffer.py @@ -0,0 +1,153 @@ +import re +import sys +from typing import List, Optional, Union + +__all__ = ["ReceiveBuffer"] + + +# Operations we want to support: +# - find next \r\n or \r\n\r\n (\n or \n\n are also acceptable), +# or wait until there is one +# - read at-most-N bytes +# Goals: +# - on average, do this fast +# - worst case, do this in O(n) where n is the number of bytes processed +# Plan: +# - store bytearray, offset, how far we've searched for a separator token +# - use the how-far-we've-searched data to avoid rescanning +# - while doing a stream of uninterrupted processing, advance offset instead +# of constantly copying +# WARNING: +# - I haven't benchmarked or profiled any of this yet. +# +# Note that starting in Python 3.4, deleting the initial n bytes from a +# bytearray is amortized O(n), thanks to some excellent work by Antoine +# Martin: +# +# https://bugs.python.org/issue19087 +# +# This means that if we only supported 3.4+, we could get rid of the code here +# involving self._start and self.compress, because it's doing exactly the same +# thing that bytearray now does internally. +# +# BUT unfortunately, we still support 2.7, and reading short segments out of a +# long buffer MUST be O(bytes read) to avoid DoS issues, so we can't actually +# delete this code. Yet: +# +# https://pythonclock.org/ +# +# (Two things to double-check first though: make sure PyPy also has the +# optimization, and benchmark to make sure it's a win, since we do have a +# slightly clever thing where we delay calling compress() until we've +# processed a whole event, which could in theory be slightly more efficient +# than the internal bytearray support.) +blank_line_regex = re.compile(b"\n\r?\n", re.MULTILINE) + + +class ReceiveBuffer: + def __init__(self) -> None: + self._data = bytearray() + self._next_line_search = 0 + self._multiple_lines_search = 0 + + def __iadd__(self, byteslike: Union[bytes, bytearray]) -> "ReceiveBuffer": + self._data += byteslike + return self + + def __bool__(self) -> bool: + return bool(len(self)) + + def __len__(self) -> int: + return len(self._data) + + # for @property unprocessed_data + def __bytes__(self) -> bytes: + return bytes(self._data) + + def _extract(self, count: int) -> bytearray: + # extracting an initial slice of the data buffer and return it + out = self._data[:count] + del self._data[:count] + + self._next_line_search = 0 + self._multiple_lines_search = 0 + + return out + + def maybe_extract_at_most(self, count: int) -> Optional[bytearray]: + """ + Extract a fixed number of bytes from the buffer. + """ + out = self._data[:count] + if not out: + return None + + return self._extract(count) + + def maybe_extract_next_line(self) -> Optional[bytearray]: + """ + Extract the first line, if it is completed in the buffer. + """ + # Only search in buffer space that we've not already looked at. + search_start_index = max(0, self._next_line_search - 1) + partial_idx = self._data.find(b"\r\n", search_start_index) + + if partial_idx == -1: + self._next_line_search = len(self._data) + return None + + # + 2 is to compensate len(b"\r\n") + idx = partial_idx + 2 + + return self._extract(idx) + + def maybe_extract_lines(self) -> Optional[List[bytearray]]: + """ + Extract everything up to the first blank line, and return a list of lines. + """ + # Handle the case where we have an immediate empty line. + if self._data[:1] == b"\n": + self._extract(1) + return [] + + if self._data[:2] == b"\r\n": + self._extract(2) + return [] + + # Only search in buffer space that we've not already looked at. + match = blank_line_regex.search(self._data, self._multiple_lines_search) + if match is None: + self._multiple_lines_search = max(0, len(self._data) - 2) + return None + + # Truncate the buffer and return it. + idx = match.span(0)[-1] + out = self._extract(idx) + lines = out.split(b"\n") + + for line in lines: + if line.endswith(b"\r"): + del line[-1] + + assert lines[-2] == lines[-1] == b"" + + del lines[-2:] + + return lines + + # In theory we should wait until `\r\n` before starting to validate + # incoming data. However it's interesting to detect (very) invalid data + # early given they might not even contain `\r\n` at all (hence only + # timeout will get rid of them). + # This is not a 100% effective detection but more of a cheap sanity check + # allowing for early abort in some useful cases. + # This is especially interesting when peer is messing up with HTTPS and + # sent us a TLS stream where we were expecting plain HTTP given all + # versions of TLS so far start handshake with a 0x16 message type code. + def is_next_line_obviously_invalid_request_line(self) -> bool: + try: + # HTTP header line must not contain non-printable characters + # and should not start with a space + return self._data[0] < 0x21 + except IndexError: + return False diff --git a/contrib/python/h11/h11/_state.py b/contrib/python/h11/h11/_state.py new file mode 100644 index 0000000000..3593430a74 --- /dev/null +++ b/contrib/python/h11/h11/_state.py @@ -0,0 +1,367 @@ +################################################################ +# The core state machine +################################################################ +# +# Rule 1: everything that affects the state machine and state transitions must +# live here in this file. As much as possible goes into the table-based +# representation, but for the bits that don't quite fit, the actual code and +# state must nonetheless live here. +# +# Rule 2: this file does not know about what role we're playing; it only knows +# about HTTP request/response cycles in the abstract. This ensures that we +# don't cheat and apply different rules to local and remote parties. +# +# +# Theory of operation +# =================== +# +# Possibly the simplest way to think about this is that we actually have 5 +# different state machines here. Yes, 5. These are: +# +# 1) The client state, with its complicated automaton (see the docs) +# 2) The server state, with its complicated automaton (see the docs) +# 3) The keep-alive state, with possible states {True, False} +# 4) The SWITCH_CONNECT state, with possible states {False, True} +# 5) The SWITCH_UPGRADE state, with possible states {False, True} +# +# For (3)-(5), the first state listed is the initial state. +# +# (1)-(3) are stored explicitly in member variables. The last +# two are stored implicitly in the pending_switch_proposals set as: +# (state of 4) == (_SWITCH_CONNECT in pending_switch_proposals) +# (state of 5) == (_SWITCH_UPGRADE in pending_switch_proposals) +# +# And each of these machines has two different kinds of transitions: +# +# a) Event-triggered +# b) State-triggered +# +# Event triggered is the obvious thing that you'd think it is: some event +# happens, and if it's the right event at the right time then a transition +# happens. But there are somewhat complicated rules for which machines can +# "see" which events. (As a rule of thumb, if a machine "sees" an event, this +# means two things: the event can affect the machine, and if the machine is +# not in a state where it expects that event then it's an error.) These rules +# are: +# +# 1) The client machine sees all h11.events objects emitted by the client. +# +# 2) The server machine sees all h11.events objects emitted by the server. +# +# It also sees the client's Request event. +# +# And sometimes, server events are annotated with a _SWITCH_* event. For +# example, we can have a (Response, _SWITCH_CONNECT) event, which is +# different from a regular Response event. +# +# 3) The keep-alive machine sees the process_keep_alive_disabled() event +# (which is derived from Request/Response events), and this event +# transitions it from True -> False, or from False -> False. There's no way +# to transition back. +# +# 4&5) The _SWITCH_* machines transition from False->True when we get a +# Request that proposes the relevant type of switch (via +# process_client_switch_proposals), and they go from True->False when we +# get a Response that has no _SWITCH_* annotation. +# +# So that's event-triggered transitions. +# +# State-triggered transitions are less standard. What they do here is couple +# the machines together. The way this works is, when certain *joint* +# configurations of states are achieved, then we automatically transition to a +# new *joint* state. So, for example, if we're ever in a joint state with +# +# client: DONE +# keep-alive: False +# +# then the client state immediately transitions to: +# +# client: MUST_CLOSE +# +# This is fundamentally different from an event-based transition, because it +# doesn't matter how we arrived at the {client: DONE, keep-alive: False} state +# -- maybe the client transitioned SEND_BODY -> DONE, or keep-alive +# transitioned True -> False. Either way, once this precondition is satisfied, +# this transition is immediately triggered. +# +# What if two conflicting state-based transitions get enabled at the same +# time? In practice there's only one case where this arises (client DONE -> +# MIGHT_SWITCH_PROTOCOL versus DONE -> MUST_CLOSE), and we resolve it by +# explicitly prioritizing the DONE -> MIGHT_SWITCH_PROTOCOL transition. +# +# Implementation +# -------------- +# +# The event-triggered transitions for the server and client machines are all +# stored explicitly in a table. Ditto for the state-triggered transitions that +# involve just the server and client state. +# +# The transitions for the other machines, and the state-triggered transitions +# that involve the other machines, are written out as explicit Python code. +# +# It'd be nice if there were some cleaner way to do all this. This isn't +# *too* terrible, but I feel like it could probably be better. +# +# WARNING +# ------- +# +# The script that generates the state machine diagrams for the docs knows how +# to read out the EVENT_TRIGGERED_TRANSITIONS and STATE_TRIGGERED_TRANSITIONS +# tables. But it can't automatically read the transitions that are written +# directly in Python code. So if you touch those, you need to also update the +# script to keep it in sync! +from typing import cast, Dict, Optional, Set, Tuple, Type, Union + +from ._events import * +from ._util import LocalProtocolError, Sentinel + +# Everything in __all__ gets re-exported as part of the h11 public API. +__all__ = [ + "CLIENT", + "SERVER", + "IDLE", + "SEND_RESPONSE", + "SEND_BODY", + "DONE", + "MUST_CLOSE", + "CLOSED", + "MIGHT_SWITCH_PROTOCOL", + "SWITCHED_PROTOCOL", + "ERROR", +] + + +class CLIENT(Sentinel, metaclass=Sentinel): + pass + + +class SERVER(Sentinel, metaclass=Sentinel): + pass + + +# States +class IDLE(Sentinel, metaclass=Sentinel): + pass + + +class SEND_RESPONSE(Sentinel, metaclass=Sentinel): + pass + + +class SEND_BODY(Sentinel, metaclass=Sentinel): + pass + + +class DONE(Sentinel, metaclass=Sentinel): + pass + + +class MUST_CLOSE(Sentinel, metaclass=Sentinel): + pass + + +class CLOSED(Sentinel, metaclass=Sentinel): + pass + + +class ERROR(Sentinel, metaclass=Sentinel): + pass + + +# Switch types +class MIGHT_SWITCH_PROTOCOL(Sentinel, metaclass=Sentinel): + pass + + +class SWITCHED_PROTOCOL(Sentinel, metaclass=Sentinel): + pass + + +class _SWITCH_UPGRADE(Sentinel, metaclass=Sentinel): + pass + + +class _SWITCH_CONNECT(Sentinel, metaclass=Sentinel): + pass + + +EventTransitionType = Dict[ + Type[Sentinel], + Dict[ + Type[Sentinel], + Dict[Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]], Type[Sentinel]], + ], +] + +EVENT_TRIGGERED_TRANSITIONS: EventTransitionType = { + CLIENT: { + IDLE: {Request: SEND_BODY, ConnectionClosed: CLOSED}, + SEND_BODY: {Data: SEND_BODY, EndOfMessage: DONE}, + DONE: {ConnectionClosed: CLOSED}, + MUST_CLOSE: {ConnectionClosed: CLOSED}, + CLOSED: {ConnectionClosed: CLOSED}, + MIGHT_SWITCH_PROTOCOL: {}, + SWITCHED_PROTOCOL: {}, + ERROR: {}, + }, + SERVER: { + IDLE: { + ConnectionClosed: CLOSED, + Response: SEND_BODY, + # Special case: server sees client Request events, in this form + (Request, CLIENT): SEND_RESPONSE, + }, + SEND_RESPONSE: { + InformationalResponse: SEND_RESPONSE, + Response: SEND_BODY, + (InformationalResponse, _SWITCH_UPGRADE): SWITCHED_PROTOCOL, + (Response, _SWITCH_CONNECT): SWITCHED_PROTOCOL, + }, + SEND_BODY: {Data: SEND_BODY, EndOfMessage: DONE}, + DONE: {ConnectionClosed: CLOSED}, + MUST_CLOSE: {ConnectionClosed: CLOSED}, + CLOSED: {ConnectionClosed: CLOSED}, + SWITCHED_PROTOCOL: {}, + ERROR: {}, + }, +} + +StateTransitionType = Dict[ + Tuple[Type[Sentinel], Type[Sentinel]], Dict[Type[Sentinel], Type[Sentinel]] +] + +# NB: there are also some special-case state-triggered transitions hard-coded +# into _fire_state_triggered_transitions below. +STATE_TRIGGERED_TRANSITIONS: StateTransitionType = { + # (Client state, Server state) -> new states + # Protocol negotiation + (MIGHT_SWITCH_PROTOCOL, SWITCHED_PROTOCOL): {CLIENT: SWITCHED_PROTOCOL}, + # Socket shutdown + (CLOSED, DONE): {SERVER: MUST_CLOSE}, + (CLOSED, IDLE): {SERVER: MUST_CLOSE}, + (ERROR, DONE): {SERVER: MUST_CLOSE}, + (DONE, CLOSED): {CLIENT: MUST_CLOSE}, + (IDLE, CLOSED): {CLIENT: MUST_CLOSE}, + (DONE, ERROR): {CLIENT: MUST_CLOSE}, +} + + +class ConnectionState: + def __init__(self) -> None: + # Extra bits of state that don't quite fit into the state model. + + # If this is False then it enables the automatic DONE -> MUST_CLOSE + # transition. Don't set this directly; call .keep_alive_disabled() + self.keep_alive = True + + # This is a subset of {UPGRADE, CONNECT}, containing the proposals + # made by the client for switching protocols. + self.pending_switch_proposals: Set[Type[Sentinel]] = set() + + self.states: Dict[Type[Sentinel], Type[Sentinel]] = {CLIENT: IDLE, SERVER: IDLE} + + def process_error(self, role: Type[Sentinel]) -> None: + self.states[role] = ERROR + self._fire_state_triggered_transitions() + + def process_keep_alive_disabled(self) -> None: + self.keep_alive = False + self._fire_state_triggered_transitions() + + def process_client_switch_proposal(self, switch_event: Type[Sentinel]) -> None: + self.pending_switch_proposals.add(switch_event) + self._fire_state_triggered_transitions() + + def process_event( + self, + role: Type[Sentinel], + event_type: Type[Event], + server_switch_event: Optional[Type[Sentinel]] = None, + ) -> None: + _event_type: Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]] = event_type + if server_switch_event is not None: + assert role is SERVER + if server_switch_event not in self.pending_switch_proposals: + raise LocalProtocolError( + "Received server {} event without a pending proposal".format( + server_switch_event + ) + ) + _event_type = (event_type, server_switch_event) + if server_switch_event is None and _event_type is Response: + self.pending_switch_proposals = set() + self._fire_event_triggered_transitions(role, _event_type) + # Special case: the server state does get to see Request + # events. + if _event_type is Request: + assert role is CLIENT + self._fire_event_triggered_transitions(SERVER, (Request, CLIENT)) + self._fire_state_triggered_transitions() + + def _fire_event_triggered_transitions( + self, + role: Type[Sentinel], + event_type: Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]], + ) -> None: + state = self.states[role] + try: + new_state = EVENT_TRIGGERED_TRANSITIONS[role][state][event_type] + except KeyError: + event_type = cast(Type[Event], event_type) + raise LocalProtocolError( + "can't handle event type {} when role={} and state={}".format( + event_type.__name__, role, self.states[role] + ) + ) from None + self.states[role] = new_state + + def _fire_state_triggered_transitions(self) -> None: + # We apply these rules repeatedly until converging on a fixed point + while True: + start_states = dict(self.states) + + # It could happen that both these special-case transitions are + # enabled at the same time: + # + # DONE -> MIGHT_SWITCH_PROTOCOL + # DONE -> MUST_CLOSE + # + # For example, this will always be true of a HTTP/1.0 client + # requesting CONNECT. If this happens, the protocol switch takes + # priority. From there the client will either go to + # SWITCHED_PROTOCOL, in which case it's none of our business when + # they close the connection, or else the server will deny the + # request, in which case the client will go back to DONE and then + # from there to MUST_CLOSE. + if self.pending_switch_proposals: + if self.states[CLIENT] is DONE: + self.states[CLIENT] = MIGHT_SWITCH_PROTOCOL + + if not self.pending_switch_proposals: + if self.states[CLIENT] is MIGHT_SWITCH_PROTOCOL: + self.states[CLIENT] = DONE + + if not self.keep_alive: + for role in (CLIENT, SERVER): + if self.states[role] is DONE: + self.states[role] = MUST_CLOSE + + # Tabular state-triggered transitions + joint_state = (self.states[CLIENT], self.states[SERVER]) + changes = STATE_TRIGGERED_TRANSITIONS.get(joint_state, {}) + self.states.update(changes) + + if self.states == start_states: + # Fixed point reached + return + + def start_next_cycle(self) -> None: + if self.states != {CLIENT: DONE, SERVER: DONE}: + raise LocalProtocolError( + "not in a reusable state. self.states={}".format(self.states) + ) + # Can't reach DONE/DONE with any of these active, but still, let's be + # sure. + assert self.keep_alive + assert not self.pending_switch_proposals + self.states = {CLIENT: IDLE, SERVER: IDLE} diff --git a/contrib/python/h11/h11/_util.py b/contrib/python/h11/h11/_util.py new file mode 100644 index 0000000000..6718445290 --- /dev/null +++ b/contrib/python/h11/h11/_util.py @@ -0,0 +1,135 @@ +from typing import Any, Dict, NoReturn, Pattern, Tuple, Type, TypeVar, Union + +__all__ = [ + "ProtocolError", + "LocalProtocolError", + "RemoteProtocolError", + "validate", + "bytesify", +] + + +class ProtocolError(Exception): + """Exception indicating a violation of the HTTP/1.1 protocol. + + This as an abstract base class, with two concrete base classes: + :exc:`LocalProtocolError`, which indicates that you tried to do something + that HTTP/1.1 says is illegal, and :exc:`RemoteProtocolError`, which + indicates that the remote peer tried to do something that HTTP/1.1 says is + illegal. See :ref:`error-handling` for details. + + In addition to the normal :exc:`Exception` features, it has one attribute: + + .. attribute:: error_status_hint + + This gives a suggestion as to what status code a server might use if + this error occurred as part of a request. + + For a :exc:`RemoteProtocolError`, this is useful as a suggestion for + how you might want to respond to a misbehaving peer, if you're + implementing a server. + + For a :exc:`LocalProtocolError`, this can be taken as a suggestion for + how your peer might have responded to *you* if h11 had allowed you to + continue. + + The default is 400 Bad Request, a generic catch-all for protocol + violations. + + """ + + def __init__(self, msg: str, error_status_hint: int = 400) -> None: + if type(self) is ProtocolError: + raise TypeError("tried to directly instantiate ProtocolError") + Exception.__init__(self, msg) + self.error_status_hint = error_status_hint + + +# Strategy: there are a number of public APIs where a LocalProtocolError can +# be raised (send(), all the different event constructors, ...), and only one +# public API where RemoteProtocolError can be raised +# (receive_data()). Therefore we always raise LocalProtocolError internally, +# and then receive_data will translate this into a RemoteProtocolError. +# +# Internally: +# LocalProtocolError is the generic "ProtocolError". +# Externally: +# LocalProtocolError is for local errors and RemoteProtocolError is for +# remote errors. +class LocalProtocolError(ProtocolError): + def _reraise_as_remote_protocol_error(self) -> NoReturn: + # After catching a LocalProtocolError, use this method to re-raise it + # as a RemoteProtocolError. This method must be called from inside an + # except: block. + # + # An easy way to get an equivalent RemoteProtocolError is just to + # modify 'self' in place. + self.__class__ = RemoteProtocolError # type: ignore + # But the re-raising is somewhat non-trivial -- you might think that + # now that we've modified the in-flight exception object, that just + # doing 'raise' to re-raise it would be enough. But it turns out that + # this doesn't work, because Python tracks the exception type + # (exc_info[0]) separately from the exception object (exc_info[1]), + # and we only modified the latter. So we really do need to re-raise + # the new type explicitly. + # On py3, the traceback is part of the exception object, so our + # in-place modification preserved it and we can just re-raise: + raise self + + +class RemoteProtocolError(ProtocolError): + pass + + +def validate( + regex: Pattern[bytes], data: bytes, msg: str = "malformed data", *format_args: Any +) -> Dict[str, bytes]: + match = regex.fullmatch(data) + if not match: + if format_args: + msg = msg.format(*format_args) + raise LocalProtocolError(msg) + return match.groupdict() + + +# Sentinel values +# +# - Inherit identity-based comparison and hashing from object +# - Have a nice repr +# - Have a *bonus property*: type(sentinel) is sentinel +# +# The bonus property is useful if you want to take the return value from +# next_event() and do some sort of dispatch based on type(event). + +_T_Sentinel = TypeVar("_T_Sentinel", bound="Sentinel") + + +class Sentinel(type): + def __new__( + cls: Type[_T_Sentinel], + name: str, + bases: Tuple[type, ...], + namespace: Dict[str, Any], + **kwds: Any + ) -> _T_Sentinel: + assert bases == (Sentinel,) + v = super().__new__(cls, name, bases, namespace, **kwds) + v.__class__ = v # type: ignore + return v + + def __repr__(self) -> str: + return self.__name__ + + +# Used for methods, request targets, HTTP versions, header names, and header +# values. Accepts ascii-strings, or bytes/bytearray/memoryview/..., and always +# returns bytes. +def bytesify(s: Union[bytes, bytearray, memoryview, int, str]) -> bytes: + # Fast-path: + if type(s) is bytes: + return s + if isinstance(s, str): + s = s.encode("ascii") + if isinstance(s, int): + raise TypeError("expected bytes-like object, not int") + return bytes(s) diff --git a/contrib/python/h11/h11/_version.py b/contrib/python/h11/h11/_version.py new file mode 100644 index 0000000000..4c89113056 --- /dev/null +++ b/contrib/python/h11/h11/_version.py @@ -0,0 +1,16 @@ +# This file must be kept very simple, because it is consumed from several +# places -- it is imported by h11/__init__.py, execfile'd by setup.py, etc. + +# We use a simple scheme: +# 1.0.0 -> 1.0.0+dev -> 1.1.0 -> 1.1.0+dev +# where the +dev versions are never released into the wild, they're just what +# we stick into the VCS in between releases. +# +# This is compatible with PEP 440: +# http://legacy.python.org/dev/peps/pep-0440/ +# via the use of the "local suffix" "+dev", which is disallowed on index +# servers and causes 1.0.0+dev to sort after plain 1.0.0, which is what we +# want. (Contrast with the special suffix 1.0.0.dev, which sorts *before* +# 1.0.0.) + +__version__ = "0.14.0" diff --git a/contrib/python/h11/h11/_writers.py b/contrib/python/h11/h11/_writers.py new file mode 100644 index 0000000000..939cdb912a --- /dev/null +++ b/contrib/python/h11/h11/_writers.py @@ -0,0 +1,145 @@ +# Code to read HTTP data +# +# Strategy: each writer takes an event + a write-some-bytes function, which is +# calls. +# +# WRITERS is a dict describing how to pick a reader. It maps states to either: +# - a writer +# - or, for body writers, a dict of framin-dependent writer factories + +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +from ._events import Data, EndOfMessage, Event, InformationalResponse, Request, Response +from ._headers import Headers +from ._state import CLIENT, IDLE, SEND_BODY, SEND_RESPONSE, SERVER +from ._util import LocalProtocolError, Sentinel + +__all__ = ["WRITERS"] + +Writer = Callable[[bytes], Any] + + +def write_headers(headers: Headers, write: Writer) -> None: + # "Since the Host field-value is critical information for handling a + # request, a user agent SHOULD generate Host as the first header field + # following the request-line." - RFC 7230 + raw_items = headers._full_items + for raw_name, name, value in raw_items: + if name == b"host": + write(b"%s: %s\r\n" % (raw_name, value)) + for raw_name, name, value in raw_items: + if name != b"host": + write(b"%s: %s\r\n" % (raw_name, value)) + write(b"\r\n") + + +def write_request(request: Request, write: Writer) -> None: + if request.http_version != b"1.1": + raise LocalProtocolError("I only send HTTP/1.1") + write(b"%s %s HTTP/1.1\r\n" % (request.method, request.target)) + write_headers(request.headers, write) + + +# Shared between InformationalResponse and Response +def write_any_response( + response: Union[InformationalResponse, Response], write: Writer +) -> None: + if response.http_version != b"1.1": + raise LocalProtocolError("I only send HTTP/1.1") + status_bytes = str(response.status_code).encode("ascii") + # We don't bother sending ascii status messages like "OK"; they're + # optional and ignored by the protocol. (But the space after the numeric + # status code is mandatory.) + # + # XX FIXME: could at least make an effort to pull out the status message + # from stdlib's http.HTTPStatus table. Or maybe just steal their enums + # (either by import or copy/paste). We already accept them as status codes + # since they're of type IntEnum < int. + write(b"HTTP/1.1 %s %s\r\n" % (status_bytes, response.reason)) + write_headers(response.headers, write) + + +class BodyWriter: + def __call__(self, event: Event, write: Writer) -> None: + if type(event) is Data: + self.send_data(event.data, write) + elif type(event) is EndOfMessage: + self.send_eom(event.headers, write) + else: # pragma: no cover + assert False + + def send_data(self, data: bytes, write: Writer) -> None: + pass + + def send_eom(self, headers: Headers, write: Writer) -> None: + pass + + +# +# These are all careful not to do anything to 'data' except call len(data) and +# write(data). This allows us to transparently pass-through funny objects, +# like placeholder objects referring to files on disk that will be sent via +# sendfile(2). +# +class ContentLengthWriter(BodyWriter): + def __init__(self, length: int) -> None: + self._length = length + + def send_data(self, data: bytes, write: Writer) -> None: + self._length -= len(data) + if self._length < 0: + raise LocalProtocolError("Too much data for declared Content-Length") + write(data) + + def send_eom(self, headers: Headers, write: Writer) -> None: + if self._length != 0: + raise LocalProtocolError("Too little data for declared Content-Length") + if headers: + raise LocalProtocolError("Content-Length and trailers don't mix") + + +class ChunkedWriter(BodyWriter): + def send_data(self, data: bytes, write: Writer) -> None: + # if we encoded 0-length data in the naive way, it would look like an + # end-of-message. + if not data: + return + write(b"%x\r\n" % len(data)) + write(data) + write(b"\r\n") + + def send_eom(self, headers: Headers, write: Writer) -> None: + write(b"0\r\n") + write_headers(headers, write) + + +class Http10Writer(BodyWriter): + def send_data(self, data: bytes, write: Writer) -> None: + write(data) + + def send_eom(self, headers: Headers, write: Writer) -> None: + if headers: + raise LocalProtocolError("can't send trailers to HTTP/1.0 client") + # no need to close the socket ourselves, that will be taken care of by + # Connection: close machinery + + +WritersType = Dict[ + Union[Tuple[Type[Sentinel], Type[Sentinel]], Type[Sentinel]], + Union[ + Dict[str, Type[BodyWriter]], + Callable[[Union[InformationalResponse, Response], Writer], None], + Callable[[Request, Writer], None], + ], +] + +WRITERS: WritersType = { + (CLIENT, IDLE): write_request, + (SERVER, IDLE): write_any_response, + (SERVER, SEND_RESPONSE): write_any_response, + SEND_BODY: { + "chunked": ChunkedWriter, + "content-length": ContentLengthWriter, + "http/1.0": Http10Writer, + }, +} diff --git a/contrib/python/h11/h11/py.typed b/contrib/python/h11/h11/py.typed new file mode 100644 index 0000000000..f5642f79f2 --- /dev/null +++ b/contrib/python/h11/h11/py.typed @@ -0,0 +1 @@ +Marker diff --git a/contrib/python/h11/ya.make b/contrib/python/h11/ya.make new file mode 100644 index 0000000000..48fcc1a654 --- /dev/null +++ b/contrib/python/h11/ya.make @@ -0,0 +1,33 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(0.14.0) + +LICENSE(MIT) + +NO_LINT() + +PY_SRCS( + TOP_LEVEL + h11/__init__.py + h11/_abnf.py + h11/_connection.py + h11/_events.py + h11/_headers.py + h11/_readers.py + h11/_receivebuffer.py + h11/_state.py + h11/_util.py + h11/_version.py + h11/_writers.py +) + +RESOURCE_FILES( + PREFIX contrib/python/h11/ + .dist-info/METADATA + .dist-info/top_level.txt + h11/py.typed +) + +END() diff --git a/contrib/python/httpcore/.dist-info/METADATA b/contrib/python/httpcore/.dist-info/METADATA new file mode 100644 index 0000000000..3776738caf --- /dev/null +++ b/contrib/python/httpcore/.dist-info/METADATA @@ -0,0 +1,547 @@ +Metadata-Version: 2.1 +Name: httpcore +Version: 0.18.0 +Summary: A minimal low-level HTTP client. +Project-URL: Documentation, https://www.encode.io/httpcore +Project-URL: Homepage, https://www.encode.io/httpcore/ +Project-URL: Source, https://github.com/encode/httpcore +Author-email: Tom Christie <tom@tomchristie.com> +License-Expression: BSD-3-Clause +License-File: LICENSE.md +Classifier: Development Status :: 3 - Alpha +Classifier: Environment :: Web Environment +Classifier: Framework :: AsyncIO +Classifier: Framework :: Trio +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Topic :: Internet :: WWW/HTTP +Requires-Python: >=3.8 +Requires-Dist: anyio<5.0,>=3.0 +Requires-Dist: certifi +Requires-Dist: h11<0.15,>=0.13 +Requires-Dist: sniffio==1.* +Provides-Extra: http2 +Requires-Dist: h2<5,>=3; extra == 'http2' +Provides-Extra: socks +Requires-Dist: socksio==1.*; extra == 'socks' +Description-Content-Type: text/markdown + +# HTTP Core + +[![Test Suite](https://github.com/encode/httpcore/workflows/Test%20Suite/badge.svg)](https://github.com/encode/httpcore/actions) +[![Package version](https://badge.fury.io/py/httpcore.svg)](https://pypi.org/project/httpcore/) + +> *Do one thing, and do it well.* + +The HTTP Core package provides a minimal low-level HTTP client, which does +one thing only. Sending HTTP requests. + +It does not provide any high level model abstractions over the API, +does not handle redirects, multipart uploads, building authentication headers, +transparent HTTP caching, URL parsing, session cookie handling, +content or charset decoding, handling JSON, environment based configuration +defaults, or any of that Jazz. + +Some things HTTP Core does do: + +* Sending HTTP requests. +* Thread-safe / task-safe connection pooling. +* HTTP(S) proxy & SOCKS proxy support. +* Supports HTTP/1.1 and HTTP/2. +* Provides both sync and async interfaces. +* Async backend support for `asyncio` and `trio`. + +## Requirements + +Python 3.8+ + +## Installation + +For HTTP/1.1 only support, install with: + +```shell +$ pip install httpcore +``` + +For HTTP/1.1 and HTTP/2 support, install with: + +```shell +$ pip install httpcore[http2] +``` + +For SOCKS proxy support, install with: + +```shell +$ pip install httpcore[socks] +``` + +# Sending requests + +Send an HTTP request: + +```python +import httpcore + +response = httpcore.request("GET", "https://www.example.com/") + +print(response) +# <Response [200]> +print(response.status) +# 200 +print(response.headers) +# [(b'Accept-Ranges', b'bytes'), (b'Age', b'557328'), (b'Cache-Control', b'max-age=604800'), ...] +print(response.content) +# b'<!doctype html>\n<html>\n<head>\n<title>Example Domain</title>\n\n<meta charset="utf-8"/>\n ...' +``` + +The top-level `httpcore.request()` function is provided for convenience. In practice whenever you're working with `httpcore` you'll want to use the connection pooling functionality that it provides. + +```python +import httpcore + +http = httpcore.ConnectionPool() +response = http.request("GET", "https://www.example.com/") +``` + +Once you're ready to get going, [head over to the documentation](https://www.encode.io/httpcore/). + +## Motivation + +You *probably* don't want to be using HTTP Core directly. It might make sense if +you're writing something like a proxy service in Python, and you just want +something at the lowest possible level, but more typically you'll want to use +a higher level client library, such as `httpx`. + +The motivation for `httpcore` is: + +* To provide a reusable low-level client library, that other packages can then build on top of. +* To provide a *really clear interface split* between the networking code and client logic, + so that each is easier to understand and reason about in isolation. +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). + +## 0.18.0 (September 8th, 2023) + +- Add support for HTTPS proxies. (#745, #786) +- Drop Python 3.7 support. (#727) +- Handle `sni_hostname` extension with SOCKS proxy. (#774) +- Handle HTTP/1.1 half-closed connections gracefully. (#641) +- Change the type of `Extensions` from `Mapping[Str, Any]` to `MutableMapping[Str, Any]`. (#762) + +## 0.17.3 (July 5th, 2023) + +- Support async cancellations, ensuring that the connection pool is left in a clean state when cancellations occur. (#726) +- The networking backend interface has [been added to the public API](https://www.encode.io/httpcore/network-backends). Some classes which were previously private implementation detail are now part of the top-level public API. (#699) +- Graceful handling of HTTP/2 GoAway frames, with requests being transparently retried on a new connection. (#730) +- Add exceptions when a synchronous `trace callback` is passed to an asynchronous request or an asynchronous `trace callback` is passed to a synchronous request. (#717) +- Drop Python 3.7 support. (#727) + +## 0.17.2 (May 23th, 2023) + +- Add `socket_options` argument to `ConnectionPool` and `HTTProxy` classes. (#668) +- Improve logging with per-module logger names. (#690) +- Add `sni_hostname` request extension. (#696) +- Resolve race condition during import of `anyio` package. (#692) +- Enable TCP_NODELAY for all synchronous sockets. (#651) + +## 0.17.1 (May 17th, 2023) + +- If 'retries' is set, then allow retries if an SSL handshake error occurs. (#669) +- Improve correctness of tracebacks on network exceptions, by raising properly chained exceptions. (#678) +- Prevent connection-hanging behaviour when HTTP/2 connections are closed by a server-sent 'GoAway' frame. (#679) +- Fix edge-case exception when removing requests from the connection pool. (#680) +- Fix pool timeout edge-case. (#688) + +## 0.17.0 (March 16th, 2023) + +- Add DEBUG level logging. (#648) +- Respect HTTP/2 max concurrent streams when settings updates are sent by server. (#652) +- Increase the allowable HTTP header size to 100kB. (#647) +- Add `retries` option to SOCKS proxy classes. (#643) + +## 0.16.3 (December 20th, 2022) + +- Allow `ws` and `wss` schemes. Allows us to properly support websocket upgrade connections. (#625) +- Forwarding HTTP proxies use a connection-per-remote-host. Required by some proxy implementations. (#637) +- Don't raise `RuntimeError` when closing a connection pool with active connections. Removes some error cases when cancellations are used. (#631) +- Lazy import `anyio`, so that it's no longer a hard dependancy, and isn't imported if unused. (#639) + +## 0.16.2 (November 25th, 2022) + +- Revert 'Fix async cancellation behaviour', which introduced race conditions. (#627) +- Raise `RuntimeError` if attempting to us UNIX domain sockets on Windows. (#619) + +## 0.16.1 (November 17th, 2022) + +- Fix HTTP/1.1 interim informational responses, such as "100 Continue". (#605) + +## 0.16.0 (October 11th, 2022) + +- Support HTTP/1.1 informational responses. (#581) +- Fix async cancellation behaviour. (#580) +- Support `h11` 0.14. (#579) + +## 0.15.0 (May 17th, 2022) + +- Drop Python 3.6 support (#535) +- Ensure HTTP proxy CONNECT requests include `timeout` configuration. (#506) +- Switch to explicit `typing.Optional` for type hints. (#513) +- For `trio` map OSError exceptions to `ConnectError`. (#543) + +## 0.14.7 (February 4th, 2022) + +- Requests which raise a PoolTimeout need to be removed from the pool queue. (#502) +- Fix AttributeError that happened when Socks5Connection were terminated. (#501) + +## 0.14.6 (February 1st, 2022) + +- Fix SOCKS support for `http://` URLs. (#492) +- Resolve race condition around exceptions during streaming a response. (#491) + +## 0.14.5 (January 18th, 2022) + +- SOCKS proxy support. (#478) +- Add proxy_auth argument to HTTPProxy. (#481) +- Improve error message on 'RemoteProtocolError' exception when server disconnects without sending a response. (#479) + +## 0.14.4 (January 5th, 2022) + +- Support HTTP/2 on HTTPS tunnelling proxies. (#468) +- Fix proxy headers missing on HTTP forwarding. (#456) +- Only instantiate SSL context if required. (#457) +- More robust HTTP/2 handling. (#253, #439, #440, #441) + +## 0.14.3 (November 17th, 2021) + +- Fix race condition when removing closed connections from the pool. (#437) + +## 0.14.2 (November 16th, 2021) + +- Failed connections no longer remain in the pool. (Pull #433) + +## 0.14.1 (November 12th, 2021) + +- `max_connections` becomes optional. (Pull #429) +- `certifi` is now included in the install dependancies. (Pull #428) +- `h2` is now strictly optional. (Pull #428) + +## 0.14.0 (November 11th, 2021) + +The 0.14 release is a complete reworking of `httpcore`, comprehensively addressing some underlying issues in the connection pooling, as well as substantially redesigning the API to be more user friendly. + +Some of the lower-level API design also makes the components more easily testable in isolation, and the package now has 100% test coverage. + +See [discussion #419](https://github.com/encode/httpcore/discussions/419) for a little more background. + +There's some other neat bits in there too, such as the "trace" extension, which gives a hook into inspecting the internal events that occur during the request/response cycle. This extension is needed for the HTTPX cli, in order to... + +* Log the point at which the connection is established, and the IP/port on which it is made. +* Determine if the outgoing request should log as HTTP/1.1 or HTTP/2, rather than having to assume it's HTTP/2 if the --http2 flag was passed. (Which may not actually be true.) +* Log SSL version info / certificate info. + +Note that `curio` support is not currently available in 0.14.0. If you're using `httpcore` with `curio` please get in touch, so we can assess if we ought to prioritize it as a feature or not. + +## 0.13.7 (September 13th, 2021) + +- Fix broken error messaging when URL scheme is missing, or a non HTTP(S) scheme is used. (Pull #403) + +## 0.13.6 (June 15th, 2021) + +### Fixed + +- Close sockets when read or write timeouts occur. (Pull #365) + +## 0.13.5 (June 14th, 2021) + +### Fixed + +- Resolved niggles with AnyIO EOF behaviours. (Pull #358, #362) + +## 0.13.4 (June 9th, 2021) + +### Added + +- Improved error messaging when URL scheme is missing, or a non HTTP(S) scheme is used. (Pull #354) + +### Fixed + +- Switched to `anyio` as the default backend implementation when running with `asyncio`. Resolves some awkward [TLS timeout issues](https://github.com/encode/httpx/discussions/1511). + +## 0.13.3 (May 6th, 2021) + +### Added + +- Support HTTP/2 prior knowledge, using `httpcore.SyncConnectionPool(http1=False)`. (Pull #333) + +### Fixed + +- Handle cases where environment does not provide `select.poll` support. (Pull #331) + +## 0.13.2 (April 29th, 2021) + +### Added + +- Improve error message for specific case of `RemoteProtocolError` where server disconnects without sending a response. (Pull #313) + +## 0.13.1 (April 28th, 2021) + +### Fixed + +- More resiliant testing for closed connections. (Pull #311) +- Don't raise exceptions on ungraceful connection closes. (Pull #310) + +## 0.13.0 (April 21st, 2021) + +The 0.13 release updates the core API in order to match the HTTPX Transport API, +introduced in HTTPX 0.18 onwards. + +An example of making requests with the new interface is: + +```python +with httpcore.SyncConnectionPool() as http: + status_code, headers, stream, extensions = http.handle_request( + method=b'GET', + url=(b'https', b'example.org', 443, b'/'), + headers=[(b'host', b'example.org'), (b'user-agent', b'httpcore')] + stream=httpcore.ByteStream(b''), + extensions={} + ) + body = stream.read() + print(status_code, body) +``` + +### Changed + +- The `.request()` method is now `handle_request()`. (Pull #296) +- The `.arequest()` method is now `.handle_async_request()`. (Pull #296) +- The `headers` argument is no longer optional. (Pull #296) +- The `stream` argument is no longer optional. (Pull #296) +- The `ext` argument is now named `extensions`, and is no longer optional. (Pull #296) +- The `"reason"` extension keyword is now named `"reason_phrase"`. (Pull #296) +- The `"reason_phrase"` and `"http_version"` extensions now use byte strings for their values. (Pull #296) +- The `httpcore.PlainByteStream()` class becomes `httpcore.ByteStream()`. (Pull #296) + +### Added + +- Streams now support a `.read()` interface. (Pull #296) + +### Fixed + +- Task cancellation no longer leaks connections from the connection pool. (Pull #305) + +## 0.12.3 (December 7th, 2020) + +### Fixed + +- Abort SSL connections on close rather than waiting for remote EOF when using `asyncio`. (Pull #167) +- Fix exception raised in case of connect timeouts when using the `anyio` backend. (Pull #236) +- Fix `Host` header precedence for `:authority` in HTTP/2. (Pull #241, #243) +- Handle extra edge case when detecting for socket readability when using `asyncio`. (Pull #242, #244) +- Fix `asyncio` SSL warning when using proxy tunneling. (Pull #249) + +## 0.12.2 (November 20th, 2020) + +### Fixed + +- Properly wrap connect errors on the asyncio backend. (Pull #235) +- Fix `ImportError` occurring on Python 3.9 when using the HTTP/1.1 sync client in a multithreaded context. (Pull #237) + +## 0.12.1 (November 7th, 2020) + +### Added + +- Add connect retries. (Pull #221) + +### Fixed + +- Tweak detection of dropped connections, resolving an issue with open files limits on Linux. (Pull #185) +- Avoid leaking connections when establishing an HTTP tunnel to a proxy has failed. (Pull #223) +- Properly wrap OS errors when using `trio`. (Pull #225) + +## 0.12.0 (October 6th, 2020) + +### Changed + +- HTTP header casing is now preserved, rather than always sent in lowercase. (#216 and python-hyper/h11#104) + +### Added + +- Add Python 3.9 to officially supported versions. + +### Fixed + +- Gracefully handle a stdlib asyncio bug when a connection is closed while it is in a paused-for-reading state. (#201) + +## 0.11.1 (September 28nd, 2020) + +### Fixed + +- Add await to async semaphore release() coroutine (#197) +- Drop incorrect curio classifier (#192) + +## 0.11.0 (September 22nd, 2020) + +The Transport API with 0.11.0 has a couple of significant changes. + +Firstly we've moved changed the request interface in order to allow extensions, which will later enable us to support features +such as trailing headers, HTTP/2 server push, and CONNECT/Upgrade connections. + +The interface changes from: + +```python +def request(method, url, headers, stream, timeout): + return (http_version, status_code, reason, headers, stream) +``` + +To instead including an optional dictionary of extensions on the request and response: + +```python +def request(method, url, headers, stream, ext): + return (status_code, headers, stream, ext) +``` + +Having an open-ended extensions point will allow us to add later support for various optional features, that wouldn't otherwise be supported without these API changes. + +In particular: + +* Trailing headers support. +* HTTP/2 Server Push +* sendfile. +* Exposing raw connection on CONNECT, Upgrade, HTTP/2 bi-di streaming. +* Exposing debug information out of the API, including template name, template context. + +Currently extensions are limited to: + +* request: `timeout` - Optional. Timeout dictionary. +* response: `http_version` - Optional. Include the HTTP version used on the response. +* response: `reason` - Optional. Include the reason phrase used on the response. Only valid with HTTP/1.*. + +See https://github.com/encode/httpx/issues/1274#issuecomment-694884553 for the history behind this. + +Secondly, the async version of `request` is now namespaced as `arequest`. + +This allows concrete transports to support both sync and async implementations on the same class. + +### Added + +- Add curio support. (Pull #168) +- Add anyio support, with `backend="anyio"`. (Pull #169) + +### Changed + +- Update the Transport API to use 'ext' for optional extensions. (Pull #190) +- Update the Transport API to use `.request` and `.arequest` so implementations can support both sync and async. (Pull #189) + +## 0.10.2 (August 20th, 2020) + +### Added + +- Added Unix Domain Socket support. (Pull #139) + +### Fixed + +- Always include the port on proxy CONNECT requests. (Pull #154) +- Fix `max_keepalive_connections` configuration. (Pull #153) +- Fixes behaviour in HTTP/1.1 where server disconnects can be used to signal the end of the response body. (Pull #164) + +## 0.10.1 (August 7th, 2020) + +- Include `max_keepalive_connections` on `AsyncHTTPProxy`/`SyncHTTPProxy` classes. + +## 0.10.0 (August 7th, 2020) + +The most notable change in the 0.10.0 release is that HTTP/2 support is now fully optional. + +Use either `pip install httpcore` for HTTP/1.1 support only, or `pip install httpcore[http2]` for HTTP/1.1 and HTTP/2 support. + +### Added + +- HTTP/2 support becomes optional. (Pull #121, #130) +- Add `local_address=...` support. (Pull #100, #134) +- Add `PlainByteStream`, `IteratorByteStream`, `AsyncIteratorByteStream`. The `AsyncByteSteam` and `SyncByteStream` classes are now pure interface classes. (#133) +- Add `LocalProtocolError`, `RemoteProtocolError` exceptions. (Pull #129) +- Add `UnsupportedProtocol` exception. (Pull #128) +- Add `.get_connection_info()` method. (Pull #102, #137) +- Add better TRACE logs. (Pull #101) + +### Changed + +- `max_keepalive` is deprecated in favour of `max_keepalive_connections`. (Pull #140) + +### Fixed + +- Improve handling of server disconnects. (Pull #112) + +## 0.9.1 (May 27th, 2020) + +### Fixed + +- Proper host resolution for sync case, including IPv6 support. (Pull #97) +- Close outstanding connections when connection pool is closed. (Pull #98) + +## 0.9.0 (May 21th, 2020) + +### Changed + +- URL port becomes an `Optional[int]` instead of `int`. (Pull #92) + +### Fixed + +- Honor HTTP/2 max concurrent streams settings. (Pull #89, #90) +- Remove incorrect debug log. (Pull #83) + +## 0.8.4 (May 11th, 2020) + +### Added + +- Logging via HTTPCORE_LOG_LEVEL and HTTPX_LOG_LEVEL environment variables +and TRACE level logging. (Pull #79) + +### Fixed + +- Reuse of connections on HTTP/2 in close concurrency situations. (Pull #81) + +## 0.8.3 (May 6rd, 2020) + +### Fixed + +- Include `Host` and `Accept` headers on proxy "CONNECT" requests. +- De-duplicate any headers also contained in proxy_headers. +- HTTP/2 flag not being passed down to proxy connections. + +## 0.8.2 (May 3rd, 2020) + +### Fixed + +- Fix connections using proxy forwarding requests not being added to the +connection pool properly. (Pull #70) + +## 0.8.1 (April 30th, 2020) + +### Changed + +- Allow inherintance of both `httpcore.AsyncByteStream`, `httpcore.SyncByteStream` without type conflicts. + +## 0.8.0 (April 30th, 2020) + +### Fixed + +- Fixed tunnel proxy support. + +### Added + +- New `TimeoutException` base class. + +## 0.7.0 (March 5th, 2020) + +- First integration with HTTPX. diff --git a/contrib/python/httpcore/.dist-info/top_level.txt b/contrib/python/httpcore/.dist-info/top_level.txt new file mode 100644 index 0000000000..613e43507b --- /dev/null +++ b/contrib/python/httpcore/.dist-info/top_level.txt @@ -0,0 +1,4 @@ +httpcore +httpcore/_async +httpcore/_backends +httpcore/_sync diff --git a/contrib/python/httpcore/LICENSE.md b/contrib/python/httpcore/LICENSE.md new file mode 100644 index 0000000000..311b2b56c5 --- /dev/null +++ b/contrib/python/httpcore/LICENSE.md @@ -0,0 +1,27 @@ +Copyright © 2020, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/contrib/python/httpcore/README.md b/contrib/python/httpcore/README.md new file mode 100644 index 0000000000..66a2150016 --- /dev/null +++ b/contrib/python/httpcore/README.md @@ -0,0 +1,91 @@ +# HTTP Core + +[![Test Suite](https://github.com/encode/httpcore/workflows/Test%20Suite/badge.svg)](https://github.com/encode/httpcore/actions) +[![Package version](https://badge.fury.io/py/httpcore.svg)](https://pypi.org/project/httpcore/) + +> *Do one thing, and do it well.* + +The HTTP Core package provides a minimal low-level HTTP client, which does +one thing only. Sending HTTP requests. + +It does not provide any high level model abstractions over the API, +does not handle redirects, multipart uploads, building authentication headers, +transparent HTTP caching, URL parsing, session cookie handling, +content or charset decoding, handling JSON, environment based configuration +defaults, or any of that Jazz. + +Some things HTTP Core does do: + +* Sending HTTP requests. +* Thread-safe / task-safe connection pooling. +* HTTP(S) proxy & SOCKS proxy support. +* Supports HTTP/1.1 and HTTP/2. +* Provides both sync and async interfaces. +* Async backend support for `asyncio` and `trio`. + +## Requirements + +Python 3.8+ + +## Installation + +For HTTP/1.1 only support, install with: + +```shell +$ pip install httpcore +``` + +For HTTP/1.1 and HTTP/2 support, install with: + +```shell +$ pip install httpcore[http2] +``` + +For SOCKS proxy support, install with: + +```shell +$ pip install httpcore[socks] +``` + +# Sending requests + +Send an HTTP request: + +```python +import httpcore + +response = httpcore.request("GET", "https://www.example.com/") + +print(response) +# <Response [200]> +print(response.status) +# 200 +print(response.headers) +# [(b'Accept-Ranges', b'bytes'), (b'Age', b'557328'), (b'Cache-Control', b'max-age=604800'), ...] +print(response.content) +# b'<!doctype html>\n<html>\n<head>\n<title>Example Domain</title>\n\n<meta charset="utf-8"/>\n ...' +``` + +The top-level `httpcore.request()` function is provided for convenience. In practice whenever you're working with `httpcore` you'll want to use the connection pooling functionality that it provides. + +```python +import httpcore + +http = httpcore.ConnectionPool() +response = http.request("GET", "https://www.example.com/") +``` + +Once you're ready to get going, [head over to the documentation](https://www.encode.io/httpcore/). + +## Motivation + +You *probably* don't want to be using HTTP Core directly. It might make sense if +you're writing something like a proxy service in Python, and you just want +something at the lowest possible level, but more typically you'll want to use +a higher level client library, such as `httpx`. + +The motivation for `httpcore` is: + +* To provide a reusable low-level client library, that other packages can then build on top of. +* To provide a *really clear interface split* between the networking code and client logic, + so that each is easier to understand and reason about in isolation. diff --git a/contrib/python/httpcore/httpcore/__init__.py b/contrib/python/httpcore/httpcore/__init__.py new file mode 100644 index 0000000000..65abe9716a --- /dev/null +++ b/contrib/python/httpcore/httpcore/__init__.py @@ -0,0 +1,139 @@ +from ._api import request, stream +from ._async import ( + AsyncConnectionInterface, + AsyncConnectionPool, + AsyncHTTP2Connection, + AsyncHTTP11Connection, + AsyncHTTPConnection, + AsyncHTTPProxy, + AsyncSOCKSProxy, +) +from ._backends.base import ( + SOCKET_OPTION, + AsyncNetworkBackend, + AsyncNetworkStream, + NetworkBackend, + NetworkStream, +) +from ._backends.mock import AsyncMockBackend, AsyncMockStream, MockBackend, MockStream +from ._backends.sync import SyncBackend +from ._exceptions import ( + ConnectError, + ConnectionNotAvailable, + ConnectTimeout, + LocalProtocolError, + NetworkError, + PoolTimeout, + ProtocolError, + ProxyError, + ReadError, + ReadTimeout, + RemoteProtocolError, + TimeoutException, + UnsupportedProtocol, + WriteError, + WriteTimeout, +) +from ._models import URL, Origin, Request, Response +from ._ssl import default_ssl_context +from ._sync import ( + ConnectionInterface, + ConnectionPool, + HTTP2Connection, + HTTP11Connection, + HTTPConnection, + HTTPProxy, + SOCKSProxy, +) + +# The 'httpcore.AnyIOBackend' class is conditional on 'anyio' being installed. +try: + from ._backends.anyio import AnyIOBackend +except ImportError: # pragma: nocover + + class AnyIOBackend: # type: ignore + def __init__(self, *args, **kwargs): # type: ignore + msg = ( + "Attempted to use 'httpcore.AnyIOBackend' but 'anyio' is not installed." + ) + raise RuntimeError(msg) + + +# The 'httpcore.TrioBackend' class is conditional on 'trio' being installed. +try: + from ._backends.trio import TrioBackend +except ImportError: # pragma: nocover + + class TrioBackend: # type: ignore + def __init__(self, *args, **kwargs): # type: ignore + msg = "Attempted to use 'httpcore.TrioBackend' but 'trio' is not installed." + raise RuntimeError(msg) + + +__all__ = [ + # top-level requests + "request", + "stream", + # models + "Origin", + "URL", + "Request", + "Response", + # async + "AsyncHTTPConnection", + "AsyncConnectionPool", + "AsyncHTTPProxy", + "AsyncHTTP11Connection", + "AsyncHTTP2Connection", + "AsyncConnectionInterface", + "AsyncSOCKSProxy", + # sync + "HTTPConnection", + "ConnectionPool", + "HTTPProxy", + "HTTP11Connection", + "HTTP2Connection", + "ConnectionInterface", + "SOCKSProxy", + # network backends, implementations + "SyncBackend", + "AnyIOBackend", + "TrioBackend", + # network backends, mock implementations + "AsyncMockBackend", + "AsyncMockStream", + "MockBackend", + "MockStream", + # network backends, interface + "AsyncNetworkStream", + "AsyncNetworkBackend", + "NetworkStream", + "NetworkBackend", + # util + "default_ssl_context", + "SOCKET_OPTION", + # exceptions + "ConnectionNotAvailable", + "ProxyError", + "ProtocolError", + "LocalProtocolError", + "RemoteProtocolError", + "UnsupportedProtocol", + "TimeoutException", + "PoolTimeout", + "ConnectTimeout", + "ReadTimeout", + "WriteTimeout", + "NetworkError", + "ConnectError", + "ReadError", + "WriteError", +] + +__version__ = "0.18.0" + + +__locals = locals() +for __name in __all__: + if not __name.startswith("__"): + setattr(__locals[__name], "__module__", "httpcore") # noqa diff --git a/contrib/python/httpcore/httpcore/_api.py b/contrib/python/httpcore/httpcore/_api.py new file mode 100644 index 0000000000..854235f5f6 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_api.py @@ -0,0 +1,92 @@ +from contextlib import contextmanager +from typing import Iterator, Optional, Union + +from ._models import URL, Extensions, HeaderTypes, Response +from ._sync.connection_pool import ConnectionPool + + +def request( + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterator[bytes], None] = None, + extensions: Optional[Extensions] = None, +) -> Response: + """ + Sends an HTTP request, returning the response. + + ``` + response = httpcore.request("GET", "https://www.example.com/") + ``` + + Arguments: + method: The HTTP method for the request. Typically one of `"GET"`, + `"OPTIONS"`, `"HEAD"`, `"POST"`, `"PUT"`, `"PATCH"`, or `"DELETE"`. + url: The URL of the HTTP request. Either as an instance of `httpcore.URL`, + or as str/bytes. + headers: The HTTP request headers. Either as a dictionary of str/bytes, + or as a list of two-tuples of str/bytes. + content: The content of the request body. Either as bytes, + or as a bytes iterator. + extensions: A dictionary of optional extra information included on the request. + Possible keys include `"timeout"`. + + Returns: + An instance of `httpcore.Response`. + """ + with ConnectionPool() as pool: + return pool.request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + + +@contextmanager +def stream( + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterator[bytes], None] = None, + extensions: Optional[Extensions] = None, +) -> Iterator[Response]: + """ + Sends an HTTP request, returning the response within a content manager. + + ``` + with httpcore.stream("GET", "https://www.example.com/") as response: + ... + ``` + + When using the `stream()` function, the body of the response will not be + automatically read. If you want to access the response body you should + either use `content = response.read()`, or `for chunk in response.iter_content()`. + + Arguments: + method: The HTTP method for the request. Typically one of `"GET"`, + `"OPTIONS"`, `"HEAD"`, `"POST"`, `"PUT"`, `"PATCH"`, or `"DELETE"`. + url: The URL of the HTTP request. Either as an instance of `httpcore.URL`, + or as str/bytes. + headers: The HTTP request headers. Either as a dictionary of str/bytes, + or as a list of two-tuples of str/bytes. + content: The content of the request body. Either as bytes, + or as a bytes iterator. + extensions: A dictionary of optional extra information included on the request. + Possible keys include `"timeout"`. + + Returns: + An instance of `httpcore.Response`. + """ + with ConnectionPool() as pool: + with pool.stream( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) as response: + yield response diff --git a/contrib/python/httpcore/httpcore/_async/__init__.py b/contrib/python/httpcore/httpcore/_async/__init__.py new file mode 100644 index 0000000000..88dc7f01e1 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/__init__.py @@ -0,0 +1,39 @@ +from .connection import AsyncHTTPConnection +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .http_proxy import AsyncHTTPProxy +from .interfaces import AsyncConnectionInterface + +try: + from .http2 import AsyncHTTP2Connection +except ImportError: # pragma: nocover + + class AsyncHTTP2Connection: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use http2 support, but the `h2` package is not " + "installed. Use 'pip install httpcore[http2]'." + ) + + +try: + from .socks_proxy import AsyncSOCKSProxy +except ImportError: # pragma: nocover + + class AsyncSOCKSProxy: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use SOCKS support, but the `socksio` package is not " + "installed. Use 'pip install httpcore[socks]'." + ) + + +__all__ = [ + "AsyncHTTPConnection", + "AsyncConnectionPool", + "AsyncHTTPProxy", + "AsyncHTTP11Connection", + "AsyncHTTP2Connection", + "AsyncConnectionInterface", + "AsyncSOCKSProxy", +] diff --git a/contrib/python/httpcore/httpcore/_async/connection.py b/contrib/python/httpcore/httpcore/_async/connection.py new file mode 100644 index 0000000000..45ee22a63d --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/connection.py @@ -0,0 +1,222 @@ +import itertools +import logging +import ssl +from types import TracebackType +from typing import Iterable, Iterator, Optional, Type + +from .._backends.auto import AutoBackend +from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream +from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout +from .._models import Origin, Request, Response +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from .._trace import Trace +from .http11 import AsyncHTTP11Connection +from .interfaces import AsyncConnectionInterface + +RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. + + +logger = logging.getLogger("httpcore.connection") + + +def exponential_backoff(factor: float) -> Iterator[float]: + """ + Generate a geometric sequence that has a ratio of 2 and starts with 0. + + For example: + - `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...` + - `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...` + """ + yield 0 + for n in itertools.count(): + yield factor * 2**n + + +class AsyncHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + origin: Origin, + ssl_context: Optional[ssl.SSLContext] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[AsyncNetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + self._origin = origin + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._network_backend: AsyncNetworkBackend = ( + AutoBackend() if network_backend is None else network_backend + ) + self._connection: Optional[AsyncConnectionInterface] = None + self._connect_failed: bool = False + self._request_lock = AsyncLock() + self._socket_options = socket_options + + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection to {self._origin}" + ) + + async with self._request_lock: + if self._connection is None: + try: + stream = await self._connect(request) + + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except Exception as exc: + self._connect_failed = True + raise exc + elif not self._connection.is_available(): + raise ConnectionNotAvailable() + + return await self._connection.handle_async_request(request) + + async def _connect(self, request: Request) -> AsyncNetworkStream: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + retries_left = self._retries + delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) + + while True: + try: + if self._uds is None: + kwargs = { + "host": self._origin.host.decode("ascii"), + "port": self._origin.port, + "local_address": self._local_address, + "timeout": timeout, + "socket_options": self._socket_options, + } + async with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = await self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + else: + kwargs = { + "path": self._uds, + "timeout": timeout, + "socket_options": self._socket_options, + } + async with Trace( + "connect_unix_socket", logger, request, kwargs + ) as trace: + stream = await self._network_backend.connect_unix_socket( + **kwargs + ) + trace.return_value = stream + + if self._origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + return stream + except (ConnectError, ConnectTimeout): + if retries_left <= 0: + raise + retries_left -= 1 + delay = next(delays) + async with Trace("retry", logger, request, kwargs) as trace: + await self._network_backend.sleep(delay) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + async def aclose(self) -> None: + if self._connection is not None: + async with Trace("close", logger, None, {}): + await self._connection.aclose() + + def is_available(self) -> bool: + if self._connection is None: + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> "AsyncHTTPConnection": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.aclose() diff --git a/contrib/python/httpcore/httpcore/_async/connection_pool.py b/contrib/python/httpcore/httpcore/_async/connection_pool.py new file mode 100644 index 0000000000..ddc0510e60 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/connection_pool.py @@ -0,0 +1,356 @@ +import ssl +import sys +from types import TracebackType +from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type + +from .._backends.auto import AutoBackend +from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol +from .._models import Origin, Request, Response +from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation +from .connection import AsyncHTTPConnection +from .interfaces import AsyncConnectionInterface, AsyncRequestInterface + + +class RequestStatus: + def __init__(self, request: Request): + self.request = request + self.connection: Optional[AsyncConnectionInterface] = None + self._connection_acquired = AsyncEvent() + + def set_connection(self, connection: AsyncConnectionInterface) -> None: + assert self.connection is None + self.connection = connection + self._connection_acquired.set() + + def unset_connection(self) -> None: + assert self.connection is not None + self.connection = None + self._connection_acquired = AsyncEvent() + + async def wait_for_connection( + self, timeout: Optional[float] = None + ) -> AsyncConnectionInterface: + if self.connection is None: + await self._connection_acquired.wait(timeout=timeout) + assert self.connection is not None + return self.connection + + +class AsyncConnectionPool(AsyncRequestInterface): + """ + A connection pool for making HTTP requests. + """ + + def __init__( + self, + ssl_context: Optional[ssl.SSLContext] = None, + max_connections: Optional[int] = 10, + max_keepalive_connections: Optional[int] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[AsyncNetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish a + connection. + local_address: Local address to connect from. Can also be used to connect + using a particular address family. Using `local_address="0.0.0.0"` + will connect using an `AF_INET` address (IPv4), while using + `local_address="::"` will connect using an `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + socket_options: Socket options that have to be included + in the TCP socket when the connection was established. + """ + self._ssl_context = ssl_context + + self._max_connections = ( + sys.maxsize if max_connections is None else max_connections + ) + self._max_keepalive_connections = ( + sys.maxsize + if max_keepalive_connections is None + else max_keepalive_connections + ) + self._max_keepalive_connections = min( + self._max_connections, self._max_keepalive_connections + ) + + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._pool: List[AsyncConnectionInterface] = [] + self._requests: List[RequestStatus] = [] + self._pool_lock = AsyncLock() + self._network_backend = ( + AutoBackend() if network_backend is None else network_backend + ) + self._socket_options = socket_options + + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + return AsyncHTTPConnection( + origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + retries=self._retries, + local_address=self._local_address, + uds=self._uds, + network_backend=self._network_backend, + socket_options=self._socket_options, + ) + + @property + def connections(self) -> List[AsyncConnectionInterface]: + """ + Return a list of the connections currently in the pool. + + For example: + + ```python + >>> pool.connections + [ + <AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>, + <AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> , + <AsyncHTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>, + ] + ``` + """ + return list(self._pool) + + async def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool: + """ + Attempt to provide a connection that can handle the given origin. + """ + origin = status.request.url.origin + + # If there are queued requests in front of us, then don't acquire a + # connection. We handle requests strictly in order. + waiting = [s for s in self._requests if s.connection is None] + if waiting and waiting[0] is not status: + return False + + # Reuse an existing connection if one is currently available. + for idx, connection in enumerate(self._pool): + if connection.can_handle_request(origin) and connection.is_available(): + self._pool.pop(idx) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + # If the pool is currently full, attempt to close one idle connection. + if len(self._pool) >= self._max_connections: + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle(): + await connection.aclose() + self._pool.pop(idx) + break + + # If the pool is still full, then we cannot acquire a connection. + if len(self._pool) >= self._max_connections: + return False + + # Otherwise create a new connection. + connection = self.create_connection(origin) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + async def _close_expired_connections(self) -> None: + """ + Clean up the connection pool by closing off any connections that have expired. + """ + # Close any connections that have expired their keep-alive time. + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.has_expired(): + await connection.aclose() + self._pool.pop(idx) + + # If the pool size exceeds the maximum number of allowed keep-alive connections, + # then close off idle connections as required. + pool_size = len(self._pool) + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle() and pool_size > self._max_keepalive_connections: + await connection.aclose() + self._pool.pop(idx) + pool_size -= 1 + + async def handle_async_request(self, request: Request) -> Response: + """ + Send an HTTP request, and return an HTTP response. + + This is the core implementation that is called into by `.request()` or `.stream()`. + """ + scheme = request.url.scheme.decode() + if scheme == "": + raise UnsupportedProtocol( + "Request URL is missing an 'http://' or 'https://' protocol." + ) + if scheme not in ("http", "https", "ws", "wss"): + raise UnsupportedProtocol( + f"Request URL has an unsupported protocol '{scheme}://'." + ) + + status = RequestStatus(request) + + async with self._pool_lock: + self._requests.append(status) + await self._close_expired_connections() + await self._attempt_to_acquire_connection(status) + + while True: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("pool", None) + try: + connection = await status.wait_for_connection(timeout=timeout) + except BaseException as exc: + # If we timeout here, or if the task is cancelled, then make + # sure to remove the request from the queue before bubbling + # up the exception. + async with self._pool_lock: + # Ensure only remove when task exists. + if status in self._requests: + self._requests.remove(status) + raise exc + + try: + response = await connection.handle_async_request(request) + except ConnectionNotAvailable: + # The ConnectionNotAvailable exception is a special case, that + # indicates we need to retry the request on a new connection. + # + # The most common case where this can occur is when multiple + # requests are queued waiting for a single connection, which + # might end up as an HTTP/2 connection, but which actually ends + # up as HTTP/1.1. + async with self._pool_lock: + # Maintain our position in the request queue, but reset the + # status so that the request becomes queued again. + status.unset_connection() + await self._attempt_to_acquire_connection(status) + except BaseException as exc: + with AsyncShieldCancellation(): + await self.response_closed(status) + raise exc + else: + break + + # When we return the response, we wrap the stream in a special class + # that handles notifying the connection pool once the response + # has been released. + assert isinstance(response.stream, AsyncIterable) + return Response( + status=response.status, + headers=response.headers, + content=ConnectionPoolByteStream(response.stream, self, status), + extensions=response.extensions, + ) + + async def response_closed(self, status: RequestStatus) -> None: + """ + This method acts as a callback once the request/response cycle is complete. + + It is called into from the `ConnectionPoolByteStream.aclose()` method. + """ + assert status.connection is not None + connection = status.connection + + async with self._pool_lock: + # Update the state of the connection pool. + if status in self._requests: + self._requests.remove(status) + + if connection.is_closed() and connection in self._pool: + self._pool.remove(connection) + + # Since we've had a response closed, it's possible we'll now be able + # to service one or more requests that are currently pending. + for status in self._requests: + if status.connection is None: + acquired = await self._attempt_to_acquire_connection(status) + # If we could not acquire a connection for a queued request + # then we don't need to check anymore requests that are + # queued later behind it. + if not acquired: + break + + # Housekeeping. + await self._close_expired_connections() + + async def aclose(self) -> None: + """ + Close any connections in the pool. + """ + async with self._pool_lock: + for connection in self._pool: + await connection.aclose() + self._pool = [] + self._requests = [] + + async def __aenter__(self) -> "AsyncConnectionPool": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.aclose() + + +class ConnectionPoolByteStream: + """ + A wrapper around the response byte stream, that additionally handles + notifying the connection pool when the response has been closed. + """ + + def __init__( + self, + stream: AsyncIterable[bytes], + pool: AsyncConnectionPool, + status: RequestStatus, + ) -> None: + self._stream = stream + self._pool = pool + self._status = status + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for part in self._stream: + yield part + + async def aclose(self) -> None: + try: + if hasattr(self._stream, "aclose"): + await self._stream.aclose() + finally: + with AsyncShieldCancellation(): + await self._pool.response_closed(self._status) diff --git a/contrib/python/httpcore/httpcore/_async/http11.py b/contrib/python/httpcore/httpcore/_async/http11.py new file mode 100644 index 0000000000..32fa3a6f23 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/http11.py @@ -0,0 +1,343 @@ +import enum +import logging +import time +from types import TracebackType +from typing import ( + AsyncIterable, + AsyncIterator, + List, + Optional, + Tuple, + Type, + Union, + cast, +) + +import h11 + +from .._backends.base import AsyncNetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, + WriteError, + map_exceptions, +) +from .._models import Origin, Request, Response +from .._synchronization import AsyncLock, AsyncShieldCancellation +from .._trace import Trace +from .interfaces import AsyncConnectionInterface + +logger = logging.getLogger("httpcore.http11") + + +# A subset of `h11.Event` types supported by `_send_event` +H11SendEvent = Union[ + h11.Request, + h11.Data, + h11.EndOfMessage, +] + + +class HTTPConnectionState(enum.IntEnum): + NEW = 0 + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class AsyncHTTP11Connection(AsyncConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024 + + def __init__( + self, + origin: Origin, + stream: AsyncNetworkStream, + keepalive_expiry: Optional[float] = None, + ) -> None: + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: Optional[float] = keepalive_expiry + self._expire_at: Optional[float] = None + self._state = HTTPConnectionState.NEW + self._state_lock = AsyncLock() + self._request_count = 0 + self._h11_state = h11.Connection( + our_role=h11.CLIENT, + max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, + ) + + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + async with self._state_lock: + if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): + self._request_count += 1 + self._state = HTTPConnectionState.ACTIVE + self._expire_at = None + else: + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request} + try: + async with Trace( + "send_request_headers", logger, request, kwargs + ) as trace: + await self._send_request_headers(**kwargs) + async with Trace("send_request_body", logger, request, kwargs) as trace: + await self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + + async with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + ( + http_version, + status, + reason_phrase, + headers, + ) = await self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + return Response( + status=status, + headers=headers, + content=HTTP11ConnectionByteStream(self, request), + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": self._network_stream, + }, + ) + except BaseException as exc: + with AsyncShieldCancellation(): + async with Trace("response_closed", logger, request) as trace: + await self._response_closed() + raise exc + + # Sending the request... + + async def _send_request_headers(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): + event = h11.Request( + method=request.method, + target=request.url.target, + headers=request.headers, + ) + await self._send_event(event, timeout=timeout) + + async def _send_request_body(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + assert isinstance(request.stream, AsyncIterable) + async for chunk in request.stream: + event = h11.Data(data=chunk) + await self._send_event(event, timeout=timeout) + + await self._send_event(h11.EndOfMessage(), timeout=timeout) + + async def _send_event( + self, event: h11.Event, timeout: Optional[float] = None + ) -> None: + bytes_to_send = self._h11_state.send(event) + if bytes_to_send is not None: + await self._network_stream.write(bytes_to_send, timeout=timeout) + + # Receiving the response... + + async def _receive_response_headers( + self, request: Request + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = await self._receive_event(timeout=timeout) + if isinstance(event, h11.Response): + break + if ( + isinstance(event, h11.InformationalResponse) + and event.status_code == 101 + ): + break + + http_version = b"HTTP/" + event.http_version + + # h11 version 0.11+ supports a `raw_items` interface to get the + # raw header casing, rather than the enforced lowercase headers. + headers = event.headers.raw_items() + + return http_version, event.status_code, event.reason, headers + + async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = await self._receive_event(timeout=timeout) + if isinstance(event, h11.Data): + yield bytes(event.data) + elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): + break + + async def _receive_event( + self, timeout: Optional[float] = None + ) -> Union[h11.Event, Type[h11.PAUSED]]: + while True: + with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): + event = self._h11_state.next_event() + + if event is h11.NEED_DATA: + data = await self._network_stream.read( + self.READ_NUM_BYTES, timeout=timeout + ) + + # If we feed this case through h11 we'll raise an exception like: + # + # httpcore.RemoteProtocolError: can't handle event type + # ConnectionClosed when role=SERVER and state=SEND_RESPONSE + # + # Which is accurate, but not very informative from an end-user + # perspective. Instead we handle this case distinctly and treat + # it as a ConnectError. + if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE: + msg = "Server disconnected without sending a response." + raise RemoteProtocolError(msg) + + self._h11_state.receive_data(data) + else: + # mypy fails to narrow the type in the above if statement above + return cast(Union[h11.Event, Type[h11.PAUSED]], event) + + async def _response_closed(self) -> None: + async with self._state_lock: + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + await self.aclose() + + # Once the connection is no longer required... + + async def aclose(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._state = HTTPConnectionState.CLOSED + await self._network_stream.aclose() + + # The AsyncConnectionInterface methods provide information about the state of + # the connection, allowing for a connection pooling implementation to + # determine when to reuse and when to close the connection... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + # Note that HTTP/1.1 connections in the "NEW" state are not treated as + # being "available". The control flow which created the connection will + # be able to send an outgoing request, but the connection will not be + # acquired from the connection pool for any other request. + return self._state == HTTPConnectionState.IDLE + + def has_expired(self) -> bool: + now = time.monotonic() + keepalive_expired = self._expire_at is not None and now > self._expire_at + + # If the HTTP connection is idle but the socket is readable, then the + # only valid state is that the socket is about to return b"", indicating + # a server-initiated disconnect. + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + + return keepalive_expired or server_disconnected + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/1.1, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> "AsyncHTTP11Connection": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.aclose() + + +class HTTP11ConnectionByteStream: + def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: + self._connection = connection + self._request = request + self._closed = False + + async def __aiter__(self) -> AsyncIterator[bytes]: + kwargs = {"request": self._request} + try: + async with Trace("receive_response_body", logger, self._request, kwargs): + async for chunk in self._connection._receive_response_body(**kwargs): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with AsyncShieldCancellation(): + await self.aclose() + raise exc + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + async with Trace("response_closed", logger, self._request): + await self._connection._response_closed() diff --git a/contrib/python/httpcore/httpcore/_async/http2.py b/contrib/python/httpcore/httpcore/_async/http2.py new file mode 100644 index 0000000000..8dc776ffa0 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/http2.py @@ -0,0 +1,589 @@ +import enum +import logging +import time +import types +import typing + +import h2.config +import h2.connection +import h2.events +import h2.exceptions +import h2.settings + +from .._backends.base import AsyncNetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, +) +from .._models import Origin, Request, Response +from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation +from .._trace import Trace +from .interfaces import AsyncConnectionInterface + +logger = logging.getLogger("httpcore.http2") + + +def has_body_headers(request: Request) -> bool: + return any( + k.lower() == b"content-length" or k.lower() == b"transfer-encoding" + for k, v in request.headers + ) + + +class HTTPConnectionState(enum.IntEnum): + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class AsyncHTTP2Connection(AsyncConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + CONFIG = h2.config.H2Configuration(validate_inbound_headers=False) + + def __init__( + self, + origin: Origin, + stream: AsyncNetworkStream, + keepalive_expiry: typing.Optional[float] = None, + ): + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: typing.Optional[float] = keepalive_expiry + self._h2_state = h2.connection.H2Connection(config=self.CONFIG) + self._state = HTTPConnectionState.IDLE + self._expire_at: typing.Optional[float] = None + self._request_count = 0 + self._init_lock = AsyncLock() + self._state_lock = AsyncLock() + self._read_lock = AsyncLock() + self._write_lock = AsyncLock() + self._sent_connection_init = False + self._used_all_stream_ids = False + self._connection_error = False + + # Mapping from stream ID to response stream events. + self._events: typing.Dict[ + int, + typing.Union[ + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ], + ] = {} + + # Connection terminated events are stored as state since + # we need to handle them for all streams. + self._connection_terminated: typing.Optional[ + h2.events.ConnectionTerminated + ] = None + + self._read_exception: typing.Optional[Exception] = None + self._write_exception: typing.Optional[Exception] = None + + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + # This cannot occur in normal operation, since the connection pool + # will only send requests on connections that handle them. + # It's in place simply for resilience as a guard against incorrect + # usage, for anyone working directly with httpcore connections. + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + async with self._state_lock: + if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): + self._request_count += 1 + self._expire_at = None + self._state = HTTPConnectionState.ACTIVE + else: + raise ConnectionNotAvailable() + + async with self._init_lock: + if not self._sent_connection_init: + try: + kwargs = {"request": request} + async with Trace("send_connection_init", logger, request, kwargs): + await self._send_connection_init(**kwargs) + except BaseException as exc: + with AsyncShieldCancellation(): + await self.aclose() + raise exc + + self._sent_connection_init = True + + # Initially start with just 1 until the remote server provides + # its max_concurrent_streams value + self._max_streams = 1 + + local_settings_max_streams = ( + self._h2_state.local_settings.max_concurrent_streams + ) + self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams) + + for _ in range(local_settings_max_streams - self._max_streams): + await self._max_streams_semaphore.acquire() + + await self._max_streams_semaphore.acquire() + + try: + stream_id = self._h2_state.get_next_available_stream_id() + self._events[stream_id] = [] + except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover + self._used_all_stream_ids = True + self._request_count -= 1 + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request, "stream_id": stream_id} + async with Trace("send_request_headers", logger, request, kwargs): + await self._send_request_headers(request=request, stream_id=stream_id) + async with Trace("send_request_body", logger, request, kwargs): + await self._send_request_body(request=request, stream_id=stream_id) + async with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + status, headers = await self._receive_response( + request=request, stream_id=stream_id + ) + trace.return_value = (status, headers) + + return Response( + status=status, + headers=headers, + content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), + extensions={ + "http_version": b"HTTP/2", + "network_stream": self._network_stream, + "stream_id": stream_id, + }, + ) + except BaseException as exc: # noqa: PIE786 + with AsyncShieldCancellation(): + kwargs = {"stream_id": stream_id} + async with Trace("response_closed", logger, request, kwargs): + await self._response_closed(stream_id=stream_id) + + if isinstance(exc, h2.exceptions.ProtocolError): + # One case where h2 can raise a protocol error is when a + # closed frame has been seen by the state machine. + # + # This happens when one stream is reading, and encounters + # a GOAWAY event. Other flows of control may then raise + # a protocol error at any point they interact with the 'h2_state'. + # + # In this case we'll have stored the event, and should raise + # it as a RemoteProtocolError. + if self._connection_terminated: # pragma: nocover + raise RemoteProtocolError(self._connection_terminated) + # If h2 raises a protocol error in some other state then we + # must somehow have made a protocol violation. + raise LocalProtocolError(exc) # pragma: nocover + + raise exc + + async def _send_connection_init(self, request: Request) -> None: + """ + The HTTP/2 connection requires some initial setup before we can start + using individual request/response streams on it. + """ + # Need to set these manually here instead of manipulating via + # __setitem__() otherwise the H2Connection will emit SettingsUpdate + # frames in addition to sending the undesired defaults. + self._h2_state.local_settings = h2.settings.Settings( + client=True, + initial_values={ + # Disable PUSH_PROMISE frames from the server since we don't do anything + # with them for now. Maybe when we support caching? + h2.settings.SettingCodes.ENABLE_PUSH: 0, + # These two are taken from h2 for safe defaults + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100, + h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536, + }, + ) + + # Some websites (*cough* Yahoo *cough*) balk at this setting being + # present in the initial handshake since it's not defined in the original + # RFC despite the RFC mandating ignoring settings you don't know about. + del self._h2_state.local_settings[ + h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL + ] + + self._h2_state.initiate_connection() + self._h2_state.increment_flow_control_window(2**24) + await self._write_outgoing_data(request) + + # Sending the request... + + async def _send_request_headers(self, request: Request, stream_id: int) -> None: + """ + Send the request headers to a given stream ID. + """ + end_stream = not has_body_headers(request) + + # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'. + # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require + # HTTP/1.1 style headers, and map them appropriately if we end up on + # an HTTP/2 connection. + authority = [v for k, v in request.headers if k.lower() == b"host"][0] + + headers = [ + (b":method", request.method), + (b":authority", authority), + (b":scheme", request.url.scheme), + (b":path", request.url.target), + ] + [ + (k.lower(), v) + for k, v in request.headers + if k.lower() + not in ( + b"host", + b"transfer-encoding", + ) + ] + + self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) + self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id) + await self._write_outgoing_data(request) + + async def _send_request_body(self, request: Request, stream_id: int) -> None: + """ + Iterate over the request body sending it to a given stream ID. + """ + if not has_body_headers(request): + return + + assert isinstance(request.stream, typing.AsyncIterable) + async for data in request.stream: + await self._send_stream_data(request, stream_id, data) + await self._send_end_stream(request, stream_id) + + async def _send_stream_data( + self, request: Request, stream_id: int, data: bytes + ) -> None: + """ + Send a single chunk of data in one or more data frames. + """ + while data: + max_flow = await self._wait_for_outgoing_flow(request, stream_id) + chunk_size = min(len(data), max_flow) + chunk, data = data[:chunk_size], data[chunk_size:] + self._h2_state.send_data(stream_id, chunk) + await self._write_outgoing_data(request) + + async def _send_end_stream(self, request: Request, stream_id: int) -> None: + """ + Send an empty data frame on on a given stream ID with the END_STREAM flag set. + """ + self._h2_state.end_stream(stream_id) + await self._write_outgoing_data(request) + + # Receiving the response... + + async def _receive_response( + self, request: Request, stream_id: int + ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]: + """ + Return the response status code and headers for a given stream ID. + """ + while True: + event = await self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.ResponseReceived): + break + + status_code = 200 + headers = [] + for k, v in event.headers: + if k == b":status": + status_code = int(v.decode("ascii", errors="ignore")) + elif not k.startswith(b":"): + headers.append((k, v)) + + return (status_code, headers) + + async def _receive_response_body( + self, request: Request, stream_id: int + ) -> typing.AsyncIterator[bytes]: + """ + Iterator that returns the bytes of the response body for a given stream ID. + """ + while True: + event = await self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.DataReceived): + amount = event.flow_controlled_length + self._h2_state.acknowledge_received_data(amount, stream_id) + await self._write_outgoing_data(request) + yield event.data + elif isinstance(event, h2.events.StreamEnded): + break + + async def _receive_stream_event( + self, request: Request, stream_id: int + ) -> typing.Union[ + h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded + ]: + """ + Return the next available event for a given stream ID. + + Will read more data from the network if required. + """ + while not self._events.get(stream_id): + await self._receive_events(request, stream_id) + event = self._events[stream_id].pop(0) + if isinstance(event, h2.events.StreamReset): + raise RemoteProtocolError(event) + return event + + async def _receive_events( + self, request: Request, stream_id: typing.Optional[int] = None + ) -> None: + """ + Read some data from the network until we see one or more events + for a given stream ID. + """ + async with self._read_lock: + if self._connection_terminated is not None: + last_stream_id = self._connection_terminated.last_stream_id + if stream_id and last_stream_id and stream_id > last_stream_id: + self._request_count -= 1 + raise ConnectionNotAvailable() + raise RemoteProtocolError(self._connection_terminated) + + # This conditional is a bit icky. We don't want to block reading if we've + # actually got an event to return for a given stream. We need to do that + # check *within* the atomic read lock. Though it also need to be optional, + # because when we call it from `_wait_for_outgoing_flow` we *do* want to + # block until we've available flow control, event when we have events + # pending for the stream ID we're attempting to send on. + if stream_id is None or not self._events.get(stream_id): + events = await self._read_incoming_data(request) + for event in events: + if isinstance(event, h2.events.RemoteSettingsChanged): + async with Trace( + "receive_remote_settings", logger, request + ) as trace: + await self._receive_remote_settings_change(event) + trace.return_value = event + + elif isinstance( + event, + ( + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ), + ): + if event.stream_id in self._events: + self._events[event.stream_id].append(event) + + elif isinstance(event, h2.events.ConnectionTerminated): + self._connection_terminated = event + + await self._write_outgoing_data(request) + + async def _receive_remote_settings_change(self, event: h2.events.Event) -> None: + max_concurrent_streams = event.changed_settings.get( + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS + ) + if max_concurrent_streams: + new_max_streams = min( + max_concurrent_streams.new_value, + self._h2_state.local_settings.max_concurrent_streams, + ) + if new_max_streams and new_max_streams != self._max_streams: + while new_max_streams > self._max_streams: + await self._max_streams_semaphore.release() + self._max_streams += 1 + while new_max_streams < self._max_streams: + await self._max_streams_semaphore.acquire() + self._max_streams -= 1 + + async def _response_closed(self, stream_id: int) -> None: + await self._max_streams_semaphore.release() + del self._events[stream_id] + async with self._state_lock: + if self._connection_terminated and not self._events: + await self.aclose() + + elif self._state == HTTPConnectionState.ACTIVE and not self._events: + self._state = HTTPConnectionState.IDLE + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + if self._used_all_stream_ids: # pragma: nocover + await self.aclose() + + async def aclose(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._h2_state.close_connection() + self._state = HTTPConnectionState.CLOSED + await self._network_stream.aclose() + + # Wrappers around network read/write operations... + + async def _read_incoming_data( + self, request: Request + ) -> typing.List[h2.events.Event]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + if self._read_exception is not None: + raise self._read_exception # pragma: nocover + + try: + data = await self._network_stream.read(self.READ_NUM_BYTES, timeout) + if data == b"": + raise RemoteProtocolError("Server disconnected") + except Exception as exc: + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future reads. + # (For example, this means that a single read timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._read_exception = exc + self._connection_error = True + raise exc + + events: typing.List[h2.events.Event] = self._h2_state.receive_data(data) + + return events + + async def _write_outgoing_data(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + async with self._write_lock: + data_to_send = self._h2_state.data_to_send() + + if self._write_exception is not None: + raise self._write_exception # pragma: nocover + + try: + await self._network_stream.write(data_to_send, timeout) + except Exception as exc: # pragma: nocover + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future write. + # (For example, this means that a single write timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._write_exception = exc + self._connection_error = True + raise exc + + # Flow control... + + async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int: + """ + Returns the maximum allowable outgoing flow for a given stream. + + If the allowable flow is zero, then waits on the network until + WindowUpdated frames have increased the flow rate. + https://tools.ietf.org/html/rfc7540#section-6.9 + """ + local_flow: int = self._h2_state.local_flow_control_window(stream_id) + max_frame_size: int = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + while flow == 0: + await self._receive_events(request) + local_flow = self._h2_state.local_flow_control_window(stream_id) + max_frame_size = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + return flow + + # Interface for connection pooling... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + return ( + self._state != HTTPConnectionState.CLOSED + and not self._connection_error + and not self._used_all_stream_ids + and not ( + self._h2_state.state_machine.state + == h2.connection.ConnectionState.CLOSED + ) + ) + + def has_expired(self) -> bool: + now = time.monotonic() + return self._expire_at is not None and now > self._expire_at + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/2, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> "AsyncHTTP2Connection": + return self + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[types.TracebackType] = None, + ) -> None: + await self.aclose() + + +class HTTP2ConnectionByteStream: + def __init__( + self, connection: AsyncHTTP2Connection, request: Request, stream_id: int + ) -> None: + self._connection = connection + self._request = request + self._stream_id = stream_id + self._closed = False + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + kwargs = {"request": self._request, "stream_id": self._stream_id} + try: + async with Trace("receive_response_body", logger, self._request, kwargs): + async for chunk in self._connection._receive_response_body( + request=self._request, stream_id=self._stream_id + ): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with AsyncShieldCancellation(): + await self.aclose() + raise exc + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + kwargs = {"stream_id": self._stream_id} + async with Trace("response_closed", logger, self._request, kwargs): + await self._connection._response_closed(stream_id=self._stream_id) diff --git a/contrib/python/httpcore/httpcore/_async/http_proxy.py b/contrib/python/httpcore/httpcore/_async/http_proxy.py new file mode 100644 index 0000000000..4aa7d8741a --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/http_proxy.py @@ -0,0 +1,368 @@ +import logging +import ssl +from base64 import b64encode +from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union + +from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend +from .._exceptions import ProxyError +from .._models import ( + URL, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, +) +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from .._trace import Trace +from .connection import AsyncHTTPConnection +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .interfaces import AsyncConnectionInterface + +HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]] +HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]] + + +logger = logging.getLogger("httpcore.proxy") + + +def merge_headers( + default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, + override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, +) -> List[Tuple[bytes, bytes]]: + """ + Append default_headers and override_headers, de-duplicating if a key exists + in both cases. + """ + default_headers = [] if default_headers is None else list(default_headers) + override_headers = [] if override_headers is None else list(override_headers) + has_override = set(key.lower() for key, value in override_headers) + default_headers = [ + (key, value) + for key, value in default_headers + if key.lower() not in has_override + ] + return default_headers + override_headers + + +def build_auth_header(username: bytes, password: bytes) -> bytes: + userpass = username + b":" + password + return b"Basic " + b64encode(userpass) + + +class AsyncHTTPProxy(AsyncConnectionPool): + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: Union[URL, bytes, str], + proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None, + proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, + ssl_context: Optional[ssl.SSLContext] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + max_connections: Optional[int] = 10, + max_keepalive_connections: Optional[int] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[AsyncNetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + proxy_auth: Any proxy authentication as a two-tuple of + (username, password). May be either bytes or ascii-only str. + proxy_headers: Any HTTP headers to use for the proxy requests. + For example `{"Proxy-Authorization": "Basic <username>:<password>"}`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + local_address=local_address, + uds=uds, + socket_options=socket_options, + ) + + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if ( + self._proxy_url.scheme == b"http" and proxy_ssl_context is not None + ): # pragma: no cover + raise RuntimeError( + "The `proxy_ssl_context` argument is not allowed for the http scheme" + ) + + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + if proxy_auth is not None: + username = enforce_bytes(proxy_auth[0], name="proxy_auth") + password = enforce_bytes(proxy_auth[1], name="proxy_auth") + authorization = build_auth_header(username, password) + self._proxy_headers = [ + (b"Proxy-Authorization", authorization) + ] + self._proxy_headers + + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + if origin.scheme == b"http": + return AsyncForwardHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, + proxy_ssl_context=self._proxy_ssl_context, + ) + return AsyncTunnelHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + ssl_context=self._ssl_context, + proxy_ssl_context=self._proxy_ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class AsyncForwardHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, + keepalive_expiry: Optional[float] = None, + network_backend: Optional[AsyncNetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + ) -> None: + self._connection = AsyncHTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._remote_origin = remote_origin + + async def handle_async_request(self, request: Request) -> Response: + headers = merge_headers(self._proxy_headers, request.headers) + url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=bytes(request.url), + ) + proxy_request = Request( + method=request.method, + url=url, + headers=headers, + content=request.stream, + extensions=request.extensions, + ) + return await self._connection.handle_async_request(proxy_request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + async def aclose(self) -> None: + await self._connection.aclose() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + +class AsyncTunnelHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + ssl_context: Optional[ssl.SSLContext] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + network_backend: Optional[AsyncNetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + self._connection: AsyncConnectionInterface = AsyncHTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._connect_lock = AsyncLock() + self._connected = False + + async def handle_async_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("connect", None) + + async with self._connect_lock: + if not self._connected: + target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) + + connect_url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=target, + ) + connect_headers = merge_headers( + [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers + ) + connect_request = Request( + method=b"CONNECT", + url=connect_url, + headers=connect_headers, + extensions=request.extensions, + ) + connect_response = await self._connection.handle_async_request( + connect_request + ) + + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get("reason_phrase", b"") + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + await self._connection.aclose() + raise ProxyError(msg) + + stream = connect_response.extensions["network_stream"] + + # Upgrade the stream to SSL + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + + self._connected = True + return await self._connection.handle_async_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + async def aclose(self) -> None: + await self._connection.aclose() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/contrib/python/httpcore/httpcore/_async/interfaces.py b/contrib/python/httpcore/httpcore/_async/interfaces.py new file mode 100644 index 0000000000..c998dd2763 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/interfaces.py @@ -0,0 +1,135 @@ +from contextlib import asynccontextmanager +from typing import AsyncIterator, Optional, Union + +from .._models import ( + URL, + Extensions, + HeaderTypes, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, + include_request_headers, +) + + +class AsyncRequestInterface: + async def request( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, AsyncIterator[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> Response: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = await self.handle_async_request(request) + try: + await response.aread() + finally: + await response.aclose() + return response + + @asynccontextmanager + async def stream( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, AsyncIterator[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> AsyncIterator[Response]: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = await self.handle_async_request(request) + try: + yield response + finally: + await response.aclose() + + async def handle_async_request(self, request: Request) -> Response: + raise NotImplementedError() # pragma: nocover + + +class AsyncConnectionInterface(AsyncRequestInterface): + async def aclose(self) -> None: + raise NotImplementedError() # pragma: nocover + + def info(self) -> str: + raise NotImplementedError() # pragma: nocover + + def can_handle_request(self, origin: Origin) -> bool: + raise NotImplementedError() # pragma: nocover + + def is_available(self) -> bool: + """ + Return `True` if the connection is currently able to accept an + outgoing request. + + An HTTP/1.1 connection will only be available if it is currently idle. + + An HTTP/2 connection will be available so long as the stream ID space is + not yet exhausted, and the connection is not in an error state. + + While the connection is being established we may not yet know if it is going + to result in an HTTP/1.1 or HTTP/2 connection. The connection should be + treated as being available, but might ultimately raise `NewConnectionRequired` + required exceptions if multiple requests are attempted over a connection + that ends up being established as HTTP/1.1. + """ + raise NotImplementedError() # pragma: nocover + + def has_expired(self) -> bool: + """ + Return `True` if the connection is in a state where it should be closed. + + This either means that the connection is idle and it has passed the + expiry time on its keep-alive, or that server has sent an EOF. + """ + raise NotImplementedError() # pragma: nocover + + def is_idle(self) -> bool: + """ + Return `True` if the connection is currently idle. + """ + raise NotImplementedError() # pragma: nocover + + def is_closed(self) -> bool: + """ + Return `True` if the connection has been closed. + + Used when a response is closed to determine if the connection may be + returned to the connection pool or not. + """ + raise NotImplementedError() # pragma: nocover diff --git a/contrib/python/httpcore/httpcore/_async/socks_proxy.py b/contrib/python/httpcore/httpcore/_async/socks_proxy.py new file mode 100644 index 0000000000..08a065d6d1 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_async/socks_proxy.py @@ -0,0 +1,342 @@ +import logging +import ssl +import typing + +from socksio import socks5 + +from .._backends.auto import AutoBackend +from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream +from .._exceptions import ConnectionNotAvailable, ProxyError +from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from .._trace import Trace +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .interfaces import AsyncConnectionInterface + +logger = logging.getLogger("httpcore.socks") + + +AUTH_METHODS = { + b"\x00": "NO AUTHENTICATION REQUIRED", + b"\x01": "GSSAPI", + b"\x02": "USERNAME/PASSWORD", + b"\xff": "NO ACCEPTABLE METHODS", +} + +REPLY_CODES = { + b"\x00": "Succeeded", + b"\x01": "General SOCKS server failure", + b"\x02": "Connection not allowed by ruleset", + b"\x03": "Network unreachable", + b"\x04": "Host unreachable", + b"\x05": "Connection refused", + b"\x06": "TTL expired", + b"\x07": "Command not supported", + b"\x08": "Address type not supported", +} + + +async def _init_socks5_connection( + stream: AsyncNetworkStream, + *, + host: bytes, + port: int, + auth: typing.Optional[typing.Tuple[bytes, bytes]] = None, +) -> None: + conn = socks5.SOCKS5Connection() + + # Auth method request + auth_method = ( + socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED + if auth is None + else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD + ) + conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method])) + outgoing_bytes = conn.data_to_send() + await stream.write(outgoing_bytes) + + # Auth method response + incoming_bytes = await stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5AuthReply) + if response.method != auth_method: + requested = AUTH_METHODS.get(auth_method, "UNKNOWN") + responded = AUTH_METHODS.get(response.method, "UNKNOWN") + raise ProxyError( + f"Requested {requested} from proxy server, but got {responded}." + ) + + if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD: + # Username/password request + assert auth is not None + username, password = auth + conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password)) + outgoing_bytes = conn.data_to_send() + await stream.write(outgoing_bytes) + + # Username/password response + incoming_bytes = await stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5UsernamePasswordReply) + if not response.success: + raise ProxyError("Invalid username/password") + + # Connect request + conn.send( + socks5.SOCKS5CommandRequest.from_address( + socks5.SOCKS5Command.CONNECT, (host, port) + ) + ) + outgoing_bytes = conn.data_to_send() + await stream.write(outgoing_bytes) + + # Connect response + incoming_bytes = await stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5Reply) + if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED: + reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN") + raise ProxyError(f"Proxy Server could not connect: {reply_code}.") + + +class AsyncSOCKSProxy(AsyncConnectionPool): + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: typing.Union[URL, bytes, str], + proxy_auth: typing.Optional[ + typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]] + ] = None, + ssl_context: typing.Optional[ssl.SSLContext] = None, + max_connections: typing.Optional[int] = 10, + max_keepalive_connections: typing.Optional[int] = None, + keepalive_expiry: typing.Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + network_backend: typing.Optional[AsyncNetworkBackend] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + ) + self._ssl_context = ssl_context + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if proxy_auth is not None: + username, password = proxy_auth + username_bytes = enforce_bytes(username, name="proxy_auth") + password_bytes = enforce_bytes(password, name="proxy_auth") + self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = ( + username_bytes, + password_bytes, + ) + else: + self._proxy_auth = None + + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + return AsyncSocks5Connection( + proxy_origin=self._proxy_url.origin, + remote_origin=origin, + proxy_auth=self._proxy_auth, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class AsyncSocks5Connection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None, + ssl_context: typing.Optional[ssl.SSLContext] = None, + keepalive_expiry: typing.Optional[float] = None, + http1: bool = True, + http2: bool = False, + network_backend: typing.Optional[AsyncNetworkBackend] = None, + ) -> None: + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._proxy_auth = proxy_auth + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + + self._network_backend: AsyncNetworkBackend = ( + AutoBackend() if network_backend is None else network_backend + ) + self._connect_lock = AsyncLock() + self._connection: typing.Optional[AsyncConnectionInterface] = None + self._connect_failed = False + + async def handle_async_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + async with self._connect_lock: + if self._connection is None: + try: + # Connect to the proxy + kwargs = { + "host": self._proxy_origin.host.decode("ascii"), + "port": self._proxy_origin.port, + "timeout": timeout, + } + with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = await self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + + # Connect to the remote host using socks5 + kwargs = { + "stream": stream, + "host": self._remote_origin.host.decode("ascii"), + "port": self._remote_origin.port, + "auth": self._proxy_auth, + } + with Trace( + "setup_socks5_connection", logger, request, kwargs + ) as trace: + await _init_socks5_connection(**kwargs) + trace.return_value = stream + + # Upgrade the stream to SSL + if self._remote_origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ( + ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ) + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or ( + self._http2 and not self._http1 + ): # pragma: nocover + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except Exception as exc: + self._connect_failed = True + raise exc + elif not self._connection.is_available(): # pragma: nocover + raise ConnectionNotAvailable() + + return await self._connection.handle_async_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + async def aclose(self) -> None: + if self._connection is not None: + await self._connection.aclose() + + def is_available(self) -> bool: + if self._connection is None: # pragma: nocover + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._remote_origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: # pragma: nocover + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/contrib/python/httpcore/httpcore/_backends/__init__.py b/contrib/python/httpcore/httpcore/_backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/__init__.py diff --git a/contrib/python/httpcore/httpcore/_backends/anyio.py b/contrib/python/httpcore/httpcore/_backends/anyio.py new file mode 100644 index 0000000000..1ed5228dbd --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/anyio.py @@ -0,0 +1,145 @@ +import ssl +import typing + +import anyio + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._utils import is_socket_readable +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class AnyIOStream(AsyncNetworkStream): + def __init__(self, stream: anyio.abc.ByteStream) -> None: + self._stream = stream + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + exc_map = { + TimeoutError: ReadTimeout, + anyio.BrokenResourceError: ReadError, + anyio.ClosedResourceError: ReadError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + try: + return await self._stream.receive(max_bytes=max_bytes) + except anyio.EndOfStream: # pragma: nocover + return b"" + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + if not buffer: + return + + exc_map = { + TimeoutError: WriteTimeout, + anyio.BrokenResourceError: WriteError, + anyio.ClosedResourceError: WriteError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + await self._stream.send(item=buffer) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> AsyncNetworkStream: + exc_map = { + TimeoutError: ConnectTimeout, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + try: + with anyio.fail_after(timeout): + ssl_stream = await anyio.streams.tls.TLSStream.wrap( + self._stream, + ssl_context=ssl_context, + hostname=server_hostname, + standard_compatible=False, + server_side=False, + ) + except Exception as exc: # pragma: nocover + await self.aclose() + raise exc + return AnyIOStream(ssl_stream) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object": + return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None) + if info == "client_addr": + return self._stream.extra(anyio.abc.SocketAttribute.local_address, None) + if info == "server_addr": + return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None) + if info == "socket": + return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + if info == "is_readable": + sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + return is_socket_readable(sock) + return None + + +class AnyIOBackend(AsyncNetworkBackend): + async def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + if socket_options is None: + socket_options = [] # pragma: no cover + exc_map = { + TimeoutError: ConnectTimeout, + OSError: ConnectError, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + stream: anyio.abc.ByteStream = await anyio.connect_tcp( + remote_host=host, + remote_port=port, + local_host=local_address, + ) + # By default TCP sockets opened in `asyncio` include TCP_NODELAY. + for option in socket_options: + stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return AnyIOStream(stream) + + async def connect_unix_socket( + self, + path: str, + timeout: typing.Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: # pragma: nocover + if socket_options is None: + socket_options = [] + exc_map = { + TimeoutError: ConnectTimeout, + OSError: ConnectError, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + stream: anyio.abc.ByteStream = await anyio.connect_unix(path) + for option in socket_options: + stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return AnyIOStream(stream) + + async def sleep(self, seconds: float) -> None: + await anyio.sleep(seconds) # pragma: nocover diff --git a/contrib/python/httpcore/httpcore/_backends/auto.py b/contrib/python/httpcore/httpcore/_backends/auto.py new file mode 100644 index 0000000000..b612ba071c --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/auto.py @@ -0,0 +1,52 @@ +import typing +from typing import Optional + +import sniffio + +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class AutoBackend(AsyncNetworkBackend): + async def _init_backend(self) -> None: + if not (hasattr(self, "_backend")): + backend = sniffio.current_async_library() + if backend == "trio": + from .trio import TrioBackend + + self._backend: AsyncNetworkBackend = TrioBackend() + else: + from .anyio import AnyIOBackend + + self._backend = AnyIOBackend() + + async def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + await self._init_backend() + return await self._backend.connect_tcp( + host, + port, + timeout=timeout, + local_address=local_address, + socket_options=socket_options, + ) + + async def connect_unix_socket( + self, + path: str, + timeout: Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: # pragma: nocover + await self._init_backend() + return await self._backend.connect_unix_socket( + path, timeout=timeout, socket_options=socket_options + ) + + async def sleep(self, seconds: float) -> None: # pragma: nocover + await self._init_backend() + return await self._backend.sleep(seconds) diff --git a/contrib/python/httpcore/httpcore/_backends/base.py b/contrib/python/httpcore/httpcore/_backends/base.py new file mode 100644 index 0000000000..6cadedb5f9 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/base.py @@ -0,0 +1,103 @@ +import ssl +import time +import typing + +SOCKET_OPTION = typing.Union[ + typing.Tuple[int, int, int], + typing.Tuple[int, int, typing.Union[bytes, bytearray]], + typing.Tuple[int, int, None, int], +] + + +class NetworkStream: + def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: + raise NotImplementedError() # pragma: nocover + + def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: + raise NotImplementedError() # pragma: nocover + + def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> "NetworkStream": + raise NotImplementedError() # pragma: nocover + + def get_extra_info(self, info: str) -> typing.Any: + return None # pragma: nocover + + +class NetworkBackend: + def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + raise NotImplementedError() # pragma: nocover + + def connect_unix_socket( + self, + path: str, + timeout: typing.Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + raise NotImplementedError() # pragma: nocover + + def sleep(self, seconds: float) -> None: + time.sleep(seconds) # pragma: nocover + + +class AsyncNetworkStream: + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + raise NotImplementedError() # pragma: nocover + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + raise NotImplementedError() # pragma: nocover + + async def aclose(self) -> None: + raise NotImplementedError() # pragma: nocover + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> "AsyncNetworkStream": + raise NotImplementedError() # pragma: nocover + + def get_extra_info(self, info: str) -> typing.Any: + return None # pragma: nocover + + +class AsyncNetworkBackend: + async def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + raise NotImplementedError() # pragma: nocover + + async def connect_unix_socket( + self, + path: str, + timeout: typing.Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + raise NotImplementedError() # pragma: nocover + + async def sleep(self, seconds: float) -> None: + raise NotImplementedError() # pragma: nocover diff --git a/contrib/python/httpcore/httpcore/_backends/mock.py b/contrib/python/httpcore/httpcore/_backends/mock.py new file mode 100644 index 0000000000..f7aefebf51 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/mock.py @@ -0,0 +1,142 @@ +import ssl +import typing +from typing import Optional + +from .._exceptions import ReadError +from .base import ( + SOCKET_OPTION, + AsyncNetworkBackend, + AsyncNetworkStream, + NetworkBackend, + NetworkStream, +) + + +class MockSSLObject: + def __init__(self, http2: bool): + self._http2 = http2 + + def selected_alpn_protocol(self) -> str: + return "h2" if self._http2 else "http/1.1" + + +class MockStream(NetworkStream): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + self._closed = False + + def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: + if self._closed: + raise ReadError("Connection closed") + if not self._buffer: + return b"" + return self._buffer.pop(0) + + def write(self, buffer: bytes, timeout: Optional[float] = None) -> None: + pass + + def close(self) -> None: + self._closed = True + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: Optional[str] = None, + timeout: Optional[float] = None, + ) -> NetworkStream: + return self + + def get_extra_info(self, info: str) -> typing.Any: + return MockSSLObject(http2=self._http2) if info == "ssl_object" else None + + def __repr__(self) -> str: + return "<httpcore.MockStream>" + + +class MockBackend(NetworkBackend): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + + def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + return MockStream(list(self._buffer), http2=self._http2) + + def connect_unix_socket( + self, + path: str, + timeout: Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + return MockStream(list(self._buffer), http2=self._http2) + + def sleep(self, seconds: float) -> None: + pass + + +class AsyncMockStream(AsyncNetworkStream): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + self._closed = False + + async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: + if self._closed: + raise ReadError("Connection closed") + if not self._buffer: + return b"" + return self._buffer.pop(0) + + async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None: + pass + + async def aclose(self) -> None: + self._closed = True + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: Optional[str] = None, + timeout: Optional[float] = None, + ) -> AsyncNetworkStream: + return self + + def get_extra_info(self, info: str) -> typing.Any: + return MockSSLObject(http2=self._http2) if info == "ssl_object" else None + + def __repr__(self) -> str: + return "<httpcore.AsyncMockStream>" + + +class AsyncMockBackend(AsyncNetworkBackend): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + + async def connect_tcp( + self, + host: str, + port: int, + timeout: Optional[float] = None, + local_address: Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + return AsyncMockStream(list(self._buffer), http2=self._http2) + + async def connect_unix_socket( + self, + path: str, + timeout: Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + return AsyncMockStream(list(self._buffer), http2=self._http2) + + async def sleep(self, seconds: float) -> None: + pass diff --git a/contrib/python/httpcore/httpcore/_backends/sync.py b/contrib/python/httpcore/httpcore/_backends/sync.py new file mode 100644 index 0000000000..f2dbd32afa --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/sync.py @@ -0,0 +1,245 @@ +import socket +import ssl +import sys +import typing +from functools import partial + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ExceptionMapping, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._utils import is_socket_readable +from .base import SOCKET_OPTION, NetworkBackend, NetworkStream + + +class TLSinTLSStream(NetworkStream): # pragma: no cover + """ + Because the standard `SSLContext.wrap_socket` method does + not work for `SSLSocket` objects, we need this class + to implement TLS stream using an underlying `SSLObject` + instance in order to support TLS on top of TLS. + """ + + # Defined in RFC 8449 + TLS_RECORD_SIZE = 16384 + + def __init__( + self, + sock: socket.socket, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ): + self._sock = sock + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + + self.ssl_obj = ssl_context.wrap_bio( + incoming=self._incoming, + outgoing=self._outgoing, + server_hostname=server_hostname, + ) + + self._sock.settimeout(timeout) + self._perform_io(self.ssl_obj.do_handshake) + + def _perform_io( + self, + func: typing.Callable[..., typing.Any], + ) -> typing.Any: + ret = None + + while True: + errno = None + try: + ret = func() + except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: + errno = e.errno + + self._sock.sendall(self._outgoing.read()) + + if errno == ssl.SSL_ERROR_WANT_READ: + buf = self._sock.recv(self.TLS_RECORD_SIZE) + + if buf: + self._incoming.write(buf) + else: + self._incoming.write_eof() + if errno is None: + return ret + + def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: + exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + return typing.cast( + bytes, self._perform_io(partial(self.ssl_obj.read, max_bytes)) + ) + + def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: + exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + while buffer: + nsent = self._perform_io(partial(self.ssl_obj.write, buffer)) + buffer = buffer[nsent:] + + def close(self) -> None: + self._sock.close() + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> "NetworkStream": + raise NotImplementedError() + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object": + return self.ssl_obj + if info == "client_addr": + return self._sock.getsockname() + if info == "server_addr": + return self._sock.getpeername() + if info == "socket": + return self._sock + if info == "is_readable": + return is_socket_readable(self._sock) + return None + + +class SyncStream(NetworkStream): + def __init__(self, sock: socket.socket) -> None: + self._sock = sock + + def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: + exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + return self._sock.recv(max_bytes) + + def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: + if not buffer: + return + + exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} + with map_exceptions(exc_map): + while buffer: + self._sock.settimeout(timeout) + n = self._sock.send(buffer) + buffer = buffer[n:] + + def close(self) -> None: + self._sock.close() + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> NetworkStream: + if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover + raise RuntimeError( + "Attempted to add a TLS layer on top of the existing " + "TLS stream, which is not supported by httpcore package" + ) + + exc_map: ExceptionMapping = { + socket.timeout: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + try: + if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover + # If the underlying socket has already been upgraded + # to the TLS layer (i.e. is an instance of SSLSocket), + # we need some additional smarts to support TLS-in-TLS. + return TLSinTLSStream( + self._sock, ssl_context, server_hostname, timeout + ) + else: + self._sock.settimeout(timeout) + sock = ssl_context.wrap_socket( + self._sock, server_hostname=server_hostname + ) + except Exception as exc: # pragma: nocover + self.close() + raise exc + return SyncStream(sock) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket): + return self._sock._sslobj # type: ignore + if info == "client_addr": + return self._sock.getsockname() + if info == "server_addr": + return self._sock.getpeername() + if info == "socket": + return self._sock + if info == "is_readable": + return is_socket_readable(self._sock) + return None + + +class SyncBackend(NetworkBackend): + def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + # Note that we automatically include `TCP_NODELAY` + # in addition to any other custom socket options. + if socket_options is None: + socket_options = [] # pragma: no cover + address = (host, port) + source_address = None if local_address is None else (local_address, 0) + exc_map: ExceptionMapping = { + socket.timeout: ConnectTimeout, + OSError: ConnectError, + } + + with map_exceptions(exc_map): + sock = socket.create_connection( + address, + timeout, + source_address=source_address, + ) + for option in socket_options: + sock.setsockopt(*option) # pragma: no cover + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return SyncStream(sock) + + def connect_unix_socket( + self, + path: str, + timeout: typing.Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: # pragma: nocover + if sys.platform == "win32": + raise RuntimeError( + "Attempted to connect to a UNIX socket on a Windows system." + ) + if socket_options is None: + socket_options = [] + + exc_map: ExceptionMapping = { + socket.timeout: ConnectTimeout, + OSError: ConnectError, + } + with map_exceptions(exc_map): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + for option in socket_options: + sock.setsockopt(*option) + sock.settimeout(timeout) + sock.connect(path) + return SyncStream(sock) diff --git a/contrib/python/httpcore/httpcore/_backends/trio.py b/contrib/python/httpcore/httpcore/_backends/trio.py new file mode 100644 index 0000000000..b1626d28e2 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_backends/trio.py @@ -0,0 +1,161 @@ +import ssl +import typing + +import trio + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ExceptionMapping, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream + + +class TrioStream(AsyncNetworkStream): + def __init__(self, stream: trio.abc.Stream) -> None: + self._stream = stream + + async def read( + self, max_bytes: int, timeout: typing.Optional[float] = None + ) -> bytes: + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ReadTimeout, + trio.BrokenResourceError: ReadError, + trio.ClosedResourceError: ReadError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + data: bytes = await self._stream.receive_some(max_bytes=max_bytes) + return data + + async def write( + self, buffer: bytes, timeout: typing.Optional[float] = None + ) -> None: + if not buffer: + return + + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: WriteTimeout, + trio.BrokenResourceError: WriteError, + trio.ClosedResourceError: WriteError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + await self._stream.send_all(data=buffer) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> AsyncNetworkStream: + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + } + ssl_stream = trio.SSLStream( + self._stream, + ssl_context=ssl_context, + server_hostname=server_hostname, + https_compatible=True, + server_side=False, + ) + with map_exceptions(exc_map): + try: + with trio.fail_after(timeout_or_inf): + await ssl_stream.do_handshake() + except Exception as exc: # pragma: nocover + await self.aclose() + raise exc + return TrioStream(ssl_stream) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object" and isinstance(self._stream, trio.SSLStream): + # Type checkers cannot see `_ssl_object` attribute because trio._ssl.SSLStream uses __getattr__/__setattr__. + # Tracked at https://github.com/python-trio/trio/issues/542 + return self._stream._ssl_object # type: ignore[attr-defined] + if info == "client_addr": + return self._get_socket_stream().socket.getsockname() + if info == "server_addr": + return self._get_socket_stream().socket.getpeername() + if info == "socket": + stream = self._stream + while isinstance(stream, trio.SSLStream): + stream = stream.transport_stream + assert isinstance(stream, trio.SocketStream) + return stream.socket + if info == "is_readable": + socket = self.get_extra_info("socket") + return socket.is_readable() + return None + + def _get_socket_stream(self) -> trio.SocketStream: + stream = self._stream + while isinstance(stream, trio.SSLStream): + stream = stream.transport_stream + assert isinstance(stream, trio.SocketStream) + return stream + + +class TrioBackend(AsyncNetworkBackend): + async def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + # By default for TCP sockets, trio enables TCP_NODELAY. + # https://trio.readthedocs.io/en/stable/reference-io.html#trio.SocketStream + if socket_options is None: + socket_options = [] # pragma: no cover + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + OSError: ConnectError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + stream: trio.abc.Stream = await trio.open_tcp_stream( + host=host, port=port, local_address=local_address + ) + for option in socket_options: + stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return TrioStream(stream) + + async def connect_unix_socket( + self, + path: str, + timeout: typing.Optional[float] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: # pragma: nocover + if socket_options is None: + socket_options = [] + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map: ExceptionMapping = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + OSError: ConnectError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + stream: trio.abc.Stream = await trio.open_unix_socket(path) + for option in socket_options: + stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + return TrioStream(stream) + + async def sleep(self, seconds: float) -> None: + await trio.sleep(seconds) # pragma: nocover diff --git a/contrib/python/httpcore/httpcore/_exceptions.py b/contrib/python/httpcore/httpcore/_exceptions.py new file mode 100644 index 0000000000..81e7fc61dd --- /dev/null +++ b/contrib/python/httpcore/httpcore/_exceptions.py @@ -0,0 +1,81 @@ +import contextlib +from typing import Iterator, Mapping, Type + +ExceptionMapping = Mapping[Type[Exception], Type[Exception]] + + +@contextlib.contextmanager +def map_exceptions(map: ExceptionMapping) -> Iterator[None]: + try: + yield + except Exception as exc: # noqa: PIE786 + for from_exc, to_exc in map.items(): + if isinstance(exc, from_exc): + raise to_exc(exc) from exc + raise # pragma: nocover + + +class ConnectionNotAvailable(Exception): + pass + + +class ProxyError(Exception): + pass + + +class UnsupportedProtocol(Exception): + pass + + +class ProtocolError(Exception): + pass + + +class RemoteProtocolError(ProtocolError): + pass + + +class LocalProtocolError(ProtocolError): + pass + + +# Timeout errors + + +class TimeoutException(Exception): + pass + + +class PoolTimeout(TimeoutException): + pass + + +class ConnectTimeout(TimeoutException): + pass + + +class ReadTimeout(TimeoutException): + pass + + +class WriteTimeout(TimeoutException): + pass + + +# Network errors + + +class NetworkError(Exception): + pass + + +class ConnectError(NetworkError): + pass + + +class ReadError(NetworkError): + pass + + +class WriteError(NetworkError): + pass diff --git a/contrib/python/httpcore/httpcore/_models.py b/contrib/python/httpcore/httpcore/_models.py new file mode 100644 index 0000000000..11bfcd84f0 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_models.py @@ -0,0 +1,484 @@ +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, + Union, +) +from urllib.parse import urlparse + +# Functions for typechecking... + + +HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]] +HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]] +HeaderTypes = Union[HeadersAsSequence, HeadersAsMapping, None] + +Extensions = MutableMapping[str, Any] + + +def enforce_bytes(value: Union[bytes, str], *, name: str) -> bytes: + """ + Any arguments that are ultimately represented as bytes can be specified + either as bytes or as strings. + + However we enforce that any string arguments must only contain characters in + the plain ASCII range. chr(0)...chr(127). If you need to use characters + outside that range then be precise, and use a byte-wise argument. + """ + if isinstance(value, str): + try: + return value.encode("ascii") + except UnicodeEncodeError: + raise TypeError(f"{name} strings may not include unicode characters.") + elif isinstance(value, bytes): + return value + + seen_type = type(value).__name__ + raise TypeError(f"{name} must be bytes or str, but got {seen_type}.") + + +def enforce_url(value: Union["URL", bytes, str], *, name: str) -> "URL": + """ + Type check for URL parameters. + """ + if isinstance(value, (bytes, str)): + return URL(value) + elif isinstance(value, URL): + return value + + seen_type = type(value).__name__ + raise TypeError(f"{name} must be a URL, bytes, or str, but got {seen_type}.") + + +def enforce_headers( + value: Union[HeadersAsMapping, HeadersAsSequence, None] = None, *, name: str +) -> List[Tuple[bytes, bytes]]: + """ + Convienence function that ensure all items in request or response headers + are either bytes or strings in the plain ASCII range. + """ + if value is None: + return [] + elif isinstance(value, Mapping): + return [ + ( + enforce_bytes(k, name="header name"), + enforce_bytes(v, name="header value"), + ) + for k, v in value.items() + ] + elif isinstance(value, Sequence): + return [ + ( + enforce_bytes(k, name="header name"), + enforce_bytes(v, name="header value"), + ) + for k, v in value + ] + + seen_type = type(value).__name__ + raise TypeError( + f"{name} must be a mapping or sequence of two-tuples, but got {seen_type}." + ) + + +def enforce_stream( + value: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None], *, name: str +) -> Union[Iterable[bytes], AsyncIterable[bytes]]: + if value is None: + return ByteStream(b"") + elif isinstance(value, bytes): + return ByteStream(value) + return value + + +# * https://tools.ietf.org/html/rfc3986#section-3.2.3 +# * https://url.spec.whatwg.org/#url-miscellaneous +# * https://url.spec.whatwg.org/#scheme-state +DEFAULT_PORTS = { + b"ftp": 21, + b"http": 80, + b"https": 443, + b"ws": 80, + b"wss": 443, +} + + +def include_request_headers( + headers: List[Tuple[bytes, bytes]], + *, + url: "URL", + content: Union[None, bytes, Iterable[bytes], AsyncIterable[bytes]], +) -> List[Tuple[bytes, bytes]]: + headers_set = set(k.lower() for k, v in headers) + + if b"host" not in headers_set: + default_port = DEFAULT_PORTS.get(url.scheme) + if url.port is None or url.port == default_port: + header_value = url.host + else: + header_value = b"%b:%d" % (url.host, url.port) + headers = [(b"Host", header_value)] + headers + + if ( + content is not None + and b"content-length" not in headers_set + and b"transfer-encoding" not in headers_set + ): + if isinstance(content, bytes): + content_length = str(len(content)).encode("ascii") + headers += [(b"Content-Length", content_length)] + else: + headers += [(b"Transfer-Encoding", b"chunked")] # pragma: nocover + + return headers + + +# Interfaces for byte streams... + + +class ByteStream: + """ + A container for non-streaming content, and that supports both sync and async + stream iteration. + """ + + def __init__(self, content: bytes) -> None: + self._content = content + + def __iter__(self) -> Iterator[bytes]: + yield self._content + + async def __aiter__(self) -> AsyncIterator[bytes]: + yield self._content + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{len(self._content)} bytes]>" + + +class Origin: + def __init__(self, scheme: bytes, host: bytes, port: int) -> None: + self.scheme = scheme + self.host = host + self.port = port + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, Origin) + and self.scheme == other.scheme + and self.host == other.host + and self.port == other.port + ) + + def __str__(self) -> str: + scheme = self.scheme.decode("ascii") + host = self.host.decode("ascii") + port = str(self.port) + return f"{scheme}://{host}:{port}" + + +class URL: + """ + Represents the URL against which an HTTP request may be made. + + The URL may either be specified as a plain string, for convienence: + + ```python + url = httpcore.URL("https://www.example.com/") + ``` + + Or be constructed with explicitily pre-parsed components: + + ```python + url = httpcore.URL(scheme=b'https', host=b'www.example.com', port=None, target=b'/') + ``` + + Using this second more explicit style allows integrations that are using + `httpcore` to pass through URLs that have already been parsed in order to use + libraries such as `rfc-3986` rather than relying on the stdlib. It also ensures + that URL parsing is treated identically at both the networking level and at any + higher layers of abstraction. + + The four components are important here, as they allow the URL to be precisely + specified in a pre-parsed format. They also allow certain types of request to + be created that could not otherwise be expressed. + + For example, an HTTP request to `http://www.example.com/` forwarded via a proxy + at `http://localhost:8080`... + + ```python + # Constructs an HTTP request with a complete URL as the target: + # GET https://www.example.com/ HTTP/1.1 + url = httpcore.URL( + scheme=b'http', + host=b'localhost', + port=8080, + target=b'https://www.example.com/' + ) + request = httpcore.Request( + method="GET", + url=url + ) + ``` + + Another example is constructing an `OPTIONS *` request... + + ```python + # Constructs an 'OPTIONS *' HTTP request: + # OPTIONS * HTTP/1.1 + url = httpcore.URL(scheme=b'https', host=b'www.example.com', target=b'*') + request = httpcore.Request(method="OPTIONS", url=url) + ``` + + This kind of request is not possible to formulate with a URL string, + because the `/` delimiter is always used to demark the target from the + host/port portion of the URL. + + For convenience, string-like arguments may be specified either as strings or + as bytes. However, once a request is being issue over-the-wire, the URL + components are always ultimately required to be a bytewise representation. + + In order to avoid any ambiguity over character encodings, when strings are used + as arguments, they must be strictly limited to the ASCII range `chr(0)`-`chr(127)`. + If you require a bytewise representation that is outside this range you must + handle the character encoding directly, and pass a bytes instance. + """ + + def __init__( + self, + url: Union[bytes, str] = "", + *, + scheme: Union[bytes, str] = b"", + host: Union[bytes, str] = b"", + port: Optional[int] = None, + target: Union[bytes, str] = b"", + ) -> None: + """ + Parameters: + url: The complete URL as a string or bytes. + scheme: The URL scheme as a string or bytes. + Typically either `"http"` or `"https"`. + host: The URL host as a string or bytes. Such as `"www.example.com"`. + port: The port to connect to. Either an integer or `None`. + target: The target of the HTTP request. Such as `"/items?search=red"`. + """ + if url: + parsed = urlparse(enforce_bytes(url, name="url")) + self.scheme = parsed.scheme + self.host = parsed.hostname or b"" + self.port = parsed.port + self.target = (parsed.path or b"/") + ( + b"?" + parsed.query if parsed.query else b"" + ) + else: + self.scheme = enforce_bytes(scheme, name="scheme") + self.host = enforce_bytes(host, name="host") + self.port = port + self.target = enforce_bytes(target, name="target") + + @property + def origin(self) -> Origin: + default_port = { + b"http": 80, + b"https": 443, + b"ws": 80, + b"wss": 443, + b"socks5": 1080, + }[self.scheme] + return Origin( + scheme=self.scheme, host=self.host, port=self.port or default_port + ) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, URL) + and other.scheme == self.scheme + and other.host == self.host + and other.port == self.port + and other.target == self.target + ) + + def __bytes__(self) -> bytes: + if self.port is None: + return b"%b://%b%b" % (self.scheme, self.host, self.target) + return b"%b://%b:%d%b" % (self.scheme, self.host, self.port, self.target) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(scheme={self.scheme!r}, " + f"host={self.host!r}, port={self.port!r}, target={self.target!r})" + ) + + +class Request: + """ + An HTTP request. + """ + + def __init__( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> None: + """ + Parameters: + method: The HTTP request method, either as a string or bytes. + For example: `GET`. + url: The request URL, either as a `URL` instance, or as a string or bytes. + For example: `"https://www.example.com".` + headers: The HTTP request headers. + content: The content of the response body. + extensions: A dictionary of optional extra information included on + the request. Possible keys include `"timeout"`, and `"trace"`. + """ + self.method: bytes = enforce_bytes(method, name="method") + self.url: URL = enforce_url(url, name="url") + self.headers: List[Tuple[bytes, bytes]] = enforce_headers( + headers, name="headers" + ) + self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream( + content, name="content" + ) + self.extensions = {} if extensions is None else extensions + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.method!r}]>" + + +class Response: + """ + An HTTP response. + """ + + def __init__( + self, + status: int, + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> None: + """ + Parameters: + status: The HTTP status code of the response. For example `200`. + headers: The HTTP response headers. + content: The content of the response body. + extensions: A dictionary of optional extra information included on + the responseself.Possible keys include `"http_version"`, + `"reason_phrase"`, and `"network_stream"`. + """ + self.status: int = status + self.headers: List[Tuple[bytes, bytes]] = enforce_headers( + headers, name="headers" + ) + self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream( + content, name="content" + ) + self.extensions = {} if extensions is None else extensions + + self._stream_consumed = False + + @property + def content(self) -> bytes: + if not hasattr(self, "_content"): + if isinstance(self.stream, Iterable): + raise RuntimeError( + "Attempted to access 'response.content' on a streaming response. " + "Call 'response.read()' first." + ) + else: + raise RuntimeError( + "Attempted to access 'response.content' on a streaming response. " + "Call 'await response.aread()' first." + ) + return self._content + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.status}]>" + + # Sync interface... + + def read(self) -> bytes: + if not isinstance(self.stream, Iterable): # pragma: nocover + raise RuntimeError( + "Attempted to read an asynchronous response using 'response.read()'. " + "You should use 'await response.aread()' instead." + ) + if not hasattr(self, "_content"): + self._content = b"".join([part for part in self.iter_stream()]) + return self._content + + def iter_stream(self) -> Iterator[bytes]: + if not isinstance(self.stream, Iterable): # pragma: nocover + raise RuntimeError( + "Attempted to stream an asynchronous response using 'for ... in " + "response.iter_stream()'. " + "You should use 'async for ... in response.aiter_stream()' instead." + ) + if self._stream_consumed: + raise RuntimeError( + "Attempted to call 'for ... in response.iter_stream()' more than once." + ) + self._stream_consumed = True + for chunk in self.stream: + yield chunk + + def close(self) -> None: + if not isinstance(self.stream, Iterable): # pragma: nocover + raise RuntimeError( + "Attempted to close an asynchronous response using 'response.close()'. " + "You should use 'await response.aclose()' instead." + ) + if hasattr(self.stream, "close"): + self.stream.close() + + # Async interface... + + async def aread(self) -> bytes: + if not isinstance(self.stream, AsyncIterable): # pragma: nocover + raise RuntimeError( + "Attempted to read an synchronous response using " + "'await response.aread()'. " + "You should use 'response.read()' instead." + ) + if not hasattr(self, "_content"): + self._content = b"".join([part async for part in self.aiter_stream()]) + return self._content + + async def aiter_stream(self) -> AsyncIterator[bytes]: + if not isinstance(self.stream, AsyncIterable): # pragma: nocover + raise RuntimeError( + "Attempted to stream an synchronous response using 'async for ... in " + "response.aiter_stream()'. " + "You should use 'for ... in response.iter_stream()' instead." + ) + if self._stream_consumed: + raise RuntimeError( + "Attempted to call 'async for ... in response.aiter_stream()' " + "more than once." + ) + self._stream_consumed = True + async for chunk in self.stream: + yield chunk + + async def aclose(self) -> None: + if not isinstance(self.stream, AsyncIterable): # pragma: nocover + raise RuntimeError( + "Attempted to close a synchronous response using " + "'await response.aclose()'. " + "You should use 'response.close()' instead." + ) + if hasattr(self.stream, "aclose"): + await self.stream.aclose() diff --git a/contrib/python/httpcore/httpcore/_ssl.py b/contrib/python/httpcore/httpcore/_ssl.py new file mode 100644 index 0000000000..c99c5a6794 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_ssl.py @@ -0,0 +1,9 @@ +import ssl + +import certifi + + +def default_ssl_context() -> ssl.SSLContext: + context = ssl.create_default_context() + context.load_verify_locations(certifi.where()) + return context diff --git a/contrib/python/httpcore/httpcore/_sync/__init__.py b/contrib/python/httpcore/httpcore/_sync/__init__.py new file mode 100644 index 0000000000..b476d76d9a --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/__init__.py @@ -0,0 +1,39 @@ +from .connection import HTTPConnection +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .http_proxy import HTTPProxy +from .interfaces import ConnectionInterface + +try: + from .http2 import HTTP2Connection +except ImportError: # pragma: nocover + + class HTTP2Connection: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use http2 support, but the `h2` package is not " + "installed. Use 'pip install httpcore[http2]'." + ) + + +try: + from .socks_proxy import SOCKSProxy +except ImportError: # pragma: nocover + + class SOCKSProxy: # type: ignore + def __init__(self, *args, **kwargs) -> None: # type: ignore + raise RuntimeError( + "Attempted to use SOCKS support, but the `socksio` package is not " + "installed. Use 'pip install httpcore[socks]'." + ) + + +__all__ = [ + "HTTPConnection", + "ConnectionPool", + "HTTPProxy", + "HTTP11Connection", + "HTTP2Connection", + "ConnectionInterface", + "SOCKSProxy", +] diff --git a/contrib/python/httpcore/httpcore/_sync/connection.py b/contrib/python/httpcore/httpcore/_sync/connection.py new file mode 100644 index 0000000000..81e4172a21 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/connection.py @@ -0,0 +1,222 @@ +import itertools +import logging +import ssl +from types import TracebackType +from typing import Iterable, Iterator, Optional, Type + +from .._backends.sync import SyncBackend +from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream +from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout +from .._models import Origin, Request, Response +from .._ssl import default_ssl_context +from .._synchronization import Lock +from .._trace import Trace +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface + +RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. + + +logger = logging.getLogger("httpcore.connection") + + +def exponential_backoff(factor: float) -> Iterator[float]: + """ + Generate a geometric sequence that has a ratio of 2 and starts with 0. + + For example: + - `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...` + - `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...` + """ + yield 0 + for n in itertools.count(): + yield factor * 2**n + + +class HTTPConnection(ConnectionInterface): + def __init__( + self, + origin: Origin, + ssl_context: Optional[ssl.SSLContext] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[NetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + self._origin = origin + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._network_backend: NetworkBackend = ( + SyncBackend() if network_backend is None else network_backend + ) + self._connection: Optional[ConnectionInterface] = None + self._connect_failed: bool = False + self._request_lock = Lock() + self._socket_options = socket_options + + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection to {self._origin}" + ) + + with self._request_lock: + if self._connection is None: + try: + stream = self._connect(request) + + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except Exception as exc: + self._connect_failed = True + raise exc + elif not self._connection.is_available(): + raise ConnectionNotAvailable() + + return self._connection.handle_request(request) + + def _connect(self, request: Request) -> NetworkStream: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + retries_left = self._retries + delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) + + while True: + try: + if self._uds is None: + kwargs = { + "host": self._origin.host.decode("ascii"), + "port": self._origin.port, + "local_address": self._local_address, + "timeout": timeout, + "socket_options": self._socket_options, + } + with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + else: + kwargs = { + "path": self._uds, + "timeout": timeout, + "socket_options": self._socket_options, + } + with Trace( + "connect_unix_socket", logger, request, kwargs + ) as trace: + stream = self._network_backend.connect_unix_socket( + **kwargs + ) + trace.return_value = stream + + if self._origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + return stream + except (ConnectError, ConnectTimeout): + if retries_left <= 0: + raise + retries_left -= 1 + delay = next(delays) + with Trace("retry", logger, request, kwargs) as trace: + self._network_backend.sleep(delay) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def close(self) -> None: + if self._connection is not None: + with Trace("close", logger, None, {}): + self._connection.close() + + def is_available(self) -> bool: + if self._connection is None: + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> "HTTPConnection": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + self.close() diff --git a/contrib/python/httpcore/httpcore/_sync/connection_pool.py b/contrib/python/httpcore/httpcore/_sync/connection_pool.py new file mode 100644 index 0000000000..dbcaff1fcf --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/connection_pool.py @@ -0,0 +1,356 @@ +import ssl +import sys +from types import TracebackType +from typing import Iterable, Iterator, Iterable, List, Optional, Type + +from .._backends.sync import SyncBackend +from .._backends.base import SOCKET_OPTION, NetworkBackend +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol +from .._models import Origin, Request, Response +from .._synchronization import Event, Lock, ShieldCancellation +from .connection import HTTPConnection +from .interfaces import ConnectionInterface, RequestInterface + + +class RequestStatus: + def __init__(self, request: Request): + self.request = request + self.connection: Optional[ConnectionInterface] = None + self._connection_acquired = Event() + + def set_connection(self, connection: ConnectionInterface) -> None: + assert self.connection is None + self.connection = connection + self._connection_acquired.set() + + def unset_connection(self) -> None: + assert self.connection is not None + self.connection = None + self._connection_acquired = Event() + + def wait_for_connection( + self, timeout: Optional[float] = None + ) -> ConnectionInterface: + if self.connection is None: + self._connection_acquired.wait(timeout=timeout) + assert self.connection is not None + return self.connection + + +class ConnectionPool(RequestInterface): + """ + A connection pool for making HTTP requests. + """ + + def __init__( + self, + ssl_context: Optional[ssl.SSLContext] = None, + max_connections: Optional[int] = 10, + max_keepalive_connections: Optional[int] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[NetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish a + connection. + local_address: Local address to connect from. Can also be used to connect + using a particular address family. Using `local_address="0.0.0.0"` + will connect using an `AF_INET` address (IPv4), while using + `local_address="::"` will connect using an `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + socket_options: Socket options that have to be included + in the TCP socket when the connection was established. + """ + self._ssl_context = ssl_context + + self._max_connections = ( + sys.maxsize if max_connections is None else max_connections + ) + self._max_keepalive_connections = ( + sys.maxsize + if max_keepalive_connections is None + else max_keepalive_connections + ) + self._max_keepalive_connections = min( + self._max_connections, self._max_keepalive_connections + ) + + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._retries = retries + self._local_address = local_address + self._uds = uds + + self._pool: List[ConnectionInterface] = [] + self._requests: List[RequestStatus] = [] + self._pool_lock = Lock() + self._network_backend = ( + SyncBackend() if network_backend is None else network_backend + ) + self._socket_options = socket_options + + def create_connection(self, origin: Origin) -> ConnectionInterface: + return HTTPConnection( + origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + retries=self._retries, + local_address=self._local_address, + uds=self._uds, + network_backend=self._network_backend, + socket_options=self._socket_options, + ) + + @property + def connections(self) -> List[ConnectionInterface]: + """ + Return a list of the connections currently in the pool. + + For example: + + ```python + >>> pool.connections + [ + <HTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>, + <HTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> , + <HTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>, + ] + ``` + """ + return list(self._pool) + + def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool: + """ + Attempt to provide a connection that can handle the given origin. + """ + origin = status.request.url.origin + + # If there are queued requests in front of us, then don't acquire a + # connection. We handle requests strictly in order. + waiting = [s for s in self._requests if s.connection is None] + if waiting and waiting[0] is not status: + return False + + # Reuse an existing connection if one is currently available. + for idx, connection in enumerate(self._pool): + if connection.can_handle_request(origin) and connection.is_available(): + self._pool.pop(idx) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + # If the pool is currently full, attempt to close one idle connection. + if len(self._pool) >= self._max_connections: + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle(): + connection.close() + self._pool.pop(idx) + break + + # If the pool is still full, then we cannot acquire a connection. + if len(self._pool) >= self._max_connections: + return False + + # Otherwise create a new connection. + connection = self.create_connection(origin) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + def _close_expired_connections(self) -> None: + """ + Clean up the connection pool by closing off any connections that have expired. + """ + # Close any connections that have expired their keep-alive time. + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.has_expired(): + connection.close() + self._pool.pop(idx) + + # If the pool size exceeds the maximum number of allowed keep-alive connections, + # then close off idle connections as required. + pool_size = len(self._pool) + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle() and pool_size > self._max_keepalive_connections: + connection.close() + self._pool.pop(idx) + pool_size -= 1 + + def handle_request(self, request: Request) -> Response: + """ + Send an HTTP request, and return an HTTP response. + + This is the core implementation that is called into by `.request()` or `.stream()`. + """ + scheme = request.url.scheme.decode() + if scheme == "": + raise UnsupportedProtocol( + "Request URL is missing an 'http://' or 'https://' protocol." + ) + if scheme not in ("http", "https", "ws", "wss"): + raise UnsupportedProtocol( + f"Request URL has an unsupported protocol '{scheme}://'." + ) + + status = RequestStatus(request) + + with self._pool_lock: + self._requests.append(status) + self._close_expired_connections() + self._attempt_to_acquire_connection(status) + + while True: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("pool", None) + try: + connection = status.wait_for_connection(timeout=timeout) + except BaseException as exc: + # If we timeout here, or if the task is cancelled, then make + # sure to remove the request from the queue before bubbling + # up the exception. + with self._pool_lock: + # Ensure only remove when task exists. + if status in self._requests: + self._requests.remove(status) + raise exc + + try: + response = connection.handle_request(request) + except ConnectionNotAvailable: + # The ConnectionNotAvailable exception is a special case, that + # indicates we need to retry the request on a new connection. + # + # The most common case where this can occur is when multiple + # requests are queued waiting for a single connection, which + # might end up as an HTTP/2 connection, but which actually ends + # up as HTTP/1.1. + with self._pool_lock: + # Maintain our position in the request queue, but reset the + # status so that the request becomes queued again. + status.unset_connection() + self._attempt_to_acquire_connection(status) + except BaseException as exc: + with ShieldCancellation(): + self.response_closed(status) + raise exc + else: + break + + # When we return the response, we wrap the stream in a special class + # that handles notifying the connection pool once the response + # has been released. + assert isinstance(response.stream, Iterable) + return Response( + status=response.status, + headers=response.headers, + content=ConnectionPoolByteStream(response.stream, self, status), + extensions=response.extensions, + ) + + def response_closed(self, status: RequestStatus) -> None: + """ + This method acts as a callback once the request/response cycle is complete. + + It is called into from the `ConnectionPoolByteStream.close()` method. + """ + assert status.connection is not None + connection = status.connection + + with self._pool_lock: + # Update the state of the connection pool. + if status in self._requests: + self._requests.remove(status) + + if connection.is_closed() and connection in self._pool: + self._pool.remove(connection) + + # Since we've had a response closed, it's possible we'll now be able + # to service one or more requests that are currently pending. + for status in self._requests: + if status.connection is None: + acquired = self._attempt_to_acquire_connection(status) + # If we could not acquire a connection for a queued request + # then we don't need to check anymore requests that are + # queued later behind it. + if not acquired: + break + + # Housekeeping. + self._close_expired_connections() + + def close(self) -> None: + """ + Close any connections in the pool. + """ + with self._pool_lock: + for connection in self._pool: + connection.close() + self._pool = [] + self._requests = [] + + def __enter__(self) -> "ConnectionPool": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + self.close() + + +class ConnectionPoolByteStream: + """ + A wrapper around the response byte stream, that additionally handles + notifying the connection pool when the response has been closed. + """ + + def __init__( + self, + stream: Iterable[bytes], + pool: ConnectionPool, + status: RequestStatus, + ) -> None: + self._stream = stream + self._pool = pool + self._status = status + + def __iter__(self) -> Iterator[bytes]: + for part in self._stream: + yield part + + def close(self) -> None: + try: + if hasattr(self._stream, "close"): + self._stream.close() + finally: + with ShieldCancellation(): + self._pool.response_closed(self._status) diff --git a/contrib/python/httpcore/httpcore/_sync/http11.py b/contrib/python/httpcore/httpcore/_sync/http11.py new file mode 100644 index 0000000000..0cc100e3ff --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/http11.py @@ -0,0 +1,343 @@ +import enum +import logging +import time +from types import TracebackType +from typing import ( + Iterable, + Iterator, + List, + Optional, + Tuple, + Type, + Union, + cast, +) + +import h11 + +from .._backends.base import NetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, + WriteError, + map_exceptions, +) +from .._models import Origin, Request, Response +from .._synchronization import Lock, ShieldCancellation +from .._trace import Trace +from .interfaces import ConnectionInterface + +logger = logging.getLogger("httpcore.http11") + + +# A subset of `h11.Event` types supported by `_send_event` +H11SendEvent = Union[ + h11.Request, + h11.Data, + h11.EndOfMessage, +] + + +class HTTPConnectionState(enum.IntEnum): + NEW = 0 + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class HTTP11Connection(ConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024 + + def __init__( + self, + origin: Origin, + stream: NetworkStream, + keepalive_expiry: Optional[float] = None, + ) -> None: + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: Optional[float] = keepalive_expiry + self._expire_at: Optional[float] = None + self._state = HTTPConnectionState.NEW + self._state_lock = Lock() + self._request_count = 0 + self._h11_state = h11.Connection( + our_role=h11.CLIENT, + max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, + ) + + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + with self._state_lock: + if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): + self._request_count += 1 + self._state = HTTPConnectionState.ACTIVE + self._expire_at = None + else: + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request} + try: + with Trace( + "send_request_headers", logger, request, kwargs + ) as trace: + self._send_request_headers(**kwargs) + with Trace("send_request_body", logger, request, kwargs) as trace: + self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + + with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + ( + http_version, + status, + reason_phrase, + headers, + ) = self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + return Response( + status=status, + headers=headers, + content=HTTP11ConnectionByteStream(self, request), + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": self._network_stream, + }, + ) + except BaseException as exc: + with ShieldCancellation(): + with Trace("response_closed", logger, request) as trace: + self._response_closed() + raise exc + + # Sending the request... + + def _send_request_headers(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): + event = h11.Request( + method=request.method, + target=request.url.target, + headers=request.headers, + ) + self._send_event(event, timeout=timeout) + + def _send_request_body(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + assert isinstance(request.stream, Iterable) + for chunk in request.stream: + event = h11.Data(data=chunk) + self._send_event(event, timeout=timeout) + + self._send_event(h11.EndOfMessage(), timeout=timeout) + + def _send_event( + self, event: h11.Event, timeout: Optional[float] = None + ) -> None: + bytes_to_send = self._h11_state.send(event) + if bytes_to_send is not None: + self._network_stream.write(bytes_to_send, timeout=timeout) + + # Receiving the response... + + def _receive_response_headers( + self, request: Request + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = self._receive_event(timeout=timeout) + if isinstance(event, h11.Response): + break + if ( + isinstance(event, h11.InformationalResponse) + and event.status_code == 101 + ): + break + + http_version = b"HTTP/" + event.http_version + + # h11 version 0.11+ supports a `raw_items` interface to get the + # raw header casing, rather than the enforced lowercase headers. + headers = event.headers.raw_items() + + return http_version, event.status_code, event.reason, headers + + def _receive_response_body(self, request: Request) -> Iterator[bytes]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + while True: + event = self._receive_event(timeout=timeout) + if isinstance(event, h11.Data): + yield bytes(event.data) + elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): + break + + def _receive_event( + self, timeout: Optional[float] = None + ) -> Union[h11.Event, Type[h11.PAUSED]]: + while True: + with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): + event = self._h11_state.next_event() + + if event is h11.NEED_DATA: + data = self._network_stream.read( + self.READ_NUM_BYTES, timeout=timeout + ) + + # If we feed this case through h11 we'll raise an exception like: + # + # httpcore.RemoteProtocolError: can't handle event type + # ConnectionClosed when role=SERVER and state=SEND_RESPONSE + # + # Which is accurate, but not very informative from an end-user + # perspective. Instead we handle this case distinctly and treat + # it as a ConnectError. + if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE: + msg = "Server disconnected without sending a response." + raise RemoteProtocolError(msg) + + self._h11_state.receive_data(data) + else: + # mypy fails to narrow the type in the above if statement above + return cast(Union[h11.Event, Type[h11.PAUSED]], event) + + def _response_closed(self) -> None: + with self._state_lock: + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + self.close() + + # Once the connection is no longer required... + + def close(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._state = HTTPConnectionState.CLOSED + self._network_stream.close() + + # The ConnectionInterface methods provide information about the state of + # the connection, allowing for a connection pooling implementation to + # determine when to reuse and when to close the connection... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + # Note that HTTP/1.1 connections in the "NEW" state are not treated as + # being "available". The control flow which created the connection will + # be able to send an outgoing request, but the connection will not be + # acquired from the connection pool for any other request. + return self._state == HTTPConnectionState.IDLE + + def has_expired(self) -> bool: + now = time.monotonic() + keepalive_expired = self._expire_at is not None and now > self._expire_at + + # If the HTTP connection is idle but the socket is readable, then the + # only valid state is that the socket is about to return b"", indicating + # a server-initiated disconnect. + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + + return keepalive_expired or server_disconnected + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/1.1, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> "HTTP11Connection": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + self.close() + + +class HTTP11ConnectionByteStream: + def __init__(self, connection: HTTP11Connection, request: Request) -> None: + self._connection = connection + self._request = request + self._closed = False + + def __iter__(self) -> Iterator[bytes]: + kwargs = {"request": self._request} + try: + with Trace("receive_response_body", logger, self._request, kwargs): + for chunk in self._connection._receive_response_body(**kwargs): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with ShieldCancellation(): + self.close() + raise exc + + def close(self) -> None: + if not self._closed: + self._closed = True + with Trace("response_closed", logger, self._request): + self._connection._response_closed() diff --git a/contrib/python/httpcore/httpcore/_sync/http2.py b/contrib/python/httpcore/httpcore/_sync/http2.py new file mode 100644 index 0000000000..d141d459a5 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/http2.py @@ -0,0 +1,589 @@ +import enum +import logging +import time +import types +import typing + +import h2.config +import h2.connection +import h2.events +import h2.exceptions +import h2.settings + +from .._backends.base import NetworkStream +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, +) +from .._models import Origin, Request, Response +from .._synchronization import Lock, Semaphore, ShieldCancellation +from .._trace import Trace +from .interfaces import ConnectionInterface + +logger = logging.getLogger("httpcore.http2") + + +def has_body_headers(request: Request) -> bool: + return any( + k.lower() == b"content-length" or k.lower() == b"transfer-encoding" + for k, v in request.headers + ) + + +class HTTPConnectionState(enum.IntEnum): + ACTIVE = 1 + IDLE = 2 + CLOSED = 3 + + +class HTTP2Connection(ConnectionInterface): + READ_NUM_BYTES = 64 * 1024 + CONFIG = h2.config.H2Configuration(validate_inbound_headers=False) + + def __init__( + self, + origin: Origin, + stream: NetworkStream, + keepalive_expiry: typing.Optional[float] = None, + ): + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: typing.Optional[float] = keepalive_expiry + self._h2_state = h2.connection.H2Connection(config=self.CONFIG) + self._state = HTTPConnectionState.IDLE + self._expire_at: typing.Optional[float] = None + self._request_count = 0 + self._init_lock = Lock() + self._state_lock = Lock() + self._read_lock = Lock() + self._write_lock = Lock() + self._sent_connection_init = False + self._used_all_stream_ids = False + self._connection_error = False + + # Mapping from stream ID to response stream events. + self._events: typing.Dict[ + int, + typing.Union[ + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ], + ] = {} + + # Connection terminated events are stored as state since + # we need to handle them for all streams. + self._connection_terminated: typing.Optional[ + h2.events.ConnectionTerminated + ] = None + + self._read_exception: typing.Optional[Exception] = None + self._write_exception: typing.Optional[Exception] = None + + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + # This cannot occur in normal operation, since the connection pool + # will only send requests on connections that handle them. + # It's in place simply for resilience as a guard against incorrect + # usage, for anyone working directly with httpcore connections. + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + with self._state_lock: + if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): + self._request_count += 1 + self._expire_at = None + self._state = HTTPConnectionState.ACTIVE + else: + raise ConnectionNotAvailable() + + with self._init_lock: + if not self._sent_connection_init: + try: + kwargs = {"request": request} + with Trace("send_connection_init", logger, request, kwargs): + self._send_connection_init(**kwargs) + except BaseException as exc: + with ShieldCancellation(): + self.close() + raise exc + + self._sent_connection_init = True + + # Initially start with just 1 until the remote server provides + # its max_concurrent_streams value + self._max_streams = 1 + + local_settings_max_streams = ( + self._h2_state.local_settings.max_concurrent_streams + ) + self._max_streams_semaphore = Semaphore(local_settings_max_streams) + + for _ in range(local_settings_max_streams - self._max_streams): + self._max_streams_semaphore.acquire() + + self._max_streams_semaphore.acquire() + + try: + stream_id = self._h2_state.get_next_available_stream_id() + self._events[stream_id] = [] + except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover + self._used_all_stream_ids = True + self._request_count -= 1 + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request, "stream_id": stream_id} + with Trace("send_request_headers", logger, request, kwargs): + self._send_request_headers(request=request, stream_id=stream_id) + with Trace("send_request_body", logger, request, kwargs): + self._send_request_body(request=request, stream_id=stream_id) + with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + status, headers = self._receive_response( + request=request, stream_id=stream_id + ) + trace.return_value = (status, headers) + + return Response( + status=status, + headers=headers, + content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), + extensions={ + "http_version": b"HTTP/2", + "network_stream": self._network_stream, + "stream_id": stream_id, + }, + ) + except BaseException as exc: # noqa: PIE786 + with ShieldCancellation(): + kwargs = {"stream_id": stream_id} + with Trace("response_closed", logger, request, kwargs): + self._response_closed(stream_id=stream_id) + + if isinstance(exc, h2.exceptions.ProtocolError): + # One case where h2 can raise a protocol error is when a + # closed frame has been seen by the state machine. + # + # This happens when one stream is reading, and encounters + # a GOAWAY event. Other flows of control may then raise + # a protocol error at any point they interact with the 'h2_state'. + # + # In this case we'll have stored the event, and should raise + # it as a RemoteProtocolError. + if self._connection_terminated: # pragma: nocover + raise RemoteProtocolError(self._connection_terminated) + # If h2 raises a protocol error in some other state then we + # must somehow have made a protocol violation. + raise LocalProtocolError(exc) # pragma: nocover + + raise exc + + def _send_connection_init(self, request: Request) -> None: + """ + The HTTP/2 connection requires some initial setup before we can start + using individual request/response streams on it. + """ + # Need to set these manually here instead of manipulating via + # __setitem__() otherwise the H2Connection will emit SettingsUpdate + # frames in addition to sending the undesired defaults. + self._h2_state.local_settings = h2.settings.Settings( + client=True, + initial_values={ + # Disable PUSH_PROMISE frames from the server since we don't do anything + # with them for now. Maybe when we support caching? + h2.settings.SettingCodes.ENABLE_PUSH: 0, + # These two are taken from h2 for safe defaults + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100, + h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536, + }, + ) + + # Some websites (*cough* Yahoo *cough*) balk at this setting being + # present in the initial handshake since it's not defined in the original + # RFC despite the RFC mandating ignoring settings you don't know about. + del self._h2_state.local_settings[ + h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL + ] + + self._h2_state.initiate_connection() + self._h2_state.increment_flow_control_window(2**24) + self._write_outgoing_data(request) + + # Sending the request... + + def _send_request_headers(self, request: Request, stream_id: int) -> None: + """ + Send the request headers to a given stream ID. + """ + end_stream = not has_body_headers(request) + + # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'. + # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require + # HTTP/1.1 style headers, and map them appropriately if we end up on + # an HTTP/2 connection. + authority = [v for k, v in request.headers if k.lower() == b"host"][0] + + headers = [ + (b":method", request.method), + (b":authority", authority), + (b":scheme", request.url.scheme), + (b":path", request.url.target), + ] + [ + (k.lower(), v) + for k, v in request.headers + if k.lower() + not in ( + b"host", + b"transfer-encoding", + ) + ] + + self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) + self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id) + self._write_outgoing_data(request) + + def _send_request_body(self, request: Request, stream_id: int) -> None: + """ + Iterate over the request body sending it to a given stream ID. + """ + if not has_body_headers(request): + return + + assert isinstance(request.stream, typing.Iterable) + for data in request.stream: + self._send_stream_data(request, stream_id, data) + self._send_end_stream(request, stream_id) + + def _send_stream_data( + self, request: Request, stream_id: int, data: bytes + ) -> None: + """ + Send a single chunk of data in one or more data frames. + """ + while data: + max_flow = self._wait_for_outgoing_flow(request, stream_id) + chunk_size = min(len(data), max_flow) + chunk, data = data[:chunk_size], data[chunk_size:] + self._h2_state.send_data(stream_id, chunk) + self._write_outgoing_data(request) + + def _send_end_stream(self, request: Request, stream_id: int) -> None: + """ + Send an empty data frame on on a given stream ID with the END_STREAM flag set. + """ + self._h2_state.end_stream(stream_id) + self._write_outgoing_data(request) + + # Receiving the response... + + def _receive_response( + self, request: Request, stream_id: int + ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]: + """ + Return the response status code and headers for a given stream ID. + """ + while True: + event = self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.ResponseReceived): + break + + status_code = 200 + headers = [] + for k, v in event.headers: + if k == b":status": + status_code = int(v.decode("ascii", errors="ignore")) + elif not k.startswith(b":"): + headers.append((k, v)) + + return (status_code, headers) + + def _receive_response_body( + self, request: Request, stream_id: int + ) -> typing.Iterator[bytes]: + """ + Iterator that returns the bytes of the response body for a given stream ID. + """ + while True: + event = self._receive_stream_event(request, stream_id) + if isinstance(event, h2.events.DataReceived): + amount = event.flow_controlled_length + self._h2_state.acknowledge_received_data(amount, stream_id) + self._write_outgoing_data(request) + yield event.data + elif isinstance(event, h2.events.StreamEnded): + break + + def _receive_stream_event( + self, request: Request, stream_id: int + ) -> typing.Union[ + h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded + ]: + """ + Return the next available event for a given stream ID. + + Will read more data from the network if required. + """ + while not self._events.get(stream_id): + self._receive_events(request, stream_id) + event = self._events[stream_id].pop(0) + if isinstance(event, h2.events.StreamReset): + raise RemoteProtocolError(event) + return event + + def _receive_events( + self, request: Request, stream_id: typing.Optional[int] = None + ) -> None: + """ + Read some data from the network until we see one or more events + for a given stream ID. + """ + with self._read_lock: + if self._connection_terminated is not None: + last_stream_id = self._connection_terminated.last_stream_id + if stream_id and last_stream_id and stream_id > last_stream_id: + self._request_count -= 1 + raise ConnectionNotAvailable() + raise RemoteProtocolError(self._connection_terminated) + + # This conditional is a bit icky. We don't want to block reading if we've + # actually got an event to return for a given stream. We need to do that + # check *within* the atomic read lock. Though it also need to be optional, + # because when we call it from `_wait_for_outgoing_flow` we *do* want to + # block until we've available flow control, event when we have events + # pending for the stream ID we're attempting to send on. + if stream_id is None or not self._events.get(stream_id): + events = self._read_incoming_data(request) + for event in events: + if isinstance(event, h2.events.RemoteSettingsChanged): + with Trace( + "receive_remote_settings", logger, request + ) as trace: + self._receive_remote_settings_change(event) + trace.return_value = event + + elif isinstance( + event, + ( + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ), + ): + if event.stream_id in self._events: + self._events[event.stream_id].append(event) + + elif isinstance(event, h2.events.ConnectionTerminated): + self._connection_terminated = event + + self._write_outgoing_data(request) + + def _receive_remote_settings_change(self, event: h2.events.Event) -> None: + max_concurrent_streams = event.changed_settings.get( + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS + ) + if max_concurrent_streams: + new_max_streams = min( + max_concurrent_streams.new_value, + self._h2_state.local_settings.max_concurrent_streams, + ) + if new_max_streams and new_max_streams != self._max_streams: + while new_max_streams > self._max_streams: + self._max_streams_semaphore.release() + self._max_streams += 1 + while new_max_streams < self._max_streams: + self._max_streams_semaphore.acquire() + self._max_streams -= 1 + + def _response_closed(self, stream_id: int) -> None: + self._max_streams_semaphore.release() + del self._events[stream_id] + with self._state_lock: + if self._connection_terminated and not self._events: + self.close() + + elif self._state == HTTPConnectionState.ACTIVE and not self._events: + self._state = HTTPConnectionState.IDLE + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + if self._used_all_stream_ids: # pragma: nocover + self.close() + + def close(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._h2_state.close_connection() + self._state = HTTPConnectionState.CLOSED + self._network_stream.close() + + # Wrappers around network read/write operations... + + def _read_incoming_data( + self, request: Request + ) -> typing.List[h2.events.Event]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + if self._read_exception is not None: + raise self._read_exception # pragma: nocover + + try: + data = self._network_stream.read(self.READ_NUM_BYTES, timeout) + if data == b"": + raise RemoteProtocolError("Server disconnected") + except Exception as exc: + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future reads. + # (For example, this means that a single read timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._read_exception = exc + self._connection_error = True + raise exc + + events: typing.List[h2.events.Event] = self._h2_state.receive_data(data) + + return events + + def _write_outgoing_data(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + with self._write_lock: + data_to_send = self._h2_state.data_to_send() + + if self._write_exception is not None: + raise self._write_exception # pragma: nocover + + try: + self._network_stream.write(data_to_send, timeout) + except Exception as exc: # pragma: nocover + # If we get a network error we should: + # + # 1. Save the exception and just raise it immediately on any future write. + # (For example, this means that a single write timeout or disconnect will + # immediately close all pending streams. Without requiring multiple + # sequential timeouts.) + # 2. Mark the connection as errored, so that we don't accept any other + # incoming requests. + self._write_exception = exc + self._connection_error = True + raise exc + + # Flow control... + + def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int: + """ + Returns the maximum allowable outgoing flow for a given stream. + + If the allowable flow is zero, then waits on the network until + WindowUpdated frames have increased the flow rate. + https://tools.ietf.org/html/rfc7540#section-6.9 + """ + local_flow: int = self._h2_state.local_flow_control_window(stream_id) + max_frame_size: int = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + while flow == 0: + self._receive_events(request) + local_flow = self._h2_state.local_flow_control_window(stream_id) + max_frame_size = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + return flow + + # Interface for connection pooling... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + return ( + self._state != HTTPConnectionState.CLOSED + and not self._connection_error + and not self._used_all_stream_ids + and not ( + self._h2_state.state_machine.state + == h2.connection.ConnectionState.CLOSED + ) + ) + + def has_expired(self) -> bool: + now = time.monotonic() + return self._expire_at is not None and now > self._expire_at + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/2, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> "HTTP2Connection": + return self + + def __exit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[types.TracebackType] = None, + ) -> None: + self.close() + + +class HTTP2ConnectionByteStream: + def __init__( + self, connection: HTTP2Connection, request: Request, stream_id: int + ) -> None: + self._connection = connection + self._request = request + self._stream_id = stream_id + self._closed = False + + def __iter__(self) -> typing.Iterator[bytes]: + kwargs = {"request": self._request, "stream_id": self._stream_id} + try: + with Trace("receive_response_body", logger, self._request, kwargs): + for chunk in self._connection._receive_response_body( + request=self._request, stream_id=self._stream_id + ): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with ShieldCancellation(): + self.close() + raise exc + + def close(self) -> None: + if not self._closed: + self._closed = True + kwargs = {"stream_id": self._stream_id} + with Trace("response_closed", logger, self._request, kwargs): + self._connection._response_closed(stream_id=self._stream_id) diff --git a/contrib/python/httpcore/httpcore/_sync/http_proxy.py b/contrib/python/httpcore/httpcore/_sync/http_proxy.py new file mode 100644 index 0000000000..6acac9a7cd --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/http_proxy.py @@ -0,0 +1,368 @@ +import logging +import ssl +from base64 import b64encode +from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union + +from .._backends.base import SOCKET_OPTION, NetworkBackend +from .._exceptions import ProxyError +from .._models import ( + URL, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, +) +from .._ssl import default_ssl_context +from .._synchronization import Lock +from .._trace import Trace +from .connection import HTTPConnection +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface + +HeadersAsSequence = Sequence[Tuple[Union[bytes, str], Union[bytes, str]]] +HeadersAsMapping = Mapping[Union[bytes, str], Union[bytes, str]] + + +logger = logging.getLogger("httpcore.proxy") + + +def merge_headers( + default_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, + override_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, +) -> List[Tuple[bytes, bytes]]: + """ + Append default_headers and override_headers, de-duplicating if a key exists + in both cases. + """ + default_headers = [] if default_headers is None else list(default_headers) + override_headers = [] if override_headers is None else list(override_headers) + has_override = set(key.lower() for key, value in override_headers) + default_headers = [ + (key, value) + for key, value in default_headers + if key.lower() not in has_override + ] + return default_headers + override_headers + + +def build_auth_header(username: bytes, password: bytes) -> bytes: + userpass = username + b":" + password + return b"Basic " + b64encode(userpass) + + +class HTTPProxy(ConnectionPool): + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: Union[URL, bytes, str], + proxy_auth: Optional[Tuple[Union[bytes, str], Union[bytes, str]]] = None, + proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, + ssl_context: Optional[ssl.SSLContext] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + max_connections: Optional[int] = 10, + max_keepalive_connections: Optional[int] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + local_address: Optional[str] = None, + uds: Optional[str] = None, + network_backend: Optional[NetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + proxy_auth: Any proxy authentication as a two-tuple of + (username, password). May be either bytes or ascii-only str. + proxy_headers: Any HTTP headers to use for the proxy requests. + For example `{"Proxy-Authorization": "Basic <username>:<password>"}`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + local_address=local_address, + uds=uds, + socket_options=socket_options, + ) + + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if ( + self._proxy_url.scheme == b"http" and proxy_ssl_context is not None + ): # pragma: no cover + raise RuntimeError( + "The `proxy_ssl_context` argument is not allowed for the http scheme" + ) + + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + if proxy_auth is not None: + username = enforce_bytes(proxy_auth[0], name="proxy_auth") + password = enforce_bytes(proxy_auth[1], name="proxy_auth") + authorization = build_auth_header(username, password) + self._proxy_headers = [ + (b"Proxy-Authorization", authorization) + ] + self._proxy_headers + + def create_connection(self, origin: Origin) -> ConnectionInterface: + if origin.scheme == b"http": + return ForwardHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, + proxy_ssl_context=self._proxy_ssl_context, + ) + return TunnelHTTPConnection( + proxy_origin=self._proxy_url.origin, + proxy_headers=self._proxy_headers, + remote_origin=origin, + ssl_context=self._ssl_context, + proxy_ssl_context=self._proxy_ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class ForwardHTTPConnection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_headers: Union[HeadersAsMapping, HeadersAsSequence, None] = None, + keepalive_expiry: Optional[float] = None, + network_backend: Optional[NetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + ) -> None: + self._connection = HTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._remote_origin = remote_origin + + def handle_request(self, request: Request) -> Response: + headers = merge_headers(self._proxy_headers, request.headers) + url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=bytes(request.url), + ) + proxy_request = Request( + method=request.method, + url=url, + headers=headers, + content=request.stream, + extensions=request.extensions, + ) + return self._connection.handle_request(proxy_request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + def close(self) -> None: + self._connection.close() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + +class TunnelHTTPConnection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + ssl_context: Optional[ssl.SSLContext] = None, + proxy_ssl_context: Optional[ssl.SSLContext] = None, + proxy_headers: Optional[Sequence[Tuple[bytes, bytes]]] = None, + keepalive_expiry: Optional[float] = None, + http1: bool = True, + http2: bool = False, + network_backend: Optional[NetworkBackend] = None, + socket_options: Optional[Iterable[SOCKET_OPTION]] = None, + ) -> None: + self._connection: ConnectionInterface = HTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + socket_options=socket_options, + ssl_context=proxy_ssl_context, + ) + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._ssl_context = ssl_context + self._proxy_ssl_context = proxy_ssl_context + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + self._connect_lock = Lock() + self._connected = False + + def handle_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("connect", None) + + with self._connect_lock: + if not self._connected: + target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) + + connect_url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=target, + ) + connect_headers = merge_headers( + [(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers + ) + connect_request = Request( + method=b"CONNECT", + url=connect_url, + headers=connect_headers, + extensions=request.extensions, + ) + connect_response = self._connection.handle_request( + connect_request + ) + + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get("reason_phrase", b"") + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + self._connection.close() + raise ProxyError(msg) + + stream = connect_response.extensions["network_stream"] + + # Upgrade the stream to SSL + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + + self._connected = True + return self._connection.handle_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + def close(self) -> None: + self._connection.close() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/contrib/python/httpcore/httpcore/_sync/interfaces.py b/contrib/python/httpcore/httpcore/_sync/interfaces.py new file mode 100644 index 0000000000..5e95be1ec7 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/interfaces.py @@ -0,0 +1,135 @@ +from contextlib import contextmanager +from typing import Iterator, Optional, Union + +from .._models import ( + URL, + Extensions, + HeaderTypes, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, + include_request_headers, +) + + +class RequestInterface: + def request( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterator[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> Response: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = self.handle_request(request) + try: + response.read() + finally: + response.close() + return response + + @contextmanager + def stream( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: HeaderTypes = None, + content: Union[bytes, Iterator[bytes], None] = None, + extensions: Optional[Extensions] = None, + ) -> Iterator[Response]: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = self.handle_request(request) + try: + yield response + finally: + response.close() + + def handle_request(self, request: Request) -> Response: + raise NotImplementedError() # pragma: nocover + + +class ConnectionInterface(RequestInterface): + def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + def info(self) -> str: + raise NotImplementedError() # pragma: nocover + + def can_handle_request(self, origin: Origin) -> bool: + raise NotImplementedError() # pragma: nocover + + def is_available(self) -> bool: + """ + Return `True` if the connection is currently able to accept an + outgoing request. + + An HTTP/1.1 connection will only be available if it is currently idle. + + An HTTP/2 connection will be available so long as the stream ID space is + not yet exhausted, and the connection is not in an error state. + + While the connection is being established we may not yet know if it is going + to result in an HTTP/1.1 or HTTP/2 connection. The connection should be + treated as being available, but might ultimately raise `NewConnectionRequired` + required exceptions if multiple requests are attempted over a connection + that ends up being established as HTTP/1.1. + """ + raise NotImplementedError() # pragma: nocover + + def has_expired(self) -> bool: + """ + Return `True` if the connection is in a state where it should be closed. + + This either means that the connection is idle and it has passed the + expiry time on its keep-alive, or that server has sent an EOF. + """ + raise NotImplementedError() # pragma: nocover + + def is_idle(self) -> bool: + """ + Return `True` if the connection is currently idle. + """ + raise NotImplementedError() # pragma: nocover + + def is_closed(self) -> bool: + """ + Return `True` if the connection has been closed. + + Used when a response is closed to determine if the connection may be + returned to the connection pool or not. + """ + raise NotImplementedError() # pragma: nocover diff --git a/contrib/python/httpcore/httpcore/_sync/socks_proxy.py b/contrib/python/httpcore/httpcore/_sync/socks_proxy.py new file mode 100644 index 0000000000..502e4d7fef --- /dev/null +++ b/contrib/python/httpcore/httpcore/_sync/socks_proxy.py @@ -0,0 +1,342 @@ +import logging +import ssl +import typing + +from socksio import socks5 + +from .._backends.sync import SyncBackend +from .._backends.base import NetworkBackend, NetworkStream +from .._exceptions import ConnectionNotAvailable, ProxyError +from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url +from .._ssl import default_ssl_context +from .._synchronization import Lock +from .._trace import Trace +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface + +logger = logging.getLogger("httpcore.socks") + + +AUTH_METHODS = { + b"\x00": "NO AUTHENTICATION REQUIRED", + b"\x01": "GSSAPI", + b"\x02": "USERNAME/PASSWORD", + b"\xff": "NO ACCEPTABLE METHODS", +} + +REPLY_CODES = { + b"\x00": "Succeeded", + b"\x01": "General SOCKS server failure", + b"\x02": "Connection not allowed by ruleset", + b"\x03": "Network unreachable", + b"\x04": "Host unreachable", + b"\x05": "Connection refused", + b"\x06": "TTL expired", + b"\x07": "Command not supported", + b"\x08": "Address type not supported", +} + + +def _init_socks5_connection( + stream: NetworkStream, + *, + host: bytes, + port: int, + auth: typing.Optional[typing.Tuple[bytes, bytes]] = None, +) -> None: + conn = socks5.SOCKS5Connection() + + # Auth method request + auth_method = ( + socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED + if auth is None + else socks5.SOCKS5AuthMethod.USERNAME_PASSWORD + ) + conn.send(socks5.SOCKS5AuthMethodsRequest([auth_method])) + outgoing_bytes = conn.data_to_send() + stream.write(outgoing_bytes) + + # Auth method response + incoming_bytes = stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5AuthReply) + if response.method != auth_method: + requested = AUTH_METHODS.get(auth_method, "UNKNOWN") + responded = AUTH_METHODS.get(response.method, "UNKNOWN") + raise ProxyError( + f"Requested {requested} from proxy server, but got {responded}." + ) + + if response.method == socks5.SOCKS5AuthMethod.USERNAME_PASSWORD: + # Username/password request + assert auth is not None + username, password = auth + conn.send(socks5.SOCKS5UsernamePasswordRequest(username, password)) + outgoing_bytes = conn.data_to_send() + stream.write(outgoing_bytes) + + # Username/password response + incoming_bytes = stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5UsernamePasswordReply) + if not response.success: + raise ProxyError("Invalid username/password") + + # Connect request + conn.send( + socks5.SOCKS5CommandRequest.from_address( + socks5.SOCKS5Command.CONNECT, (host, port) + ) + ) + outgoing_bytes = conn.data_to_send() + stream.write(outgoing_bytes) + + # Connect response + incoming_bytes = stream.read(max_bytes=4096) + response = conn.receive_data(incoming_bytes) + assert isinstance(response, socks5.SOCKS5Reply) + if response.reply_code != socks5.SOCKS5ReplyCode.SUCCEEDED: + reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN") + raise ProxyError(f"Proxy Server could not connect: {reply_code}.") + + +class SOCKSProxy(ConnectionPool): + """ + A connection pool that sends requests via an HTTP proxy. + """ + + def __init__( + self, + proxy_url: typing.Union[URL, bytes, str], + proxy_auth: typing.Optional[ + typing.Tuple[typing.Union[bytes, str], typing.Union[bytes, str]] + ] = None, + ssl_context: typing.Optional[ssl.SSLContext] = None, + max_connections: typing.Optional[int] = 10, + max_keepalive_connections: typing.Optional[int] = None, + keepalive_expiry: typing.Optional[float] = None, + http1: bool = True, + http2: bool = False, + retries: int = 0, + network_backend: typing.Optional[NetworkBackend] = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + super().__init__( + ssl_context=ssl_context, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + http1=http1, + http2=http2, + network_backend=network_backend, + retries=retries, + ) + self._ssl_context = ssl_context + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + if proxy_auth is not None: + username, password = proxy_auth + username_bytes = enforce_bytes(username, name="proxy_auth") + password_bytes = enforce_bytes(password, name="proxy_auth") + self._proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = ( + username_bytes, + password_bytes, + ) + else: + self._proxy_auth = None + + def create_connection(self, origin: Origin) -> ConnectionInterface: + return Socks5Connection( + proxy_origin=self._proxy_url.origin, + remote_origin=origin, + proxy_auth=self._proxy_auth, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + + +class Socks5Connection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + proxy_auth: typing.Optional[typing.Tuple[bytes, bytes]] = None, + ssl_context: typing.Optional[ssl.SSLContext] = None, + keepalive_expiry: typing.Optional[float] = None, + http1: bool = True, + http2: bool = False, + network_backend: typing.Optional[NetworkBackend] = None, + ) -> None: + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._proxy_auth = proxy_auth + self._ssl_context = ssl_context + self._keepalive_expiry = keepalive_expiry + self._http1 = http1 + self._http2 = http2 + + self._network_backend: NetworkBackend = ( + SyncBackend() if network_backend is None else network_backend + ) + self._connect_lock = Lock() + self._connection: typing.Optional[ConnectionInterface] = None + self._connect_failed = False + + def handle_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + sni_hostname = request.extensions.get("sni_hostname", None) + timeout = timeouts.get("connect", None) + + with self._connect_lock: + if self._connection is None: + try: + # Connect to the proxy + kwargs = { + "host": self._proxy_origin.host.decode("ascii"), + "port": self._proxy_origin.port, + "timeout": timeout, + } + with Trace("connect_tcp", logger, request, kwargs) as trace: + stream = self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream + + # Connect to the remote host using socks5 + kwargs = { + "stream": stream, + "host": self._remote_origin.host.decode("ascii"), + "port": self._remote_origin.port, + "auth": self._proxy_auth, + } + with Trace( + "setup_socks5_connection", logger, request, kwargs + ) as trace: + _init_socks5_connection(**kwargs) + trace.return_value = stream + + # Upgrade the stream to SSL + if self._remote_origin.scheme == b"https": + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context + ) + alpn_protocols = ( + ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ) + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": sni_hostname + or self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or ( + self._http2 and not self._http1 + ): # pragma: nocover + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + except Exception as exc: + self._connect_failed = True + raise exc + elif not self._connection.is_available(): # pragma: nocover + raise ConnectionNotAvailable() + + return self._connection.handle_request(request) + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + def close(self) -> None: + if self._connection is not None: + self._connection.close() + + def is_available(self) -> bool: + if self._connection is None: # pragma: nocover + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._remote_origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: # pragma: nocover + return self._connect_failed + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: # pragma: nocover + return "CONNECTION FAILED" if self._connect_failed else "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/contrib/python/httpcore/httpcore/_synchronization.py b/contrib/python/httpcore/httpcore/_synchronization.py new file mode 100644 index 0000000000..bae27c1b11 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_synchronization.py @@ -0,0 +1,279 @@ +import threading +from types import TracebackType +from typing import Optional, Type + +import sniffio + +from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions + +# Our async synchronization primatives use either 'anyio' or 'trio' depending +# on if they're running under asyncio or trio. + +try: + import trio +except ImportError: # pragma: nocover + trio = None # type: ignore + +try: + import anyio +except ImportError: # pragma: nocover + anyio = None # type: ignore + + +class AsyncLock: + def __init__(self) -> None: + self._backend = "" + + def setup(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a lock with the correct implementation. + """ + self._backend = sniffio.current_async_library() + if self._backend == "trio": + if trio is None: # pragma: nocover + raise RuntimeError( + "Running under trio, requires the 'trio' package to be installed." + ) + self._trio_lock = trio.Lock() + else: + if anyio is None: # pragma: nocover + raise RuntimeError( + "Running under asyncio requires the 'anyio' package to be installed." + ) + self._anyio_lock = anyio.Lock() + + async def __aenter__(self) -> "AsyncLock": + if not self._backend: + self.setup() + + if self._backend == "trio": + await self._trio_lock.acquire() + else: + await self._anyio_lock.acquire() + + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + if self._backend == "trio": + self._trio_lock.release() + else: + self._anyio_lock.release() + + +class AsyncEvent: + def __init__(self) -> None: + self._backend = "" + + def setup(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a lock with the correct implementation. + """ + self._backend = sniffio.current_async_library() + if self._backend == "trio": + if trio is None: # pragma: nocover + raise RuntimeError( + "Running under trio requires the 'trio' package to be installed." + ) + self._trio_event = trio.Event() + else: + if anyio is None: # pragma: nocover + raise RuntimeError( + "Running under asyncio requires the 'anyio' package to be installed." + ) + self._anyio_event = anyio.Event() + + def set(self) -> None: + if not self._backend: + self.setup() + + if self._backend == "trio": + self._trio_event.set() + else: + self._anyio_event.set() + + async def wait(self, timeout: Optional[float] = None) -> None: + if not self._backend: + self.setup() + + if self._backend == "trio": + if trio is None: # pragma: nocover + raise RuntimeError( + "Running under trio requires the 'trio' package to be installed." + ) + + trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout} + timeout_or_inf = float("inf") if timeout is None else timeout + with map_exceptions(trio_exc_map): + with trio.fail_after(timeout_or_inf): + await self._trio_event.wait() + else: + if anyio is None: # pragma: nocover + raise RuntimeError( + "Running under asyncio requires the 'anyio' package to be installed." + ) + + anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout} + with map_exceptions(anyio_exc_map): + with anyio.fail_after(timeout): + await self._anyio_event.wait() + + +class AsyncSemaphore: + def __init__(self, bound: int) -> None: + self._bound = bound + self._backend = "" + + def setup(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a semaphore with the correct implementation. + """ + self._backend = sniffio.current_async_library() + if self._backend == "trio": + if trio is None: # pragma: nocover + raise RuntimeError( + "Running under trio requires the 'trio' package to be installed." + ) + + self._trio_semaphore = trio.Semaphore( + initial_value=self._bound, max_value=self._bound + ) + else: + if anyio is None: # pragma: nocover + raise RuntimeError( + "Running under asyncio requires the 'anyio' package to be installed." + ) + + self._anyio_semaphore = anyio.Semaphore( + initial_value=self._bound, max_value=self._bound + ) + + async def acquire(self) -> None: + if not self._backend: + self.setup() + + if self._backend == "trio": + await self._trio_semaphore.acquire() + else: + await self._anyio_semaphore.acquire() + + async def release(self) -> None: + if self._backend == "trio": + self._trio_semaphore.release() + else: + self._anyio_semaphore.release() + + +class AsyncShieldCancellation: + # For certain portions of our codebase where we're dealing with + # closing connections during exception handling we want to shield + # the operation from being cancelled. + # + # with AsyncShieldCancellation(): + # ... # clean-up operations, shielded from cancellation. + + def __init__(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a shielded scope with the correct implementation. + """ + self._backend = sniffio.current_async_library() + + if self._backend == "trio": + if trio is None: # pragma: nocover + raise RuntimeError( + "Running under trio requires the 'trio' package to be installed." + ) + + self._trio_shield = trio.CancelScope(shield=True) + else: + if anyio is None: # pragma: nocover + raise RuntimeError( + "Running under asyncio requires the 'anyio' package to be installed." + ) + + self._anyio_shield = anyio.CancelScope(shield=True) + + def __enter__(self) -> "AsyncShieldCancellation": + if self._backend == "trio": + self._trio_shield.__enter__() + else: + self._anyio_shield.__enter__() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + if self._backend == "trio": + self._trio_shield.__exit__(exc_type, exc_value, traceback) + else: + self._anyio_shield.__exit__(exc_type, exc_value, traceback) + + +# Our thread-based synchronization primitives... + + +class Lock: + def __init__(self) -> None: + self._lock = threading.Lock() + + def __enter__(self) -> "Lock": + self._lock.acquire() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + self._lock.release() + + +class Event: + def __init__(self) -> None: + self._event = threading.Event() + + def set(self) -> None: + self._event.set() + + def wait(self, timeout: Optional[float] = None) -> None: + if not self._event.wait(timeout=timeout): + raise PoolTimeout() # pragma: nocover + + +class Semaphore: + def __init__(self, bound: int) -> None: + self._semaphore = threading.Semaphore(value=bound) + + def acquire(self) -> None: + self._semaphore.acquire() + + def release(self) -> None: + self._semaphore.release() + + +class ShieldCancellation: + # Thread-synchronous codebases don't support cancellation semantics. + # We have this class because we need to mirror the async and sync + # cases within our package, but it's just a no-op. + def __enter__(self) -> "ShieldCancellation": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + pass diff --git a/contrib/python/httpcore/httpcore/_trace.py b/contrib/python/httpcore/httpcore/_trace.py new file mode 100644 index 0000000000..b122a53e88 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_trace.py @@ -0,0 +1,105 @@ +import inspect +import logging +from types import TracebackType +from typing import Any, Dict, Optional, Type + +from ._models import Request + + +class Trace: + def __init__( + self, + name: str, + logger: logging.Logger, + request: Optional[Request] = None, + kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self.name = name + self.logger = logger + self.trace_extension = ( + None if request is None else request.extensions.get("trace") + ) + self.debug = self.logger.isEnabledFor(logging.DEBUG) + self.kwargs = kwargs or {} + self.return_value: Any = None + self.should_trace = self.debug or self.trace_extension is not None + self.prefix = self.logger.name.split(".")[-1] + + def trace(self, name: str, info: Dict[str, Any]) -> None: + if self.trace_extension is not None: + prefix_and_name = f"{self.prefix}.{name}" + ret = self.trace_extension(prefix_and_name, info) + if inspect.iscoroutine(ret): # pragma: no cover + raise TypeError( + "If you are using a synchronous interface, " + "the callback of the `trace` extension should " + "be a normal function instead of an asynchronous function." + ) + + if self.debug: + if not info or "return_value" in info and info["return_value"] is None: + message = name + else: + args = " ".join([f"{key}={value!r}" for key, value in info.items()]) + message = f"{name} {args}" + self.logger.debug(message) + + def __enter__(self) -> "Trace": + if self.should_trace: + info = self.kwargs + self.trace(f"{self.name}.started", info) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + if self.should_trace: + if exc_value is None: + info = {"return_value": self.return_value} + self.trace(f"{self.name}.complete", info) + else: + info = {"exception": exc_value} + self.trace(f"{self.name}.failed", info) + + async def atrace(self, name: str, info: Dict[str, Any]) -> None: + if self.trace_extension is not None: + prefix_and_name = f"{self.prefix}.{name}" + coro = self.trace_extension(prefix_and_name, info) + if not inspect.iscoroutine(coro): # pragma: no cover + raise TypeError( + "If you're using an asynchronous interface, " + "the callback of the `trace` extension should " + "be an asynchronous function rather than a normal function." + ) + await coro + + if self.debug: + if not info or "return_value" in info and info["return_value"] is None: + message = name + else: + args = " ".join([f"{key}={value!r}" for key, value in info.items()]) + message = f"{name} {args}" + self.logger.debug(message) + + async def __aenter__(self) -> "Trace": + if self.should_trace: + info = self.kwargs + await self.atrace(f"{self.name}.started", info) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + if self.should_trace: + if exc_value is None: + info = {"return_value": self.return_value} + await self.atrace(f"{self.name}.complete", info) + else: + info = {"exception": exc_value} + await self.atrace(f"{self.name}.failed", info) diff --git a/contrib/python/httpcore/httpcore/_utils.py b/contrib/python/httpcore/httpcore/_utils.py new file mode 100644 index 0000000000..df5dea8fe4 --- /dev/null +++ b/contrib/python/httpcore/httpcore/_utils.py @@ -0,0 +1,36 @@ +import select +import socket +import sys +import typing + + +def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool: + """ + Return whether a socket, as identifed by its file descriptor, is readable. + "A socket is readable" means that the read buffer isn't empty, i.e. that calling + .recv() on it would immediately return some data. + """ + # NOTE: we want check for readability without actually attempting to read, because + # we don't want to block forever if it's not readable. + + # In the case that the socket no longer exists, or cannot return a file + # descriptor, we treat it as being readable, as if it the next read operation + # on it is ready to return the terminating `b""`. + sock_fd = None if sock is None else sock.fileno() + if sock_fd is None or sock_fd < 0: # pragma: nocover + return True + + # The implementation below was stolen from: + # https://github.com/python-trio/trio/blob/20ee2b1b7376db637435d80e266212a35837ddcc/trio/_socket.py#L471-L478 + # See also: https://github.com/encode/httpcore/pull/193#issuecomment-703129316 + + # Use select.select on Windows, and when poll is unavailable and select.poll + # everywhere else. (E.g. When eventlet is in use. See #327) + if ( + sys.platform == "win32" or getattr(select, "poll", None) is None + ): # pragma: nocover + rready, _, _ = select.select([sock_fd], [], [], 0) + return bool(rready) + p = select.poll() + p.register(sock_fd, select.POLLIN) + return bool(p.poll(0)) diff --git a/contrib/python/httpcore/httpcore/py.typed b/contrib/python/httpcore/httpcore/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/httpcore/httpcore/py.typed diff --git a/contrib/python/httpcore/ya.make b/contrib/python/httpcore/ya.make new file mode 100644 index 0000000000..e8516afe10 --- /dev/null +++ b/contrib/python/httpcore/ya.make @@ -0,0 +1,68 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(0.18.0) + +LICENSE(BSD-3-Clause) + +PEERDIR( + contrib/python/anyio + contrib/python/certifi + contrib/python/h11 + contrib/python/sniffio +) + +NO_LINT() + +NO_CHECK_IMPORTS( + httpcore._async.http2 + httpcore._async.socks_proxy + httpcore._backends.trio + httpcore._sync.http2 + httpcore._sync.socks_proxy +) + +PY_SRCS( + TOP_LEVEL + httpcore/__init__.py + httpcore/_api.py + httpcore/_async/__init__.py + httpcore/_async/connection.py + httpcore/_async/connection_pool.py + httpcore/_async/http11.py + httpcore/_async/http2.py + httpcore/_async/http_proxy.py + httpcore/_async/interfaces.py + httpcore/_async/socks_proxy.py + httpcore/_backends/__init__.py + httpcore/_backends/anyio.py + httpcore/_backends/auto.py + httpcore/_backends/base.py + httpcore/_backends/mock.py + httpcore/_backends/sync.py + httpcore/_backends/trio.py + httpcore/_exceptions.py + httpcore/_models.py + httpcore/_ssl.py + httpcore/_sync/__init__.py + httpcore/_sync/connection.py + httpcore/_sync/connection_pool.py + httpcore/_sync/http11.py + httpcore/_sync/http2.py + httpcore/_sync/http_proxy.py + httpcore/_sync/interfaces.py + httpcore/_sync/socks_proxy.py + httpcore/_synchronization.py + httpcore/_trace.py + httpcore/_utils.py +) + +RESOURCE_FILES( + PREFIX contrib/python/httpcore/ + .dist-info/METADATA + .dist-info/top_level.txt + httpcore/py.typed +) + +END() diff --git a/contrib/python/httpx/.dist-info/METADATA b/contrib/python/httpx/.dist-info/METADATA new file mode 100644 index 0000000000..f3a6d509cd --- /dev/null +++ b/contrib/python/httpx/.dist-info/METADATA @@ -0,0 +1,216 @@ +Metadata-Version: 2.1 +Name: httpx +Version: 0.25.0 +Summary: The next generation HTTP client. +Project-URL: Changelog, https://github.com/encode/httpx/blob/master/CHANGELOG.md +Project-URL: Documentation, https://www.python-httpx.org +Project-URL: Homepage, https://github.com/encode/httpx +Project-URL: Source, https://github.com/encode/httpx +Author-email: Tom Christie <tom@tomchristie.com> +License-Expression: BSD-3-Clause +License-File: LICENSE.md +Classifier: Development Status :: 4 - Beta +Classifier: Environment :: Web Environment +Classifier: Framework :: AsyncIO +Classifier: Framework :: Trio +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Topic :: Internet :: WWW/HTTP +Requires-Python: >=3.8 +Requires-Dist: certifi +Requires-Dist: httpcore<0.19.0,>=0.18.0 +Requires-Dist: idna +Requires-Dist: sniffio +Provides-Extra: brotli +Requires-Dist: brotli; platform_python_implementation == 'CPython' and extra == 'brotli' +Requires-Dist: brotlicffi; platform_python_implementation != 'CPython' and extra == 'brotli' +Provides-Extra: cli +Requires-Dist: click==8.*; extra == 'cli' +Requires-Dist: pygments==2.*; extra == 'cli' +Requires-Dist: rich<14,>=10; extra == 'cli' +Provides-Extra: http2 +Requires-Dist: h2<5,>=3; extra == 'http2' +Provides-Extra: socks +Requires-Dist: socksio==1.*; extra == 'socks' +Description-Content-Type: text/markdown + +<p align="center"> + <a href="https://www.python-httpx.org/"><img width="350" height="208" src="https://raw.githubusercontent.com/encode/httpx/master/docs/img/butterfly.png" alt='HTTPX'></a> +</p> + +<p align="center"><strong>HTTPX</strong> <em>- A next-generation HTTP client for Python.</em></p> + +<p align="center"> +<a href="https://github.com/encode/httpx/actions"> + <img src="https://github.com/encode/httpx/workflows/Test%20Suite/badge.svg" alt="Test Suite"> +</a> +<a href="https://pypi.org/project/httpx/"> + <img src="https://badge.fury.io/py/httpx.svg" alt="Package version"> +</a> +</p> + +HTTPX is a fully featured HTTP client library for Python 3. It includes **an integrated +command line client**, has support for both **HTTP/1.1 and HTTP/2**, and provides both **sync +and async APIs**. + +--- + +Install HTTPX using pip: + +```shell +$ pip install httpx +``` + +Now, let's get started: + +```pycon +>>> import httpx +>>> r = httpx.get('https://www.example.org/') +>>> r +<Response [200 OK]> +>>> r.status_code +200 +>>> r.headers['content-type'] +'text/html; charset=UTF-8' +>>> r.text +'<!doctype html>\n<html>\n<head>\n<title>Example Domain</title>...' +``` + +Or, using the command-line client. + +```shell +$ pip install 'httpx[cli]' # The command line client is an optional dependency. +``` + +Which now allows us to use HTTPX directly from the command-line... + +<p align="center"> + <img width="700" src="https://raw.githubusercontent.com/encode/httpx/master/docs/img/httpx-help.png" alt='httpx --help'> +</p> + +Sending a request... + +<p align="center"> + <img width="700" src="https://raw.githubusercontent.com/encode/httpx/master/docs/img/httpx-request.png" alt='httpx http://httpbin.org/json'> +</p> + +## Features + +HTTPX builds on the well-established usability of `requests`, and gives you: + +* A broadly [requests-compatible API](https://www.python-httpx.org/compatibility/). +* An integrated command-line client. +* HTTP/1.1 [and HTTP/2 support](https://www.python-httpx.org/http2/). +* Standard synchronous interface, but with [async support if you need it](https://www.python-httpx.org/async/). +* Ability to make requests directly to [WSGI applications](https://www.python-httpx.org/advanced/#calling-into-python-web-apps) or [ASGI applications](https://www.python-httpx.org/async/#calling-into-python-web-apps). +* Strict timeouts everywhere. +* Fully type annotated. +* 100% test coverage. + +Plus all the standard features of `requests`... + +* International Domains and URLs +* Keep-Alive & Connection Pooling +* Sessions with Cookie Persistence +* Browser-style SSL Verification +* Basic/Digest Authentication +* Elegant Key/Value Cookies +* Automatic Decompression +* Automatic Content Decoding +* Unicode Response Bodies +* Multipart File Uploads +* HTTP(S) Proxy Support +* Connection Timeouts +* Streaming Downloads +* .netrc Support +* Chunked Requests + +## Installation + +Install with pip: + +```shell +$ pip install httpx +``` + +Or, to include the optional HTTP/2 support, use: + +```shell +$ pip install httpx[http2] +``` + +HTTPX requires Python 3.8+. + +## Documentation + +Project documentation is available at [https://www.python-httpx.org/](https://www.python-httpx.org/). + +For a run-through of all the basics, head over to the [QuickStart](https://www.python-httpx.org/quickstart/). + +For more advanced topics, see the [Advanced Usage](https://www.python-httpx.org/advanced/) section, the [async support](https://www.python-httpx.org/async/) section, or the [HTTP/2](https://www.python-httpx.org/http2/) section. + +The [Developer Interface](https://www.python-httpx.org/api/) provides a comprehensive API reference. + +To find out about tools that integrate with HTTPX, see [Third Party Packages](https://www.python-httpx.org/third_party_packages/). + +## Contribute + +If you want to contribute with HTTPX check out the [Contributing Guide](https://www.python-httpx.org/contributing/) to learn how to start. + +## Dependencies + +The HTTPX project relies on these excellent libraries: + +* `httpcore` - The underlying transport implementation for `httpx`. + * `h11` - HTTP/1.1 support. +* `certifi` - SSL certificates. +* `idna` - Internationalized domain name support. +* `sniffio` - Async library autodetection. + +As well as these optional installs: + +* `h2` - HTTP/2 support. *(Optional, with `httpx[http2]`)* +* `socksio` - SOCKS proxy support. *(Optional, with `httpx[socks]`)* +* `rich` - Rich terminal support. *(Optional, with `httpx[cli]`)* +* `click` - Command line client support. *(Optional, with `httpx[cli]`)* +* `brotli` or `brotlicffi` - Decoding for "brotli" compressed responses. *(Optional, with `httpx[brotli]`)* + +A huge amount of credit is due to `requests` for the API layout that +much of this work follows, as well as to `urllib3` for plenty of design +inspiration around the lower-level networking details. + +--- + +<p align="center"><i>HTTPX is <a href="https://github.com/encode/httpx/blob/master/LICENSE.md">BSD licensed</a> code.<br/>Designed & crafted with care.</i><br/>— 🦋 —</p> + +## Release Information + +### Removed + +* Drop support for Python 3.7. (#2813) + +### Added + +* Support HTTPS proxies. (#2845) +* Change the type of `Extensions` from `Mapping[Str, Any]` to `MutableMapping[Str, Any]`. (#2803) +* Add `socket_options` argument to `httpx.HTTPTransport` and `httpx.AsyncHTTPTransport` classes. (#2716) +* The `Response.raise_for_status()` method now returns the response instance. For example: `data = httpx.get('...').raise_for_status().json()`. (#2776) + +### Fixed + +* Return `500` error response instead of exceptions when `raise_app_exceptions=False` is set on `ASGITransport`. (#2669) +* Ensure all `WSGITransport` environs have a `SERVER_PROTOCOL`. (#2708) +* Always encode forward slashes as `%2F` in query parameters (#2723) +* Use Mozilla documentation instead of `httpstatuses.com` for HTTP error reference (#2768) + + +--- + +[Full changelog](https://github.com/encode/httpx/blob/master/CHANGELOG.md) diff --git a/contrib/python/httpx/.dist-info/entry_points.txt b/contrib/python/httpx/.dist-info/entry_points.txt new file mode 100644 index 0000000000..8ae96007f7 --- /dev/null +++ b/contrib/python/httpx/.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +httpx = httpx:main diff --git a/contrib/python/httpx/.dist-info/top_level.txt b/contrib/python/httpx/.dist-info/top_level.txt new file mode 100644 index 0000000000..c180eb2f8e --- /dev/null +++ b/contrib/python/httpx/.dist-info/top_level.txt @@ -0,0 +1,2 @@ +httpx +httpx/_transports diff --git a/contrib/python/httpx/LICENSE.md b/contrib/python/httpx/LICENSE.md new file mode 100644 index 0000000000..ab79d16a3f --- /dev/null +++ b/contrib/python/httpx/LICENSE.md @@ -0,0 +1,12 @@ +Copyright © 2019, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/contrib/python/httpx/README.md b/contrib/python/httpx/README.md new file mode 100644 index 0000000000..62fb295d17 --- /dev/null +++ b/contrib/python/httpx/README.md @@ -0,0 +1,148 @@ +<p align="center"> + <a href="https://www.python-httpx.org/"><img width="350" height="208" src="https://raw.githubusercontent.com/encode/httpx/master/docs/img/butterfly.png" alt='HTTPX'></a> +</p> + +<p align="center"><strong>HTTPX</strong> <em>- A next-generation HTTP client for Python.</em></p> + +<p align="center"> +<a href="https://github.com/encode/httpx/actions"> + <img src="https://github.com/encode/httpx/workflows/Test%20Suite/badge.svg" alt="Test Suite"> +</a> +<a href="https://pypi.org/project/httpx/"> + <img src="https://badge.fury.io/py/httpx.svg" alt="Package version"> +</a> +</p> + +HTTPX is a fully featured HTTP client library for Python 3. It includes **an integrated +command line client**, has support for both **HTTP/1.1 and HTTP/2**, and provides both **sync +and async APIs**. + +--- + +Install HTTPX using pip: + +```shell +$ pip install httpx +``` + +Now, let's get started: + +```pycon +>>> import httpx +>>> r = httpx.get('https://www.example.org/') +>>> r +<Response [200 OK]> +>>> r.status_code +200 +>>> r.headers['content-type'] +'text/html; charset=UTF-8' +>>> r.text +'<!doctype html>\n<html>\n<head>\n<title>Example Domain</title>...' +``` + +Or, using the command-line client. + +```shell +$ pip install 'httpx[cli]' # The command line client is an optional dependency. +``` + +Which now allows us to use HTTPX directly from the command-line... + +<p align="center"> + <img width="700" src="docs/img/httpx-help.png" alt='httpx --help'> +</p> + +Sending a request... + +<p align="center"> + <img width="700" src="docs/img/httpx-request.png" alt='httpx http://httpbin.org/json'> +</p> + +## Features + +HTTPX builds on the well-established usability of `requests`, and gives you: + +* A broadly [requests-compatible API](https://www.python-httpx.org/compatibility/). +* An integrated command-line client. +* HTTP/1.1 [and HTTP/2 support](https://www.python-httpx.org/http2/). +* Standard synchronous interface, but with [async support if you need it](https://www.python-httpx.org/async/). +* Ability to make requests directly to [WSGI applications](https://www.python-httpx.org/advanced/#calling-into-python-web-apps) or [ASGI applications](https://www.python-httpx.org/async/#calling-into-python-web-apps). +* Strict timeouts everywhere. +* Fully type annotated. +* 100% test coverage. + +Plus all the standard features of `requests`... + +* International Domains and URLs +* Keep-Alive & Connection Pooling +* Sessions with Cookie Persistence +* Browser-style SSL Verification +* Basic/Digest Authentication +* Elegant Key/Value Cookies +* Automatic Decompression +* Automatic Content Decoding +* Unicode Response Bodies +* Multipart File Uploads +* HTTP(S) Proxy Support +* Connection Timeouts +* Streaming Downloads +* .netrc Support +* Chunked Requests + +## Installation + +Install with pip: + +```shell +$ pip install httpx +``` + +Or, to include the optional HTTP/2 support, use: + +```shell +$ pip install httpx[http2] +``` + +HTTPX requires Python 3.8+. + +## Documentation + +Project documentation is available at [https://www.python-httpx.org/](https://www.python-httpx.org/). + +For a run-through of all the basics, head over to the [QuickStart](https://www.python-httpx.org/quickstart/). + +For more advanced topics, see the [Advanced Usage](https://www.python-httpx.org/advanced/) section, the [async support](https://www.python-httpx.org/async/) section, or the [HTTP/2](https://www.python-httpx.org/http2/) section. + +The [Developer Interface](https://www.python-httpx.org/api/) provides a comprehensive API reference. + +To find out about tools that integrate with HTTPX, see [Third Party Packages](https://www.python-httpx.org/third_party_packages/). + +## Contribute + +If you want to contribute with HTTPX check out the [Contributing Guide](https://www.python-httpx.org/contributing/) to learn how to start. + +## Dependencies + +The HTTPX project relies on these excellent libraries: + +* `httpcore` - The underlying transport implementation for `httpx`. + * `h11` - HTTP/1.1 support. +* `certifi` - SSL certificates. +* `idna` - Internationalized domain name support. +* `sniffio` - Async library autodetection. + +As well as these optional installs: + +* `h2` - HTTP/2 support. *(Optional, with `httpx[http2]`)* +* `socksio` - SOCKS proxy support. *(Optional, with `httpx[socks]`)* +* `rich` - Rich terminal support. *(Optional, with `httpx[cli]`)* +* `click` - Command line client support. *(Optional, with `httpx[cli]`)* +* `brotli` or `brotlicffi` - Decoding for "brotli" compressed responses. *(Optional, with `httpx[brotli]`)* + +A huge amount of credit is due to `requests` for the API layout that +much of this work follows, as well as to `urllib3` for plenty of design +inspiration around the lower-level networking details. + +--- + +<p align="center"><i>HTTPX is <a href="https://github.com/encode/httpx/blob/master/LICENSE.md">BSD licensed</a> code.<br/>Designed & crafted with care.</i><br/>— 🦋 —</p> diff --git a/contrib/python/httpx/httpx/__init__.py b/contrib/python/httpx/httpx/__init__.py new file mode 100644 index 0000000000..f61112f8b2 --- /dev/null +++ b/contrib/python/httpx/httpx/__init__.py @@ -0,0 +1,138 @@ +from .__version__ import __description__, __title__, __version__ +from ._api import delete, get, head, options, patch, post, put, request, stream +from ._auth import Auth, BasicAuth, DigestAuth, NetRCAuth +from ._client import USE_CLIENT_DEFAULT, AsyncClient, Client +from ._config import Limits, Proxy, Timeout, create_ssl_context +from ._content import ByteStream +from ._exceptions import ( + CloseError, + ConnectError, + ConnectTimeout, + CookieConflict, + DecodingError, + HTTPError, + HTTPStatusError, + InvalidURL, + LocalProtocolError, + NetworkError, + PoolTimeout, + ProtocolError, + ProxyError, + ReadError, + ReadTimeout, + RemoteProtocolError, + RequestError, + RequestNotRead, + ResponseNotRead, + StreamClosed, + StreamConsumed, + StreamError, + TimeoutException, + TooManyRedirects, + TransportError, + UnsupportedProtocol, + WriteError, + WriteTimeout, +) +from ._models import Cookies, Headers, Request, Response +from ._status_codes import codes +from ._transports.asgi import ASGITransport +from ._transports.base import AsyncBaseTransport, BaseTransport +from ._transports.default import AsyncHTTPTransport, HTTPTransport +from ._transports.mock import MockTransport +from ._transports.wsgi import WSGITransport +from ._types import AsyncByteStream, SyncByteStream +from ._urls import URL, QueryParams + +try: + from ._main import main +except ImportError: # pragma: no cover + + def main() -> None: # type: ignore + import sys + + print( + "The httpx command line client could not run because the required " + "dependencies were not installed.\nMake sure you've installed " + "everything with: pip install 'httpx[cli]'" + ) + sys.exit(1) + + +__all__ = [ + "__description__", + "__title__", + "__version__", + "ASGITransport", + "AsyncBaseTransport", + "AsyncByteStream", + "AsyncClient", + "AsyncHTTPTransport", + "Auth", + "BaseTransport", + "BasicAuth", + "ByteStream", + "Client", + "CloseError", + "codes", + "ConnectError", + "ConnectTimeout", + "CookieConflict", + "Cookies", + "create_ssl_context", + "DecodingError", + "delete", + "DigestAuth", + "get", + "head", + "Headers", + "HTTPError", + "HTTPStatusError", + "HTTPTransport", + "InvalidURL", + "Limits", + "LocalProtocolError", + "main", + "MockTransport", + "NetRCAuth", + "NetworkError", + "options", + "patch", + "PoolTimeout", + "post", + "ProtocolError", + "Proxy", + "ProxyError", + "put", + "QueryParams", + "ReadError", + "ReadTimeout", + "RemoteProtocolError", + "request", + "Request", + "RequestError", + "RequestNotRead", + "Response", + "ResponseNotRead", + "stream", + "StreamClosed", + "StreamConsumed", + "StreamError", + "SyncByteStream", + "Timeout", + "TimeoutException", + "TooManyRedirects", + "TransportError", + "UnsupportedProtocol", + "URL", + "USE_CLIENT_DEFAULT", + "WriteError", + "WriteTimeout", + "WSGITransport", +] + + +__locals = locals() +for __name in __all__: + if not __name.startswith("__"): + setattr(__locals[__name], "__module__", "httpx") # noqa diff --git a/contrib/python/httpx/httpx/__version__.py b/contrib/python/httpx/httpx/__version__.py new file mode 100644 index 0000000000..bfa421ad60 --- /dev/null +++ b/contrib/python/httpx/httpx/__version__.py @@ -0,0 +1,3 @@ +__title__ = "httpx" +__description__ = "A next generation HTTP client, for Python 3." +__version__ = "0.25.0" diff --git a/contrib/python/httpx/httpx/_api.py b/contrib/python/httpx/httpx/_api.py new file mode 100644 index 0000000000..571289cf2b --- /dev/null +++ b/contrib/python/httpx/httpx/_api.py @@ -0,0 +1,445 @@ +import typing +from contextlib import contextmanager + +from ._client import Client +from ._config import DEFAULT_TIMEOUT_CONFIG +from ._models import Response +from ._types import ( + AuthTypes, + CertTypes, + CookieTypes, + HeaderTypes, + ProxiesTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestFiles, + TimeoutTypes, + URLTypes, + VerifyTypes, +) + + +def request( + method: str, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + trust_env: bool = True, +) -> Response: + """ + Sends an HTTP request. + + **Parameters:** + + * **method** - HTTP method for the new `Request` object: `GET`, `OPTIONS`, + `HEAD`, `POST`, `PUT`, `PATCH`, or `DELETE`. + * **url** - URL for the new `Request` object. + * **params** - *(optional)* Query parameters to include in the URL, as a + string, dictionary, or sequence of two-tuples. + * **content** - *(optional)* Binary content to include in the body of the + request, as bytes or a byte iterator. + * **data** - *(optional)* Form data to include in the body of the request, + as a dictionary. + * **files** - *(optional)* A dictionary of upload files to include in the + body of the request. + * **json** - *(optional)* A JSON serializable object to include in the body + of the request. + * **headers** - *(optional)* Dictionary of HTTP headers to include in the + request. + * **cookies** - *(optional)* Dictionary of Cookie items to include in the + request. + * **auth** - *(optional)* An authentication class to use when sending the + request. + * **proxies** - *(optional)* A dictionary mapping proxy keys to proxy URLs. + * **timeout** - *(optional)* The timeout configuration to use when sending + the request. + * **follow_redirects** - *(optional)* Enables or disables HTTP redirects. + * **verify** - *(optional)* SSL certificates (a.k.a CA bundle) used to + verify the identity of requested hosts. Either `True` (default CA bundle), + a path to an SSL certificate file, an `ssl.SSLContext`, or `False` + (which will disable verification). + * **cert** - *(optional)* An SSL certificate used by the requested host + to authenticate the client. Either a path to an SSL certificate file, or + two-tuple of (certificate file, key file), or a three-tuple of (certificate + file, key file, password). + * **trust_env** - *(optional)* Enables or disables usage of environment + variables for configuration. + + **Returns:** `Response` + + Usage: + + ``` + >>> import httpx + >>> response = httpx.request('GET', 'https://httpbin.org/get') + >>> response + <Response [200 OK]> + ``` + """ + with Client( + cookies=cookies, + proxies=proxies, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) as client: + return client.request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + auth=auth, + follow_redirects=follow_redirects, + ) + + +@contextmanager +def stream( + method: str, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + trust_env: bool = True, +) -> typing.Iterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + with Client( + cookies=cookies, + proxies=proxies, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) as client: + with client.stream( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + auth=auth, + follow_redirects=follow_redirects, + ) as response: + yield response + + +def get( + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `GET` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `GET` requests should not include a request body. + """ + return request( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def options( + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends an `OPTIONS` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `OPTIONS` requests should not include a request body. + """ + return request( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def head( + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `HEAD` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `HEAD` requests should not include a request body. + """ + return request( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def post( + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `POST` request. + + **Parameters**: See `httpx.request`. + """ + return request( + "POST", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def put( + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `PUT` request. + + **Parameters**: See `httpx.request`. + """ + return request( + "PUT", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def patch( + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `PATCH` request. + + **Parameters**: See `httpx.request`. + """ + return request( + "PATCH", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) + + +def delete( + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Optional[AuthTypes] = None, + proxies: typing.Optional[ProxiesTypes] = None, + follow_redirects: bool = False, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + trust_env: bool = True, +) -> Response: + """ + Sends a `DELETE` request. + + **Parameters**: See `httpx.request`. + + Note that the `data`, `files`, `json` and `content` parameters are not available + on this function, as `DELETE` requests should not include a request body. + """ + return request( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + proxies=proxies, + follow_redirects=follow_redirects, + cert=cert, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) diff --git a/contrib/python/httpx/httpx/_auth.py b/contrib/python/httpx/httpx/_auth.py new file mode 100644 index 0000000000..1d7385d573 --- /dev/null +++ b/contrib/python/httpx/httpx/_auth.py @@ -0,0 +1,347 @@ +import hashlib +import netrc +import os +import re +import time +import typing +from base64 import b64encode +from urllib.request import parse_http_list + +from ._exceptions import ProtocolError +from ._models import Request, Response +from ._utils import to_bytes, to_str, unquote + +if typing.TYPE_CHECKING: # pragma: no cover + from hashlib import _Hash + + +class Auth: + """ + Base class for all authentication schemes. + + To implement a custom authentication scheme, subclass `Auth` and override + the `.auth_flow()` method. + + If the authentication scheme does I/O such as disk access or network calls, or uses + synchronization primitives such as locks, you should override `.sync_auth_flow()` + and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized + implementations that will be used by `Client` and `AsyncClient` respectively. + """ + + requires_request_body = False + requires_response_body = False + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + """ + Execute the authentication flow. + + To dispatch a request, `yield` it: + + ``` + yield request + ``` + + The client will `.send()` the response back into the flow generator. You can + access it like so: + + ``` + response = yield request + ``` + + A `return` (or reaching the end of the generator) will result in the + client returning the last response obtained from the server. + + You can dispatch as many requests as is necessary. + """ + yield request + + def sync_auth_flow( + self, request: Request + ) -> typing.Generator[Request, Response, None]: + """ + Execute the authentication flow synchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O and/or uses concurrency primitives. + """ + if self.requires_request_body: + request.read() + + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + response.read() + + try: + request = flow.send(response) + except StopIteration: + break + + async def async_auth_flow( + self, request: Request + ) -> typing.AsyncGenerator[Request, Response]: + """ + Execute the authentication flow asynchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O and/or uses concurrency primitives. + """ + if self.requires_request_body: + await request.aread() + + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + await response.aread() + + try: + request = flow.send(response) + except StopIteration: + break + + +class FunctionAuth(Auth): + """ + Allows the 'auth' argument to be passed as a simple callable function, + that takes the request, and returns a new, modified request. + """ + + def __init__(self, func: typing.Callable[[Request], Request]) -> None: + self._func = func + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + yield self._func(request) + + +class BasicAuth(Auth): + """ + Allows the 'auth' argument to be passed as a (username, password) pair, + and uses HTTP Basic authentication. + """ + + def __init__( + self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] + ): + self._auth_header = self._build_auth_header(username, password) + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + request.headers["Authorization"] = self._auth_header + yield request + + def _build_auth_header( + self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] + ) -> str: + userpass = b":".join((to_bytes(username), to_bytes(password))) + token = b64encode(userpass).decode() + return f"Basic {token}" + + +class NetRCAuth(Auth): + """ + Use a 'netrc' file to lookup basic auth credentials based on the url host. + """ + + def __init__(self, file: typing.Optional[str] = None): + self._netrc_info = netrc.netrc(file) + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + auth_info = self._netrc_info.authenticators(request.url.host) + if auth_info is None or not auth_info[2]: + # The netrc file did not have authentication credentials for this host. + yield request + else: + # Build a basic auth header with credentials from the netrc file. + request.headers["Authorization"] = self._build_auth_header( + username=auth_info[0], password=auth_info[2] + ) + yield request + + def _build_auth_header( + self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] + ) -> str: + userpass = b":".join((to_bytes(username), to_bytes(password))) + token = b64encode(userpass).decode() + return f"Basic {token}" + + +class DigestAuth(Auth): + _ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable[[bytes], "_Hash"]] = { + "MD5": hashlib.md5, + "MD5-SESS": hashlib.md5, + "SHA": hashlib.sha1, + "SHA-SESS": hashlib.sha1, + "SHA-256": hashlib.sha256, + "SHA-256-SESS": hashlib.sha256, + "SHA-512": hashlib.sha512, + "SHA-512-SESS": hashlib.sha512, + } + + def __init__( + self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] + ) -> None: + self._username = to_bytes(username) + self._password = to_bytes(password) + self._last_challenge: typing.Optional[_DigestAuthChallenge] = None + self._nonce_count = 1 + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + if self._last_challenge: + request.headers["Authorization"] = self._build_auth_header( + request, self._last_challenge + ) + + response = yield request + + if response.status_code != 401 or "www-authenticate" not in response.headers: + # If the response is not a 401 then we don't + # need to build an authenticated request. + return + + for auth_header in response.headers.get_list("www-authenticate"): + if auth_header.lower().startswith("digest "): + break + else: + # If the response does not include a 'WWW-Authenticate: Digest ...' + # header, then we don't need to build an authenticated request. + return + + self._last_challenge = self._parse_challenge(request, response, auth_header) + self._nonce_count = 1 + + request.headers["Authorization"] = self._build_auth_header( + request, self._last_challenge + ) + yield request + + def _parse_challenge( + self, request: Request, response: Response, auth_header: str + ) -> "_DigestAuthChallenge": + """ + Returns a challenge from a Digest WWW-Authenticate header. + These take the form of: + `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"` + """ + scheme, _, fields = auth_header.partition(" ") + + # This method should only ever have been called with a Digest auth header. + assert scheme.lower() == "digest" + + header_dict: typing.Dict[str, str] = {} + for field in parse_http_list(fields): + key, value = field.strip().split("=", 1) + header_dict[key] = unquote(value) + + try: + realm = header_dict["realm"].encode() + nonce = header_dict["nonce"].encode() + algorithm = header_dict.get("algorithm", "MD5") + opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None + qop = header_dict["qop"].encode() if "qop" in header_dict else None + return _DigestAuthChallenge( + realm=realm, nonce=nonce, algorithm=algorithm, opaque=opaque, qop=qop + ) + except KeyError as exc: + message = "Malformed Digest WWW-Authenticate header" + raise ProtocolError(message, request=request) from exc + + def _build_auth_header( + self, request: Request, challenge: "_DigestAuthChallenge" + ) -> str: + hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()] + + def digest(data: bytes) -> bytes: + return hash_func(data).hexdigest().encode() + + A1 = b":".join((self._username, challenge.realm, self._password)) + + path = request.url.raw_path + A2 = b":".join((request.method.encode(), path)) + # TODO: implement auth-int + HA2 = digest(A2) + + nc_value = b"%08x" % self._nonce_count + cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce) + self._nonce_count += 1 + + HA1 = digest(A1) + if challenge.algorithm.lower().endswith("-sess"): + HA1 = digest(b":".join((HA1, challenge.nonce, cnonce))) + + qop = self._resolve_qop(challenge.qop, request=request) + if qop is None: + digest_data = [HA1, challenge.nonce, HA2] + else: + digest_data = [challenge.nonce, nc_value, cnonce, qop, HA2] + key_digest = b":".join(digest_data) + + format_args = { + "username": self._username, + "realm": challenge.realm, + "nonce": challenge.nonce, + "uri": path, + "response": digest(b":".join((HA1, key_digest))), + "algorithm": challenge.algorithm.encode(), + } + if challenge.opaque: + format_args["opaque"] = challenge.opaque + if qop: + format_args["qop"] = b"auth" + format_args["nc"] = nc_value + format_args["cnonce"] = cnonce + + return "Digest " + self._get_header_value(format_args) + + def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes: + s = str(nonce_count).encode() + s += nonce + s += time.ctime().encode() + s += os.urandom(8) + + return hashlib.sha1(s).hexdigest()[:16].encode() + + def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str: + NON_QUOTED_FIELDS = ("algorithm", "qop", "nc") + QUOTED_TEMPLATE = '{}="{}"' + NON_QUOTED_TEMPLATE = "{}={}" + + header_value = "" + for i, (field, value) in enumerate(header_fields.items()): + if i > 0: + header_value += ", " + template = ( + QUOTED_TEMPLATE + if field not in NON_QUOTED_FIELDS + else NON_QUOTED_TEMPLATE + ) + header_value += template.format(field, to_str(value)) + + return header_value + + def _resolve_qop( + self, qop: typing.Optional[bytes], request: Request + ) -> typing.Optional[bytes]: + if qop is None: + return None + qops = re.split(b", ?", qop) + if b"auth" in qops: + return b"auth" + + if qops == [b"auth-int"]: + raise NotImplementedError("Digest auth-int support is not yet implemented") + + message = f'Unexpected qop value "{qop!r}" in digest auth' + raise ProtocolError(message, request=request) + + +class _DigestAuthChallenge(typing.NamedTuple): + realm: bytes + nonce: bytes + algorithm: str + opaque: typing.Optional[bytes] + qop: typing.Optional[bytes] diff --git a/contrib/python/httpx/httpx/_client.py b/contrib/python/httpx/httpx/_client.py new file mode 100644 index 0000000000..cb475e0204 --- /dev/null +++ b/contrib/python/httpx/httpx/_client.py @@ -0,0 +1,2006 @@ +import datetime +import enum +import logging +import typing +import warnings +from contextlib import asynccontextmanager, contextmanager +from types import TracebackType + +from .__version__ import __version__ +from ._auth import Auth, BasicAuth, FunctionAuth +from ._config import ( + DEFAULT_LIMITS, + DEFAULT_MAX_REDIRECTS, + DEFAULT_TIMEOUT_CONFIG, + Limits, + Proxy, + Timeout, +) +from ._decoders import SUPPORTED_DECODERS +from ._exceptions import ( + InvalidURL, + RemoteProtocolError, + TooManyRedirects, + request_context, +) +from ._models import Cookies, Headers, Request, Response +from ._status_codes import codes +from ._transports.asgi import ASGITransport +from ._transports.base import AsyncBaseTransport, BaseTransport +from ._transports.default import AsyncHTTPTransport, HTTPTransport +from ._transports.wsgi import WSGITransport +from ._types import ( + AsyncByteStream, + AuthTypes, + CertTypes, + CookieTypes, + HeaderTypes, + ProxiesTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestExtensions, + RequestFiles, + SyncByteStream, + TimeoutTypes, + URLTypes, + VerifyTypes, +) +from ._urls import URL, QueryParams +from ._utils import ( + Timer, + URLPattern, + get_environment_proxies, + is_https_redirect, + same_origin, +) + +# The type annotation for @classmethod and context managers here follows PEP 484 +# https://www.python.org/dev/peps/pep-0484/#annotating-instance-and-class-methods +T = typing.TypeVar("T", bound="Client") +U = typing.TypeVar("U", bound="AsyncClient") + + +class UseClientDefault: + """ + For some parameters such as `auth=...` and `timeout=...` we need to be able + to indicate the default "unset" state, in a way that is distinctly different + to using `None`. + + The default "unset" state indicates that whatever default is set on the + client should be used. This is different to setting `None`, which + explicitly disables the parameter, possibly overriding a client default. + + For example we use `timeout=USE_CLIENT_DEFAULT` in the `request()` signature. + Omitting the `timeout` parameter will send a request using whatever default + timeout has been configured on the client. Including `timeout=None` will + ensure no timeout is used. + + Note that user code shouldn't need to use the `USE_CLIENT_DEFAULT` constant, + but it is used internally when a parameter is not included. + """ + + +USE_CLIENT_DEFAULT = UseClientDefault() + + +logger = logging.getLogger("httpx") + +USER_AGENT = f"python-httpx/{__version__}" +ACCEPT_ENCODING = ", ".join( + [key for key in SUPPORTED_DECODERS.keys() if key != "identity"] +) + + +class ClientState(enum.Enum): + # UNOPENED: + # The client has been instantiated, but has not been used to send a request, + # or been opened by entering the context of a `with` block. + UNOPENED = 1 + # OPENED: + # The client has either sent a request, or is within a `with` block. + OPENED = 2 + # CLOSED: + # The client has either exited the `with` block, or `close()` has + # been called explicitly. + CLOSED = 3 + + +class BoundSyncStream(SyncByteStream): + """ + A byte stream that is bound to a given response instance, and that + ensures the `response.elapsed` is set once the response is closed. + """ + + def __init__( + self, stream: SyncByteStream, response: Response, timer: Timer + ) -> None: + self._stream = stream + self._response = response + self._timer = timer + + def __iter__(self) -> typing.Iterator[bytes]: + for chunk in self._stream: + yield chunk + + def close(self) -> None: + seconds = self._timer.sync_elapsed() + self._response.elapsed = datetime.timedelta(seconds=seconds) + self._stream.close() + + +class BoundAsyncStream(AsyncByteStream): + """ + An async byte stream that is bound to a given response instance, and that + ensures the `response.elapsed` is set once the response is closed. + """ + + def __init__( + self, stream: AsyncByteStream, response: Response, timer: Timer + ) -> None: + self._stream = stream + self._response = response + self._timer = timer + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async for chunk in self._stream: + yield chunk + + async def aclose(self) -> None: + seconds = await self._timer.async_elapsed() + self._response.elapsed = datetime.timedelta(seconds=seconds) + await self._stream.aclose() + + +EventHook = typing.Callable[..., typing.Any] + + +class BaseClient: + def __init__( + self, + *, + auth: typing.Optional[AuthTypes] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + event_hooks: typing.Optional[ + typing.Mapping[str, typing.List[EventHook]] + ] = None, + base_url: URLTypes = "", + trust_env: bool = True, + default_encoding: typing.Union[str, typing.Callable[[bytes], str]] = "utf-8", + ): + event_hooks = {} if event_hooks is None else event_hooks + + self._base_url = self._enforce_trailing_slash(URL(base_url)) + + self._auth = self._build_auth(auth) + self._params = QueryParams(params) + self.headers = Headers(headers) + self._cookies = Cookies(cookies) + self._timeout = Timeout(timeout) + self.follow_redirects = follow_redirects + self.max_redirects = max_redirects + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + self._trust_env = trust_env + self._default_encoding = default_encoding + self._state = ClientState.UNOPENED + + @property + def is_closed(self) -> bool: + """ + Check if the client being closed + """ + return self._state == ClientState.CLOSED + + @property + def trust_env(self) -> bool: + return self._trust_env + + def _enforce_trailing_slash(self, url: URL) -> URL: + if url.raw_path.endswith(b"/"): + return url + return url.copy_with(raw_path=url.raw_path + b"/") + + def _get_proxy_map( + self, proxies: typing.Optional[ProxiesTypes], allow_env_proxies: bool + ) -> typing.Dict[str, typing.Optional[Proxy]]: + if proxies is None: + if allow_env_proxies: + return { + key: None if url is None else Proxy(url=url) + for key, url in get_environment_proxies().items() + } + return {} + if isinstance(proxies, dict): + new_proxies = {} + for key, value in proxies.items(): + proxy = Proxy(url=value) if isinstance(value, (str, URL)) else value + new_proxies[str(key)] = proxy + return new_proxies + else: + proxy = Proxy(url=proxies) if isinstance(proxies, (str, URL)) else proxies + return {"all://": proxy} + + @property + def timeout(self) -> Timeout: + return self._timeout + + @timeout.setter + def timeout(self, timeout: TimeoutTypes) -> None: + self._timeout = Timeout(timeout) + + @property + def event_hooks(self) -> typing.Dict[str, typing.List[EventHook]]: + return self._event_hooks + + @event_hooks.setter + def event_hooks( + self, event_hooks: typing.Dict[str, typing.List[EventHook]] + ) -> None: + self._event_hooks = { + "request": list(event_hooks.get("request", [])), + "response": list(event_hooks.get("response", [])), + } + + @property + def auth(self) -> typing.Optional[Auth]: + """ + Authentication class used when none is passed at the request-level. + + See also [Authentication][0]. + + [0]: /quickstart/#authentication + """ + return self._auth + + @auth.setter + def auth(self, auth: AuthTypes) -> None: + self._auth = self._build_auth(auth) + + @property + def base_url(self) -> URL: + """ + Base URL to use when sending requests with relative URLs. + """ + return self._base_url + + @base_url.setter + def base_url(self, url: URLTypes) -> None: + self._base_url = self._enforce_trailing_slash(URL(url)) + + @property + def headers(self) -> Headers: + """ + HTTP headers to include when sending requests. + """ + return self._headers + + @headers.setter + def headers(self, headers: HeaderTypes) -> None: + client_headers = Headers( + { + b"Accept": b"*/*", + b"Accept-Encoding": ACCEPT_ENCODING.encode("ascii"), + b"Connection": b"keep-alive", + b"User-Agent": USER_AGENT.encode("ascii"), + } + ) + client_headers.update(headers) + self._headers = client_headers + + @property + def cookies(self) -> Cookies: + """ + Cookie values to include when sending requests. + """ + return self._cookies + + @cookies.setter + def cookies(self, cookies: CookieTypes) -> None: + self._cookies = Cookies(cookies) + + @property + def params(self) -> QueryParams: + """ + Query parameters to include in the URL when sending requests. + """ + return self._params + + @params.setter + def params(self, params: QueryParamTypes) -> None: + self._params = QueryParams(params) + + def build_request( + self, + method: str, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Request: + """ + Build and return a request instance. + + * The `params`, `headers` and `cookies` arguments + are merged with any values set on the client. + * The `url` argument is merged with any `base_url` set on the client. + + See also: [Request instances][0] + + [0]: /advanced/#request-instances + """ + url = self._merge_url(url) + headers = self._merge_headers(headers) + cookies = self._merge_cookies(cookies) + params = self._merge_queryparams(params) + extensions = {} if extensions is None else extensions + if "timeout" not in extensions: + timeout = ( + self.timeout + if isinstance(timeout, UseClientDefault) + else Timeout(timeout) + ) + extensions = dict(**extensions, timeout=timeout.as_dict()) + return Request( + method, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + extensions=extensions, + ) + + def _merge_url(self, url: URLTypes) -> URL: + """ + Merge a URL argument together with any 'base_url' on the client, + to create the URL used for the outgoing request. + """ + merge_url = URL(url) + if merge_url.is_relative_url: + # To merge URLs we always append to the base URL. To get this + # behaviour correct we always ensure the base URL ends in a '/' + # separator, and strip any leading '/' from the merge URL. + # + # So, eg... + # + # >>> client = Client(base_url="https://www.example.com/subpath") + # >>> client.base_url + # URL('https://www.example.com/subpath/') + # >>> client.build_request("GET", "/path").url + # URL('https://www.example.com/subpath/path') + merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b"/") + return self.base_url.copy_with(raw_path=merge_raw_path) + return merge_url + + def _merge_cookies( + self, cookies: typing.Optional[CookieTypes] = None + ) -> typing.Optional[CookieTypes]: + """ + Merge a cookies argument together with any cookies on the client, + to create the cookies used for the outgoing request. + """ + if cookies or self.cookies: + merged_cookies = Cookies(self.cookies) + merged_cookies.update(cookies) + return merged_cookies + return cookies + + def _merge_headers( + self, headers: typing.Optional[HeaderTypes] = None + ) -> typing.Optional[HeaderTypes]: + """ + Merge a headers argument together with any headers on the client, + to create the headers used for the outgoing request. + """ + merged_headers = Headers(self.headers) + merged_headers.update(headers) + return merged_headers + + def _merge_queryparams( + self, params: typing.Optional[QueryParamTypes] = None + ) -> typing.Optional[QueryParamTypes]: + """ + Merge a queryparams argument together with any queryparams on the client, + to create the queryparams used for the outgoing request. + """ + if params or self.params: + merged_queryparams = QueryParams(self.params) + return merged_queryparams.merge(params) + return params + + def _build_auth(self, auth: typing.Optional[AuthTypes]) -> typing.Optional[Auth]: + if auth is None: + return None + elif isinstance(auth, tuple): + return BasicAuth(username=auth[0], password=auth[1]) + elif isinstance(auth, Auth): + return auth + elif callable(auth): + return FunctionAuth(func=auth) + else: + raise TypeError(f'Invalid "auth" argument: {auth!r}') + + def _build_request_auth( + self, + request: Request, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + ) -> Auth: + auth = ( + self._auth if isinstance(auth, UseClientDefault) else self._build_auth(auth) + ) + + if auth is not None: + return auth + + username, password = request.url.username, request.url.password + if username or password: + return BasicAuth(username=username, password=password) + + return Auth() + + def _build_redirect_request(self, request: Request, response: Response) -> Request: + """ + Given a request and a redirect response, return a new request that + should be used to effect the redirect. + """ + method = self._redirect_method(request, response) + url = self._redirect_url(request, response) + headers = self._redirect_headers(request, url, method) + stream = self._redirect_stream(request, method) + cookies = Cookies(self.cookies) + return Request( + method=method, + url=url, + headers=headers, + cookies=cookies, + stream=stream, + extensions=request.extensions, + ) + + def _redirect_method(self, request: Request, response: Response) -> str: + """ + When being redirected we may want to change the method of the request + based on certain specs or browser behavior. + """ + method = request.method + + # https://tools.ietf.org/html/rfc7231#section-6.4.4 + if response.status_code == codes.SEE_OTHER and method != "HEAD": + method = "GET" + + # Do what the browsers do, despite standards... + # Turn 302s into GETs. + if response.status_code == codes.FOUND and method != "HEAD": + method = "GET" + + # If a POST is responded to with a 301, turn it into a GET. + # This bizarre behaviour is explained in 'requests' issue 1704. + if response.status_code == codes.MOVED_PERMANENTLY and method == "POST": + method = "GET" + + return method + + def _redirect_url(self, request: Request, response: Response) -> URL: + """ + Return the URL for the redirect to follow. + """ + location = response.headers["Location"] + + try: + url = URL(location) + except InvalidURL as exc: + raise RemoteProtocolError( + f"Invalid URL in location header: {exc}.", request=request + ) from None + + # Handle malformed 'Location' headers that are "absolute" form, have no host. + # See: https://github.com/encode/httpx/issues/771 + if url.scheme and not url.host: + url = url.copy_with(host=request.url.host) + + # Facilitate relative 'Location' headers, as allowed by RFC 7231. + # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') + if url.is_relative_url: + url = request.url.join(url) + + # Attach previous fragment if needed (RFC 7231 7.1.2) + if request.url.fragment and not url.fragment: + url = url.copy_with(fragment=request.url.fragment) + + return url + + def _redirect_headers(self, request: Request, url: URL, method: str) -> Headers: + """ + Return the headers that should be used for the redirect request. + """ + headers = Headers(request.headers) + + if not same_origin(url, request.url): + if not is_https_redirect(request.url, url): + # Strip Authorization headers when responses are redirected + # away from the origin. (Except for direct HTTP to HTTPS redirects.) + headers.pop("Authorization", None) + + # Update the Host header. + headers["Host"] = url.netloc.decode("ascii") + + if method != request.method and method == "GET": + # If we've switch to a 'GET' request, then strip any headers which + # are only relevant to the request body. + headers.pop("Content-Length", None) + headers.pop("Transfer-Encoding", None) + + # We should use the client cookie store to determine any cookie header, + # rather than whatever was on the original outgoing request. + headers.pop("Cookie", None) + + return headers + + def _redirect_stream( + self, request: Request, method: str + ) -> typing.Optional[typing.Union[SyncByteStream, AsyncByteStream]]: + """ + Return the body that should be used for the redirect request. + """ + if method != request.method and method == "GET": + return None + + return request.stream + + +class Client(BaseClient): + """ + An HTTP client, with connection pooling, HTTP/2, redirects, cookie persistence, etc. + + It can be shared between threads. + + Usage: + + ```python + >>> client = httpx.Client() + >>> response = client.get('https://example.org') + ``` + + **Parameters:** + + * **auth** - *(optional)* An authentication class to use when sending + requests. + * **params** - *(optional)* Query parameters to include in request URLs, as + a string, dictionary, or sequence of two-tuples. + * **headers** - *(optional)* Dictionary of HTTP headers to include when + sending requests. + * **cookies** - *(optional)* Dictionary of Cookie items to include when + sending requests. + * **verify** - *(optional)* SSL certificates (a.k.a CA bundle) used to + verify the identity of requested hosts. Either `True` (default CA bundle), + a path to an SSL certificate file, an `ssl.SSLContext`, or `False` + (which will disable verification). + * **cert** - *(optional)* An SSL certificate used by the requested host + to authenticate the client. Either a path to an SSL certificate file, or + two-tuple of (certificate file, key file), or a three-tuple of (certificate + file, key file, password). + * **proxies** - *(optional)* A dictionary mapping proxy keys to proxy + URLs. + * **timeout** - *(optional)* The timeout configuration to use when sending + requests. + * **limits** - *(optional)* The limits configuration to use. + * **max_redirects** - *(optional)* The maximum number of redirect responses + that should be followed. + * **base_url** - *(optional)* A URL to use as the base when building + request URLs. + * **transport** - *(optional)* A transport class to use for sending requests + over the network. + * **app** - *(optional)* An WSGI application to send requests to, + rather than sending actual network requests. + * **trust_env** - *(optional)* Enables or disables usage of environment + variables for configuration. + * **default_encoding** - *(optional)* The default encoding to use for decoding + response text, if no charset information is included in a response Content-Type + header. Set to a callable for automatic character set detection. Default: "utf-8". + """ + + def __init__( + self, + *, + auth: typing.Optional[AuthTypes] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + proxies: typing.Optional[ProxiesTypes] = None, + mounts: typing.Optional[typing.Mapping[str, BaseTransport]] = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + limits: Limits = DEFAULT_LIMITS, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + event_hooks: typing.Optional[ + typing.Mapping[str, typing.List[EventHook]] + ] = None, + base_url: URLTypes = "", + transport: typing.Optional[BaseTransport] = None, + app: typing.Optional[typing.Callable[..., typing.Any]] = None, + trust_env: bool = True, + default_encoding: typing.Union[str, typing.Callable[[bytes], str]] = "utf-8", + ): + super().__init__( + auth=auth, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + follow_redirects=follow_redirects, + max_redirects=max_redirects, + event_hooks=event_hooks, + base_url=base_url, + trust_env=trust_env, + default_encoding=default_encoding, + ) + + if http2: + try: + import h2 # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using http2=True, but the 'h2' package is not installed. " + "Make sure to install httpx using `pip install httpx[http2]`." + ) from None + + allow_env_proxies = trust_env and app is None and transport is None + proxy_map = self._get_proxy_map(proxies, allow_env_proxies) + + self._transport = self._init_transport( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + transport=transport, + app=app, + trust_env=trust_env, + ) + self._mounts: typing.Dict[URLPattern, typing.Optional[BaseTransport]] = { + URLPattern(key): None + if proxy is None + else self._init_proxy_transport( + proxy, + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + ) + for key, proxy in proxy_map.items() + } + if mounts is not None: + self._mounts.update( + {URLPattern(key): transport for key, transport in mounts.items()} + ) + + self._mounts = dict(sorted(self._mounts.items())) + + def _init_transport( + self, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + transport: typing.Optional[BaseTransport] = None, + app: typing.Optional[typing.Callable[..., typing.Any]] = None, + trust_env: bool = True, + ) -> BaseTransport: + if transport is not None: + return transport + + if app is not None: + return WSGITransport(app=app) + + return HTTPTransport( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + ) + + def _init_proxy_transport( + self, + proxy: Proxy, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + trust_env: bool = True, + ) -> BaseTransport: + return HTTPTransport( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + proxy=proxy, + ) + + def _transport_for_url(self, url: URL) -> BaseTransport: + """ + Returns the transport instance that should be used for a given URL. + This will either be the standard connection pool, or a proxy. + """ + for pattern, transport in self._mounts.items(): + if pattern.matches(url): + return self._transport if transport is None else transport + + return self._transport + + def request( + self, + method: str, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Build and send a request. + + Equivalent to: + + ```python + request = client.build_request(...) + response = client.send(request, ...) + ``` + + See `Client.build_request()`, `Client.send()` and + [Merging of configuration][0] for how the various parameters + are merged with client-level configuration. + + [0]: /advanced/#merging-of-configuration + """ + if cookies is not None: + message = ( + "Setting per-request cookies=<...> is being deprecated, because " + "the expected behaviour on cookie persistence is ambiguous. Set " + "cookies directly on the client instance instead." + ) + warnings.warn(message, DeprecationWarning) + + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + return self.send(request, auth=auth, follow_redirects=follow_redirects) + + @contextmanager + def stream( + self, + method: str, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> typing.Iterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + response = self.send( + request=request, + auth=auth, + follow_redirects=follow_redirects, + stream=True, + ) + try: + yield response + finally: + response.close() + + def send( + self, + request: Request, + *, + stream: bool = False, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + ) -> Response: + """ + Send a request. + + The request is sent as-is, unmodified. + + Typically you'll want to build one with `Client.build_request()` + so that any client-level configuration is merged into the request, + but passing an explicit `httpx.Request()` is supported as well. + + See also: [Request instances][0] + + [0]: /advanced/#request-instances + """ + if self._state == ClientState.CLOSED: + raise RuntimeError("Cannot send a request, as the client has been closed.") + + self._state = ClientState.OPENED + follow_redirects = ( + self.follow_redirects + if isinstance(follow_redirects, UseClientDefault) + else follow_redirects + ) + + auth = self._build_request_auth(request, auth) + + response = self._send_handling_auth( + request, + auth=auth, + follow_redirects=follow_redirects, + history=[], + ) + try: + if not stream: + response.read() + + return response + + except BaseException as exc: + response.close() + raise exc + + def _send_handling_auth( + self, + request: Request, + auth: Auth, + follow_redirects: bool, + history: typing.List[Response], + ) -> Response: + auth_flow = auth.sync_auth_flow(request) + try: + request = next(auth_flow) + + while True: + response = self._send_handling_redirects( + request, + follow_redirects=follow_redirects, + history=history, + ) + try: + try: + next_request = auth_flow.send(response) + except StopIteration: + return response + + response.history = list(history) + response.read() + request = next_request + history.append(response) + + except BaseException as exc: + response.close() + raise exc + finally: + auth_flow.close() + + def _send_handling_redirects( + self, + request: Request, + follow_redirects: bool, + history: typing.List[Response], + ) -> Response: + while True: + if len(history) > self.max_redirects: + raise TooManyRedirects( + "Exceeded maximum allowed redirects.", request=request + ) + + for hook in self._event_hooks["request"]: + hook(request) + + response = self._send_single_request(request) + try: + for hook in self._event_hooks["response"]: + hook(response) + response.history = list(history) + + if not response.has_redirect_location: + return response + + request = self._build_redirect_request(request, response) + history = history + [response] + + if follow_redirects: + response.read() + else: + response.next_request = request + return response + + except BaseException as exc: + response.close() + raise exc + + def _send_single_request(self, request: Request) -> Response: + """ + Sends a single request, without handling any redirections. + """ + transport = self._transport_for_url(request.url) + timer = Timer() + timer.sync_start() + + if not isinstance(request.stream, SyncByteStream): + raise RuntimeError( + "Attempted to send an async request with a sync Client instance." + ) + + with request_context(request=request): + response = transport.handle_request(request) + + assert isinstance(response.stream, SyncByteStream) + + response.request = request + response.stream = BoundSyncStream( + response.stream, response=response, timer=timer + ) + self.cookies.extract_cookies(response) + response.default_encoding = self._default_encoding + + logger.info( + 'HTTP Request: %s %s "%s %d %s"', + request.method, + request.url, + response.http_version, + response.status_code, + response.reason_phrase, + ) + + return response + + def get( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `GET` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def options( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send an `OPTIONS` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def head( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `HEAD` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def post( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `POST` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "POST", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def put( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `PUT` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "PUT", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def patch( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `PATCH` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "PATCH", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def delete( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `DELETE` request. + + **Parameters**: See `httpx.request`. + """ + return self.request( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def close(self) -> None: + """ + Close transport and proxies. + """ + if self._state != ClientState.CLOSED: + self._state = ClientState.CLOSED + + self._transport.close() + for transport in self._mounts.values(): + if transport is not None: + transport.close() + + def __enter__(self: T) -> T: + if self._state != ClientState.UNOPENED: + msg = { + ClientState.OPENED: "Cannot open a client instance more than once.", + ClientState.CLOSED: "Cannot reopen a client instance, once it has been closed.", + }[self._state] + raise RuntimeError(msg) + + self._state = ClientState.OPENED + + self._transport.__enter__() + for transport in self._mounts.values(): + if transport is not None: + transport.__enter__() + return self + + def __exit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + self._state = ClientState.CLOSED + + self._transport.__exit__(exc_type, exc_value, traceback) + for transport in self._mounts.values(): + if transport is not None: + transport.__exit__(exc_type, exc_value, traceback) + + +class AsyncClient(BaseClient): + """ + An asynchronous HTTP client, with connection pooling, HTTP/2, redirects, + cookie persistence, etc. + + Usage: + + ```python + >>> async with httpx.AsyncClient() as client: + >>> response = await client.get('https://example.org') + ``` + + **Parameters:** + + * **auth** - *(optional)* An authentication class to use when sending + requests. + * **params** - *(optional)* Query parameters to include in request URLs, as + a string, dictionary, or sequence of two-tuples. + * **headers** - *(optional)* Dictionary of HTTP headers to include when + sending requests. + * **cookies** - *(optional)* Dictionary of Cookie items to include when + sending requests. + * **verify** - *(optional)* SSL certificates (a.k.a CA bundle) used to + verify the identity of requested hosts. Either `True` (default CA bundle), + a path to an SSL certificate file, an `ssl.SSLContext`, or `False` + (which will disable verification). + * **cert** - *(optional)* An SSL certificate used by the requested host + to authenticate the client. Either a path to an SSL certificate file, or + two-tuple of (certificate file, key file), or a three-tuple of (certificate + file, key file, password). + * **http2** - *(optional)* A boolean indicating if HTTP/2 support should be + enabled. Defaults to `False`. + * **proxies** - *(optional)* A dictionary mapping HTTP protocols to proxy + URLs. + * **timeout** - *(optional)* The timeout configuration to use when sending + requests. + * **limits** - *(optional)* The limits configuration to use. + * **max_redirects** - *(optional)* The maximum number of redirect responses + that should be followed. + * **base_url** - *(optional)* A URL to use as the base when building + request URLs. + * **transport** - *(optional)* A transport class to use for sending requests + over the network. + * **app** - *(optional)* An ASGI application to send requests to, + rather than sending actual network requests. + * **trust_env** - *(optional)* Enables or disables usage of environment + variables for configuration. + * **default_encoding** - *(optional)* The default encoding to use for decoding + response text, if no charset information is included in a response Content-Type + header. Set to a callable for automatic character set detection. Default: "utf-8". + """ + + def __init__( + self, + *, + auth: typing.Optional[AuthTypes] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + proxies: typing.Optional[ProxiesTypes] = None, + mounts: typing.Optional[typing.Mapping[str, AsyncBaseTransport]] = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + limits: Limits = DEFAULT_LIMITS, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + event_hooks: typing.Optional[ + typing.Mapping[str, typing.List[typing.Callable[..., typing.Any]]] + ] = None, + base_url: URLTypes = "", + transport: typing.Optional[AsyncBaseTransport] = None, + app: typing.Optional[typing.Callable[..., typing.Any]] = None, + trust_env: bool = True, + default_encoding: typing.Union[str, typing.Callable[[bytes], str]] = "utf-8", + ): + super().__init__( + auth=auth, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + follow_redirects=follow_redirects, + max_redirects=max_redirects, + event_hooks=event_hooks, + base_url=base_url, + trust_env=trust_env, + default_encoding=default_encoding, + ) + + if http2: + try: + import h2 # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using http2=True, but the 'h2' package is not installed. " + "Make sure to install httpx using `pip install httpx[http2]`." + ) from None + + allow_env_proxies = trust_env and app is None and transport is None + proxy_map = self._get_proxy_map(proxies, allow_env_proxies) + + self._transport = self._init_transport( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + transport=transport, + app=app, + trust_env=trust_env, + ) + + self._mounts: typing.Dict[URLPattern, typing.Optional[AsyncBaseTransport]] = { + URLPattern(key): None + if proxy is None + else self._init_proxy_transport( + proxy, + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + ) + for key, proxy in proxy_map.items() + } + if mounts is not None: + self._mounts.update( + {URLPattern(key): transport for key, transport in mounts.items()} + ) + self._mounts = dict(sorted(self._mounts.items())) + + def _init_transport( + self, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + transport: typing.Optional[AsyncBaseTransport] = None, + app: typing.Optional[typing.Callable[..., typing.Any]] = None, + trust_env: bool = True, + ) -> AsyncBaseTransport: + if transport is not None: + return transport + + if app is not None: + return ASGITransport(app=app) + + return AsyncHTTPTransport( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + ) + + def _init_proxy_transport( + self, + proxy: Proxy, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + trust_env: bool = True, + ) -> AsyncBaseTransport: + return AsyncHTTPTransport( + verify=verify, + cert=cert, + http2=http2, + limits=limits, + trust_env=trust_env, + proxy=proxy, + ) + + def _transport_for_url(self, url: URL) -> AsyncBaseTransport: + """ + Returns the transport instance that should be used for a given URL. + This will either be the standard connection pool, or a proxy. + """ + for pattern, transport in self._mounts.items(): + if pattern.matches(url): + return self._transport if transport is None else transport + + return self._transport + + async def request( + self, + method: str, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Build and send a request. + + Equivalent to: + + ```python + request = client.build_request(...) + response = await client.send(request, ...) + ``` + + See `AsyncClient.build_request()`, `AsyncClient.send()` + and [Merging of configuration][0] for how the various parameters + are merged with client-level configuration. + + [0]: /advanced/#merging-of-configuration + """ + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + return await self.send(request, auth=auth, follow_redirects=follow_redirects) + + @asynccontextmanager + async def stream( + self, + method: str, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> typing.AsyncIterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + extensions=extensions, + ) + response = await self.send( + request=request, + auth=auth, + follow_redirects=follow_redirects, + stream=True, + ) + try: + yield response + finally: + await response.aclose() + + async def send( + self, + request: Request, + *, + stream: bool = False, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + ) -> Response: + """ + Send a request. + + The request is sent as-is, unmodified. + + Typically you'll want to build one with `AsyncClient.build_request()` + so that any client-level configuration is merged into the request, + but passing an explicit `httpx.Request()` is supported as well. + + See also: [Request instances][0] + + [0]: /advanced/#request-instances + """ + if self._state == ClientState.CLOSED: + raise RuntimeError("Cannot send a request, as the client has been closed.") + + self._state = ClientState.OPENED + follow_redirects = ( + self.follow_redirects + if isinstance(follow_redirects, UseClientDefault) + else follow_redirects + ) + + auth = self._build_request_auth(request, auth) + + response = await self._send_handling_auth( + request, + auth=auth, + follow_redirects=follow_redirects, + history=[], + ) + try: + if not stream: + await response.aread() + + return response + + except BaseException as exc: # pragma: no cover + await response.aclose() + raise exc + + async def _send_handling_auth( + self, + request: Request, + auth: Auth, + follow_redirects: bool, + history: typing.List[Response], + ) -> Response: + auth_flow = auth.async_auth_flow(request) + try: + request = await auth_flow.__anext__() + + while True: + response = await self._send_handling_redirects( + request, + follow_redirects=follow_redirects, + history=history, + ) + try: + try: + next_request = await auth_flow.asend(response) + except StopAsyncIteration: + return response + + response.history = list(history) + await response.aread() + request = next_request + history.append(response) + + except BaseException as exc: + await response.aclose() + raise exc + finally: + await auth_flow.aclose() + + async def _send_handling_redirects( + self, + request: Request, + follow_redirects: bool, + history: typing.List[Response], + ) -> Response: + while True: + if len(history) > self.max_redirects: + raise TooManyRedirects( + "Exceeded maximum allowed redirects.", request=request + ) + + for hook in self._event_hooks["request"]: + await hook(request) + + response = await self._send_single_request(request) + try: + for hook in self._event_hooks["response"]: + await hook(response) + + response.history = list(history) + + if not response.has_redirect_location: + return response + + request = self._build_redirect_request(request, response) + history = history + [response] + + if follow_redirects: + await response.aread() + else: + response.next_request = request + return response + + except BaseException as exc: + await response.aclose() + raise exc + + async def _send_single_request(self, request: Request) -> Response: + """ + Sends a single request, without handling any redirections. + """ + transport = self._transport_for_url(request.url) + timer = Timer() + await timer.async_start() + + if not isinstance(request.stream, AsyncByteStream): + raise RuntimeError( + "Attempted to send an sync request with an AsyncClient instance." + ) + + with request_context(request=request): + response = await transport.handle_async_request(request) + + assert isinstance(response.stream, AsyncByteStream) + response.request = request + response.stream = BoundAsyncStream( + response.stream, response=response, timer=timer + ) + self.cookies.extract_cookies(response) + response.default_encoding = self._default_encoding + + logger.info( + 'HTTP Request: %s %s "%s %d %s"', + request.method, + request.url, + response.http_version, + response.status_code, + response.reason_phrase, + ) + + return response + + async def get( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault, None] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `GET` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def options( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send an `OPTIONS` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def head( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `HEAD` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def post( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `POST` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "POST", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def put( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `PUT` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "PUT", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def patch( + self, + url: URLTypes, + *, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `PATCH` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "PATCH", + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def delete( + self, + url: URLTypes, + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + auth: typing.Union[AuthTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + follow_redirects: typing.Union[bool, UseClientDefault] = USE_CLIENT_DEFAULT, + timeout: typing.Union[TimeoutTypes, UseClientDefault] = USE_CLIENT_DEFAULT, + extensions: typing.Optional[RequestExtensions] = None, + ) -> Response: + """ + Send a `DELETE` request. + + **Parameters**: See `httpx.request`. + """ + return await self.request( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def aclose(self) -> None: + """ + Close transport and proxies. + """ + if self._state != ClientState.CLOSED: + self._state = ClientState.CLOSED + + await self._transport.aclose() + for proxy in self._mounts.values(): + if proxy is not None: + await proxy.aclose() + + async def __aenter__(self: U) -> U: + if self._state != ClientState.UNOPENED: + msg = { + ClientState.OPENED: "Cannot open a client instance more than once.", + ClientState.CLOSED: "Cannot reopen a client instance, once it has been closed.", + }[self._state] + raise RuntimeError(msg) + + self._state = ClientState.OPENED + + await self._transport.__aenter__() + for proxy in self._mounts.values(): + if proxy is not None: + await proxy.__aenter__() + return self + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + self._state = ClientState.CLOSED + + await self._transport.__aexit__(exc_type, exc_value, traceback) + for proxy in self._mounts.values(): + if proxy is not None: + await proxy.__aexit__(exc_type, exc_value, traceback) diff --git a/contrib/python/httpx/httpx/_compat.py b/contrib/python/httpx/httpx/_compat.py new file mode 100644 index 0000000000..a271c6b800 --- /dev/null +++ b/contrib/python/httpx/httpx/_compat.py @@ -0,0 +1,43 @@ +""" +The _compat module is used for code which requires branching between different +Python environments. It is excluded from the code coverage checks. +""" +import ssl +import sys + +# Brotli support is optional +# The C bindings in `brotli` are recommended for CPython. +# The CFFI bindings in `brotlicffi` are recommended for PyPy and everything else. +try: + import brotlicffi as brotli +except ImportError: # pragma: no cover + try: + import brotli + except ImportError: + brotli = None + +if sys.version_info >= (3, 10) or ( + sys.version_info >= (3, 8) and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0, 7) +): + + def set_minimum_tls_version_1_2(context: ssl.SSLContext) -> None: + # The OP_NO_SSL* and OP_NO_TLS* become deprecated in favor of + # 'SSLContext.minimum_version' from Python 3.7 onwards, however + # this attribute is not available unless the ssl module is compiled + # with OpenSSL 1.1.0g or newer. + # https://docs.python.org/3.10/library/ssl.html#ssl.SSLContext.minimum_version + # https://docs.python.org/3.7/library/ssl.html#ssl.SSLContext.minimum_version + context.minimum_version = ssl.TLSVersion.TLSv1_2 + +else: + + def set_minimum_tls_version_1_2(context: ssl.SSLContext) -> None: + # If 'minimum_version' isn't available, we configure these options with + # the older deprecated variants. + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_TLSv1 + context.options |= ssl.OP_NO_TLSv1_1 + + +__all__ = ["brotli", "set_minimum_tls_version_1_2"] diff --git a/contrib/python/httpx/httpx/_config.py b/contrib/python/httpx/httpx/_config.py new file mode 100644 index 0000000000..8d4e03add5 --- /dev/null +++ b/contrib/python/httpx/httpx/_config.py @@ -0,0 +1,378 @@ +import logging +import os +import ssl +import sys +import typing +from pathlib import Path + +import certifi + +from ._compat import set_minimum_tls_version_1_2 +from ._models import Headers +from ._types import CertTypes, HeaderTypes, TimeoutTypes, URLTypes, VerifyTypes +from ._urls import URL +from ._utils import get_ca_bundle_from_env + +DEFAULT_CIPHERS = ":".join( + [ + "ECDHE+AESGCM", + "ECDHE+CHACHA20", + "DHE+AESGCM", + "DHE+CHACHA20", + "ECDH+AESGCM", + "DH+AESGCM", + "ECDH+AES", + "DH+AES", + "RSA+AESGCM", + "RSA+AES", + "!aNULL", + "!eNULL", + "!MD5", + "!DSS", + ] +) + + +logger = logging.getLogger("httpx") + + +class UnsetType: + pass # pragma: no cover + + +UNSET = UnsetType() + + +def create_ssl_context( + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + trust_env: bool = True, + http2: bool = False, +) -> ssl.SSLContext: + return SSLConfig( + cert=cert, verify=verify, trust_env=trust_env, http2=http2 + ).ssl_context + + +class SSLConfig: + """ + SSL Configuration. + """ + + DEFAULT_CA_BUNDLE_PATH = certifi.where() + if callable(DEFAULT_CA_BUNDLE_PATH): + DEFAULT_CA_BUNDLE_PATH = staticmethod(DEFAULT_CA_BUNDLE_PATH) + else: + DEFAULT_CA_BUNDLE_PATH = Path(DEFAULT_CA_BUNDLE_PATH) + + def __init__( + self, + *, + cert: typing.Optional[CertTypes] = None, + verify: VerifyTypes = True, + trust_env: bool = True, + http2: bool = False, + ): + self.cert = cert + self.verify = verify + self.trust_env = trust_env + self.http2 = http2 + self.ssl_context = self.load_ssl_context() + + def load_ssl_context(self) -> ssl.SSLContext: + logger.debug( + "load_ssl_context verify=%r cert=%r trust_env=%r http2=%r", + self.verify, + self.cert, + self.trust_env, + self.http2, + ) + + if self.verify: + return self.load_ssl_context_verify() + return self.load_ssl_context_no_verify() + + def load_ssl_context_no_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for unverified connections. + """ + context = self._create_default_ssl_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + self._load_client_certs(context) + return context + + def load_ssl_context_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for verified connections. + """ + if self.trust_env and self.verify is True: + ca_bundle = get_ca_bundle_from_env() + if ca_bundle is not None: + self.verify = ca_bundle + + if isinstance(self.verify, ssl.SSLContext): + # Allow passing in our own SSLContext object that's pre-configured. + context = self.verify + self._load_client_certs(context) + return context + elif isinstance(self.verify, bool): + ca_bundle_path = self.DEFAULT_CA_BUNDLE_PATH + elif Path(self.verify).exists(): + ca_bundle_path = Path(self.verify) + else: + raise IOError( + "Could not find a suitable TLS CA certificate bundle, " + "invalid path: {}".format(self.verify) + ) + + context = self._create_default_ssl_context() + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + + # Signal to server support for PHA in TLS 1.3. Raises an + # AttributeError if only read-only access is implemented. + if sys.version_info >= (3, 8): # pragma: no cover + try: + context.post_handshake_auth = True + except AttributeError: # pragma: no cover + pass + + # Disable using 'commonName' for SSLContext.check_hostname + # when the 'subjectAltName' extension isn't available. + try: + context.hostname_checks_common_name = False + except AttributeError: # pragma: no cover + pass + + if callable(ca_bundle_path): + logger.debug("load_verify_locations cafile=%r", ca_bundle_path) + context.load_verify_locations(cafile=ca_bundle_path) + elif ca_bundle_path.is_file(): + cafile = str(ca_bundle_path) + logger.debug("load_verify_locations cafile=%r", cafile) + context.load_verify_locations(cafile=cafile) + elif ca_bundle_path.is_dir(): + capath = str(ca_bundle_path) + logger.debug("load_verify_locations capath=%r", capath) + context.load_verify_locations(capath=capath) + + self._load_client_certs(context) + + return context + + def _create_default_ssl_context(self) -> ssl.SSLContext: + """ + Creates the default SSLContext object that's used for both verified + and unverified connections. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + set_minimum_tls_version_1_2(context) + context.options |= ssl.OP_NO_COMPRESSION + context.set_ciphers(DEFAULT_CIPHERS) + + if ssl.HAS_ALPN: + alpn_idents = ["http/1.1", "h2"] if self.http2 else ["http/1.1"] + context.set_alpn_protocols(alpn_idents) + + if sys.version_info >= (3, 8): # pragma: no cover + keylogfile = os.environ.get("SSLKEYLOGFILE") + if keylogfile and self.trust_env: + context.keylog_filename = keylogfile + + return context + + def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None: + """ + Loads client certificates into our SSLContext object + """ + if self.cert is not None: + if isinstance(self.cert, str): + ssl_context.load_cert_chain(certfile=self.cert) + elif isinstance(self.cert, tuple) and len(self.cert) == 2: + ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1]) + elif isinstance(self.cert, tuple) and len(self.cert) == 3: + ssl_context.load_cert_chain( + certfile=self.cert[0], + keyfile=self.cert[1], + password=self.cert[2], # type: ignore + ) + + +class Timeout: + """ + Timeout configuration. + + **Usage**: + + Timeout(None) # No timeouts. + Timeout(5.0) # 5s timeout on all operations. + Timeout(None, connect=5.0) # 5s timeout on connect, no other timeouts. + Timeout(5.0, connect=10.0) # 10s timeout on connect. 5s timeout elsewhere. + Timeout(5.0, pool=None) # No timeout on acquiring connection from pool. + # 5s timeout elsewhere. + """ + + def __init__( + self, + timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, + *, + connect: typing.Union[None, float, UnsetType] = UNSET, + read: typing.Union[None, float, UnsetType] = UNSET, + write: typing.Union[None, float, UnsetType] = UNSET, + pool: typing.Union[None, float, UnsetType] = UNSET, + ): + if isinstance(timeout, Timeout): + # Passed as a single explicit Timeout. + assert connect is UNSET + assert read is UNSET + assert write is UNSET + assert pool is UNSET + self.connect = timeout.connect # type: typing.Optional[float] + self.read = timeout.read # type: typing.Optional[float] + self.write = timeout.write # type: typing.Optional[float] + self.pool = timeout.pool # type: typing.Optional[float] + elif isinstance(timeout, tuple): + # Passed as a tuple. + self.connect = timeout[0] + self.read = timeout[1] + self.write = None if len(timeout) < 3 else timeout[2] + self.pool = None if len(timeout) < 4 else timeout[3] + elif not ( + isinstance(connect, UnsetType) + or isinstance(read, UnsetType) + or isinstance(write, UnsetType) + or isinstance(pool, UnsetType) + ): + self.connect = connect + self.read = read + self.write = write + self.pool = pool + else: + if isinstance(timeout, UnsetType): + raise ValueError( + "httpx.Timeout must either include a default, or set all " + "four parameters explicitly." + ) + self.connect = timeout if isinstance(connect, UnsetType) else connect + self.read = timeout if isinstance(read, UnsetType) else read + self.write = timeout if isinstance(write, UnsetType) else write + self.pool = timeout if isinstance(pool, UnsetType) else pool + + def as_dict(self) -> typing.Dict[str, typing.Optional[float]]: + return { + "connect": self.connect, + "read": self.read, + "write": self.write, + "pool": self.pool, + } + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.connect == other.connect + and self.read == other.read + and self.write == other.write + and self.pool == other.pool + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + if len({self.connect, self.read, self.write, self.pool}) == 1: + return f"{class_name}(timeout={self.connect})" + return ( + f"{class_name}(connect={self.connect}, " + f"read={self.read}, write={self.write}, pool={self.pool})" + ) + + +class Limits: + """ + Configuration for limits to various client behaviors. + + **Parameters:** + + * **max_connections** - The maximum number of concurrent connections that may be + established. + * **max_keepalive_connections** - Allow the connection pool to maintain + keep-alive connections below this point. Should be less than or equal + to `max_connections`. + * **keepalive_expiry** - Time limit on idle keep-alive connections in seconds. + """ + + def __init__( + self, + *, + max_connections: typing.Optional[int] = None, + max_keepalive_connections: typing.Optional[int] = None, + keepalive_expiry: typing.Optional[float] = 5.0, + ): + self.max_connections = max_connections + self.max_keepalive_connections = max_keepalive_connections + self.keepalive_expiry = keepalive_expiry + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.max_connections == other.max_connections + and self.max_keepalive_connections == other.max_keepalive_connections + and self.keepalive_expiry == other.keepalive_expiry + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return ( + f"{class_name}(max_connections={self.max_connections}, " + f"max_keepalive_connections={self.max_keepalive_connections}, " + f"keepalive_expiry={self.keepalive_expiry})" + ) + + +class Proxy: + def __init__( + self, + url: URLTypes, + *, + ssl_context: typing.Optional[ssl.SSLContext] = None, + auth: typing.Optional[typing.Tuple[str, str]] = None, + headers: typing.Optional[HeaderTypes] = None, + ): + url = URL(url) + headers = Headers(headers) + + if url.scheme not in ("http", "https", "socks5"): + raise ValueError(f"Unknown scheme for proxy URL {url!r}") + + if url.username or url.password: + # Remove any auth credentials from the URL. + auth = (url.username, url.password) + url = url.copy_with(username=None, password=None) + + self.url = url + self.auth = auth + self.headers = headers + self.ssl_context = ssl_context + + @property + def raw_auth(self) -> typing.Optional[typing.Tuple[bytes, bytes]]: + # The proxy authentication as raw bytes. + return ( + None + if self.auth is None + else (self.auth[0].encode("utf-8"), self.auth[1].encode("utf-8")) + ) + + def __repr__(self) -> str: + # The authentication is represented with the password component masked. + auth = (self.auth[0], "********") if self.auth else None + + # Build a nice concise representation. + url_str = f"{str(self.url)!r}" + auth_str = f", auth={auth!r}" if auth else "" + headers_str = f", headers={dict(self.headers)!r}" if self.headers else "" + return f"Proxy({url_str}{auth_str}{headers_str})" + + +DEFAULT_TIMEOUT_CONFIG = Timeout(timeout=5.0) +DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20) +DEFAULT_MAX_REDIRECTS = 20 diff --git a/contrib/python/httpx/httpx/_content.py b/contrib/python/httpx/httpx/_content.py new file mode 100644 index 0000000000..b16e12d954 --- /dev/null +++ b/contrib/python/httpx/httpx/_content.py @@ -0,0 +1,238 @@ +import inspect +import warnings +from json import dumps as json_dumps +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + Tuple, + Union, +) +from urllib.parse import urlencode + +from ._exceptions import StreamClosed, StreamConsumed +from ._multipart import MultipartStream +from ._types import ( + AsyncByteStream, + RequestContent, + RequestData, + RequestFiles, + ResponseContent, + SyncByteStream, +) +from ._utils import peek_filelike_length, primitive_value_to_str + + +class ByteStream(AsyncByteStream, SyncByteStream): + def __init__(self, stream: bytes) -> None: + self._stream = stream + + def __iter__(self) -> Iterator[bytes]: + yield self._stream + + async def __aiter__(self) -> AsyncIterator[bytes]: + yield self._stream + + +class IteratorByteStream(SyncByteStream): + CHUNK_SIZE = 65_536 + + def __init__(self, stream: Iterable[bytes]): + self._stream = stream + self._is_stream_consumed = False + self._is_generator = inspect.isgenerator(stream) + + def __iter__(self) -> Iterator[bytes]: + if self._is_stream_consumed and self._is_generator: + raise StreamConsumed() + + self._is_stream_consumed = True + if hasattr(self._stream, "read"): + # File-like interfaces should use 'read' directly. + chunk = self._stream.read(self.CHUNK_SIZE) + while chunk: + yield chunk + chunk = self._stream.read(self.CHUNK_SIZE) + else: + # Otherwise iterate. + for part in self._stream: + yield part + + +class AsyncIteratorByteStream(AsyncByteStream): + CHUNK_SIZE = 65_536 + + def __init__(self, stream: AsyncIterable[bytes]): + self._stream = stream + self._is_stream_consumed = False + self._is_generator = inspect.isasyncgen(stream) + + async def __aiter__(self) -> AsyncIterator[bytes]: + if self._is_stream_consumed and self._is_generator: + raise StreamConsumed() + + self._is_stream_consumed = True + if hasattr(self._stream, "aread"): + # File-like interfaces should use 'aread' directly. + chunk = await self._stream.aread(self.CHUNK_SIZE) + while chunk: + yield chunk + chunk = await self._stream.aread(self.CHUNK_SIZE) + else: + # Otherwise iterate. + async for part in self._stream: + yield part + + +class UnattachedStream(AsyncByteStream, SyncByteStream): + """ + If a request or response is serialized using pickle, then it is no longer + attached to a stream for I/O purposes. Any stream operations should result + in `httpx.StreamClosed`. + """ + + def __iter__(self) -> Iterator[bytes]: + raise StreamClosed() + + async def __aiter__(self) -> AsyncIterator[bytes]: + raise StreamClosed() + yield b"" # pragma: no cover + + +def encode_content( + content: Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] +) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]: + if isinstance(content, (bytes, str)): + body = content.encode("utf-8") if isinstance(content, str) else content + content_length = len(body) + headers = {"Content-Length": str(content_length)} if body else {} + return headers, ByteStream(body) + + elif isinstance(content, Iterable) and not isinstance(content, dict): + # `not isinstance(content, dict)` is a bit oddly specific, but it + # catches a case that's easy for users to make in error, and would + # otherwise pass through here, like any other bytes-iterable, + # because `dict` happens to be iterable. See issue #2491. + content_length_or_none = peek_filelike_length(content) + + if content_length_or_none is None: + headers = {"Transfer-Encoding": "chunked"} + else: + headers = {"Content-Length": str(content_length_or_none)} + return headers, IteratorByteStream(content) # type: ignore + + elif isinstance(content, AsyncIterable): + headers = {"Transfer-Encoding": "chunked"} + return headers, AsyncIteratorByteStream(content) + + raise TypeError(f"Unexpected type for 'content', {type(content)!r}") + + +def encode_urlencoded_data( + data: RequestData, +) -> Tuple[Dict[str, str], ByteStream]: + plain_data = [] + for key, value in data.items(): + if isinstance(value, (list, tuple)): + plain_data.extend([(key, primitive_value_to_str(item)) for item in value]) + else: + plain_data.append((key, primitive_value_to_str(value))) + body = urlencode(plain_data, doseq=True).encode("utf-8") + content_length = str(len(body)) + content_type = "application/x-www-form-urlencoded" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_multipart_data( + data: RequestData, files: RequestFiles, boundary: Optional[bytes] +) -> Tuple[Dict[str, str], MultipartStream]: + multipart = MultipartStream(data=data, files=files, boundary=boundary) + headers = multipart.get_headers() + return headers, multipart + + +def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]: + body = text.encode("utf-8") + content_length = str(len(body)) + content_type = "text/plain; charset=utf-8" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]: + body = html.encode("utf-8") + content_length = str(len(body)) + content_type = "text/html; charset=utf-8" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]: + body = json_dumps(json).encode("utf-8") + content_length = str(len(body)) + content_type = "application/json" + headers = {"Content-Length": content_length, "Content-Type": content_type} + return headers, ByteStream(body) + + +def encode_request( + content: Optional[RequestContent] = None, + data: Optional[RequestData] = None, + files: Optional[RequestFiles] = None, + json: Optional[Any] = None, + boundary: Optional[bytes] = None, +) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]: + """ + Handles encoding the given `content`, `data`, `files`, and `json`, + returning a two-tuple of (<headers>, <stream>). + """ + if data is not None and not isinstance(data, Mapping): + # We prefer to separate `content=<bytes|str|byte iterator|bytes aiterator>` + # for raw request content, and `data=<form data>` for url encoded or + # multipart form content. + # + # However for compat with requests, we *do* still support + # `data=<bytes...>` usages. We deal with that case here, treating it + # as if `content=<...>` had been supplied instead. + message = "Use 'content=<...>' to upload raw bytes/text content." + warnings.warn(message, DeprecationWarning) + return encode_content(data) + + if content is not None: + return encode_content(content) + elif files: + return encode_multipart_data(data or {}, files, boundary) + elif data: + return encode_urlencoded_data(data) + elif json is not None: + return encode_json(json) + + return {}, ByteStream(b"") + + +def encode_response( + content: Optional[ResponseContent] = None, + text: Optional[str] = None, + html: Optional[str] = None, + json: Optional[Any] = None, +) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]: + """ + Handles encoding the given `content`, returning a two-tuple of + (<headers>, <stream>). + """ + if content is not None: + return encode_content(content) + elif text is not None: + return encode_text(text) + elif html is not None: + return encode_html(html) + elif json is not None: + return encode_json(json) + + return {}, ByteStream(b"") diff --git a/contrib/python/httpx/httpx/_decoders.py b/contrib/python/httpx/httpx/_decoders.py new file mode 100644 index 0000000000..500ce7ffc3 --- /dev/null +++ b/contrib/python/httpx/httpx/_decoders.py @@ -0,0 +1,324 @@ +""" +Handlers for Content-Encoding. + +See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding +""" +import codecs +import io +import typing +import zlib + +from ._compat import brotli +from ._exceptions import DecodingError + + +class ContentDecoder: + def decode(self, data: bytes) -> bytes: + raise NotImplementedError() # pragma: no cover + + def flush(self) -> bytes: + raise NotImplementedError() # pragma: no cover + + +class IdentityDecoder(ContentDecoder): + """ + Handle unencoded data. + """ + + def decode(self, data: bytes) -> bytes: + return data + + def flush(self) -> bytes: + return b"" + + +class DeflateDecoder(ContentDecoder): + """ + Handle 'deflate' decoding. + + See: https://stackoverflow.com/questions/1838699 + """ + + def __init__(self) -> None: + self.first_attempt = True + self.decompressor = zlib.decompressobj() + + def decode(self, data: bytes) -> bytes: + was_first_attempt = self.first_attempt + self.first_attempt = False + try: + return self.decompressor.decompress(data) + except zlib.error as exc: + if was_first_attempt: + self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS) + return self.decode(data) + raise DecodingError(str(exc)) from exc + + def flush(self) -> bytes: + try: + return self.decompressor.flush() + except zlib.error as exc: # pragma: no cover + raise DecodingError(str(exc)) from exc + + +class GZipDecoder(ContentDecoder): + """ + Handle 'gzip' decoding. + + See: https://stackoverflow.com/questions/1838699 + """ + + def __init__(self) -> None: + self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16) + + def decode(self, data: bytes) -> bytes: + try: + return self.decompressor.decompress(data) + except zlib.error as exc: + raise DecodingError(str(exc)) from exc + + def flush(self) -> bytes: + try: + return self.decompressor.flush() + except zlib.error as exc: # pragma: no cover + raise DecodingError(str(exc)) from exc + + +class BrotliDecoder(ContentDecoder): + """ + Handle 'brotli' decoding. + + Requires `pip install brotlipy`. See: https://brotlipy.readthedocs.io/ + or `pip install brotli`. See https://github.com/google/brotli + Supports both 'brotlipy' and 'Brotli' packages since they share an import + name. The top branches are for 'brotlipy' and bottom branches for 'Brotli' + """ + + def __init__(self) -> None: + if brotli is None: # pragma: no cover + raise ImportError( + "Using 'BrotliDecoder', but neither of the 'brotlicffi' or 'brotli' " + "packages have been installed. " + "Make sure to install httpx using `pip install httpx[brotli]`." + ) from None + + self.decompressor = brotli.Decompressor() + self.seen_data = False + self._decompress: typing.Callable[[bytes], bytes] + if hasattr(self.decompressor, "decompress"): + # The 'brotlicffi' package. + self._decompress = self.decompressor.decompress # pragma: no cover + else: + # The 'brotli' package. + self._decompress = self.decompressor.process # pragma: no cover + + def decode(self, data: bytes) -> bytes: + if not data: + return b"" + self.seen_data = True + try: + return self._decompress(data) + except brotli.error as exc: + raise DecodingError(str(exc)) from exc + + def flush(self) -> bytes: + if not self.seen_data: + return b"" + try: + if hasattr(self.decompressor, "finish"): + # Only available in the 'brotlicffi' package. + + # As the decompressor decompresses eagerly, this + # will never actually emit any data. However, it will potentially throw + # errors if a truncated or damaged data stream has been used. + self.decompressor.finish() # pragma: no cover + return b"" + except brotli.error as exc: # pragma: no cover + raise DecodingError(str(exc)) from exc + + +class MultiDecoder(ContentDecoder): + """ + Handle the case where multiple encodings have been applied. + """ + + def __init__(self, children: typing.Sequence[ContentDecoder]) -> None: + """ + 'children' should be a sequence of decoders in the order in which + each was applied. + """ + # Note that we reverse the order for decoding. + self.children = list(reversed(children)) + + def decode(self, data: bytes) -> bytes: + for child in self.children: + data = child.decode(data) + return data + + def flush(self) -> bytes: + data = b"" + for child in self.children: + data = child.decode(data) + child.flush() + return data + + +class ByteChunker: + """ + Handles returning byte content in fixed-size chunks. + """ + + def __init__(self, chunk_size: typing.Optional[int] = None) -> None: + self._buffer = io.BytesIO() + self._chunk_size = chunk_size + + def decode(self, content: bytes) -> typing.List[bytes]: + if self._chunk_size is None: + return [content] if content else [] + + self._buffer.write(content) + if self._buffer.tell() >= self._chunk_size: + value = self._buffer.getvalue() + chunks = [ + value[i : i + self._chunk_size] + for i in range(0, len(value), self._chunk_size) + ] + if len(chunks[-1]) == self._chunk_size: + self._buffer.seek(0) + self._buffer.truncate() + return chunks + else: + self._buffer.seek(0) + self._buffer.write(chunks[-1]) + self._buffer.truncate() + return chunks[:-1] + else: + return [] + + def flush(self) -> typing.List[bytes]: + value = self._buffer.getvalue() + self._buffer.seek(0) + self._buffer.truncate() + return [value] if value else [] + + +class TextChunker: + """ + Handles returning text content in fixed-size chunks. + """ + + def __init__(self, chunk_size: typing.Optional[int] = None) -> None: + self._buffer = io.StringIO() + self._chunk_size = chunk_size + + def decode(self, content: str) -> typing.List[str]: + if self._chunk_size is None: + return [content] + + self._buffer.write(content) + if self._buffer.tell() >= self._chunk_size: + value = self._buffer.getvalue() + chunks = [ + value[i : i + self._chunk_size] + for i in range(0, len(value), self._chunk_size) + ] + if len(chunks[-1]) == self._chunk_size: + self._buffer.seek(0) + self._buffer.truncate() + return chunks + else: + self._buffer.seek(0) + self._buffer.write(chunks[-1]) + self._buffer.truncate() + return chunks[:-1] + else: + return [] + + def flush(self) -> typing.List[str]: + value = self._buffer.getvalue() + self._buffer.seek(0) + self._buffer.truncate() + return [value] if value else [] + + +class TextDecoder: + """ + Handles incrementally decoding bytes into text + """ + + def __init__(self, encoding: str = "utf-8"): + self.decoder = codecs.getincrementaldecoder(encoding)(errors="replace") + + def decode(self, data: bytes) -> str: + return self.decoder.decode(data) + + def flush(self) -> str: + return self.decoder.decode(b"", True) + + +class LineDecoder: + """ + Handles incrementally reading lines from text. + + Has the same behaviour as the stdllib splitlines, but handling the input iteratively. + """ + + def __init__(self) -> None: + self.buffer: typing.List[str] = [] + self.trailing_cr: bool = False + + def decode(self, text: str) -> typing.List[str]: + # See https://docs.python.org/3/library/stdtypes.html#str.splitlines + NEWLINE_CHARS = "\n\r\x0b\x0c\x1c\x1d\x1e\x85\u2028\u2029" + + # We always push a trailing `\r` into the next decode iteration. + if self.trailing_cr: + text = "\r" + text + self.trailing_cr = False + if text.endswith("\r"): + self.trailing_cr = True + text = text[:-1] + + if not text: + return [] + + trailing_newline = text[-1] in NEWLINE_CHARS + lines = text.splitlines() + + if len(lines) == 1 and not trailing_newline: + # No new lines, buffer the input and continue. + self.buffer.append(lines[0]) + return [] + + if self.buffer: + # Include any existing buffer in the first portion of the + # splitlines result. + lines = ["".join(self.buffer) + lines[0]] + lines[1:] + self.buffer = [] + + if not trailing_newline: + # If the last segment of splitlines is not newline terminated, + # then drop it from our output and start a new buffer. + self.buffer = [lines.pop()] + + return lines + + def flush(self) -> typing.List[str]: + if not self.buffer and not self.trailing_cr: + return [] + + lines = ["".join(self.buffer)] + self.buffer = [] + self.trailing_cr = False + return lines + + +SUPPORTED_DECODERS = { + "identity": IdentityDecoder, + "gzip": GZipDecoder, + "deflate": DeflateDecoder, + "br": BrotliDecoder, +} + + +if brotli is None: + SUPPORTED_DECODERS.pop("br") # pragma: no cover diff --git a/contrib/python/httpx/httpx/_exceptions.py b/contrib/python/httpx/httpx/_exceptions.py new file mode 100644 index 0000000000..24a4f8aba3 --- /dev/null +++ b/contrib/python/httpx/httpx/_exceptions.py @@ -0,0 +1,343 @@ +""" +Our exception hierarchy: + +* HTTPError + x RequestError + + TransportError + - TimeoutException + · ConnectTimeout + · ReadTimeout + · WriteTimeout + · PoolTimeout + - NetworkError + · ConnectError + · ReadError + · WriteError + · CloseError + - ProtocolError + · LocalProtocolError + · RemoteProtocolError + - ProxyError + - UnsupportedProtocol + + DecodingError + + TooManyRedirects + x HTTPStatusError +* InvalidURL +* CookieConflict +* StreamError + x StreamConsumed + x StreamClosed + x ResponseNotRead + x RequestNotRead +""" +import contextlib +import typing + +if typing.TYPE_CHECKING: + from ._models import Request, Response # pragma: no cover + + +class HTTPError(Exception): + """ + Base class for `RequestError` and `HTTPStatusError`. + + Useful for `try...except` blocks when issuing a request, + and then calling `.raise_for_status()`. + + For example: + + ``` + try: + response = httpx.get("https://www.example.com") + response.raise_for_status() + except httpx.HTTPError as exc: + print(f"HTTP Exception for {exc.request.url} - {exc}") + ``` + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self._request: typing.Optional["Request"] = None + + @property + def request(self) -> "Request": + if self._request is None: + raise RuntimeError("The .request property has not been set.") + return self._request + + @request.setter + def request(self, request: "Request") -> None: + self._request = request + + +class RequestError(HTTPError): + """ + Base class for all exceptions that may occur when issuing a `.request()`. + """ + + def __init__( + self, message: str, *, request: typing.Optional["Request"] = None + ) -> None: + super().__init__(message) + # At the point an exception is raised we won't typically have a request + # instance to associate it with. + # + # The 'request_context' context manager is used within the Client and + # Response methods in order to ensure that any raised exceptions + # have a `.request` property set on them. + self._request = request + + +class TransportError(RequestError): + """ + Base class for all exceptions that occur at the level of the Transport API. + """ + + +# Timeout exceptions... + + +class TimeoutException(TransportError): + """ + The base class for timeout errors. + + An operation has timed out. + """ + + +class ConnectTimeout(TimeoutException): + """ + Timed out while connecting to the host. + """ + + +class ReadTimeout(TimeoutException): + """ + Timed out while receiving data from the host. + """ + + +class WriteTimeout(TimeoutException): + """ + Timed out while sending data to the host. + """ + + +class PoolTimeout(TimeoutException): + """ + Timed out waiting to acquire a connection from the pool. + """ + + +# Core networking exceptions... + + +class NetworkError(TransportError): + """ + The base class for network-related errors. + + An error occurred while interacting with the network. + """ + + +class ReadError(NetworkError): + """ + Failed to receive data from the network. + """ + + +class WriteError(NetworkError): + """ + Failed to send data through the network. + """ + + +class ConnectError(NetworkError): + """ + Failed to establish a connection. + """ + + +class CloseError(NetworkError): + """ + Failed to close a connection. + """ + + +# Other transport exceptions... + + +class ProxyError(TransportError): + """ + An error occurred while establishing a proxy connection. + """ + + +class UnsupportedProtocol(TransportError): + """ + Attempted to make a request to an unsupported protocol. + + For example issuing a request to `ftp://www.example.com`. + """ + + +class ProtocolError(TransportError): + """ + The protocol was violated. + """ + + +class LocalProtocolError(ProtocolError): + """ + A protocol was violated by the client. + + For example if the user instantiated a `Request` instance explicitly, + failed to include the mandatory `Host:` header, and then issued it directly + using `client.send()`. + """ + + +class RemoteProtocolError(ProtocolError): + """ + The protocol was violated by the server. + + For example, returning malformed HTTP. + """ + + +# Other request exceptions... + + +class DecodingError(RequestError): + """ + Decoding of the response failed, due to a malformed encoding. + """ + + +class TooManyRedirects(RequestError): + """ + Too many redirects. + """ + + +# Client errors + + +class HTTPStatusError(HTTPError): + """ + The response had an error HTTP status of 4xx or 5xx. + + May be raised when calling `response.raise_for_status()` + """ + + def __init__( + self, message: str, *, request: "Request", response: "Response" + ) -> None: + super().__init__(message) + self.request = request + self.response = response + + +class InvalidURL(Exception): + """ + URL is improperly formed or cannot be parsed. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class CookieConflict(Exception): + """ + Attempted to lookup a cookie by name, but multiple cookies existed. + + Can occur when calling `response.cookies.get(...)`. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +# Stream exceptions... + +# These may occur as the result of a programming error, by accessing +# the request/response stream in an invalid manner. + + +class StreamError(RuntimeError): + """ + The base class for stream exceptions. + + The developer made an error in accessing the request stream in + an invalid way. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class StreamConsumed(StreamError): + """ + Attempted to read or stream content, but the content has already + been streamed. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream some content, but the content has " + "already been streamed. For requests, this could be due to passing " + "a generator as request content, and then receiving a redirect " + "response or a secondary request as part of an authentication flow." + "For responses, this could be due to attempting to stream the response " + "content more than once." + ) + super().__init__(message) + + +class StreamClosed(StreamError): + """ + Attempted to read or stream response content, but the request has been + closed. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream content, but the stream has " "been closed." + ) + super().__init__(message) + + +class ResponseNotRead(StreamError): + """ + Attempted to access streaming response content, without having called `read()`. + """ + + def __init__(self) -> None: + message = "Attempted to access streaming response content, without having called `read()`." + super().__init__(message) + + +class RequestNotRead(StreamError): + """ + Attempted to access streaming request content, without having called `read()`. + """ + + def __init__(self) -> None: + message = "Attempted to access streaming request content, without having called `read()`." + super().__init__(message) + + +@contextlib.contextmanager +def request_context( + request: typing.Optional["Request"] = None, +) -> typing.Iterator[None]: + """ + A context manager that can be used to attach the given request context + to any `RequestError` exceptions that are raised within the block. + """ + try: + yield + except RequestError as exc: + if request is not None: + exc.request = request + raise exc diff --git a/contrib/python/httpx/httpx/_main.py b/contrib/python/httpx/httpx/_main.py new file mode 100644 index 0000000000..7c12ce841d --- /dev/null +++ b/contrib/python/httpx/httpx/_main.py @@ -0,0 +1,506 @@ +import functools +import json +import sys +import typing + +import click +import httpcore +import pygments.lexers +import pygments.util +import rich.console +import rich.markup +import rich.progress +import rich.syntax +import rich.table + +from ._client import Client +from ._exceptions import RequestError +from ._models import Response +from ._status_codes import codes + + +def print_help() -> None: + console = rich.console.Console() + + console.print("[bold]HTTPX :butterfly:", justify="center") + console.print() + console.print("A next generation HTTP client.", justify="center") + console.print() + console.print( + "Usage: [bold]httpx[/bold] [cyan]<URL> [OPTIONS][/cyan] ", justify="left" + ) + console.print() + + table = rich.table.Table.grid(padding=1, pad_edge=True) + table.add_column("Parameter", no_wrap=True, justify="left", style="bold") + table.add_column("Description") + table.add_row( + "-m, --method [cyan]METHOD", + "Request method, such as GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD.\n" + "[Default: GET, or POST if a request body is included]", + ) + table.add_row( + "-p, --params [cyan]<NAME VALUE> ...", + "Query parameters to include in the request URL.", + ) + table.add_row( + "-c, --content [cyan]TEXT", "Byte content to include in the request body." + ) + table.add_row( + "-d, --data [cyan]<NAME VALUE> ...", "Form data to include in the request body." + ) + table.add_row( + "-f, --files [cyan]<NAME FILENAME> ...", + "Form files to include in the request body.", + ) + table.add_row("-j, --json [cyan]TEXT", "JSON data to include in the request body.") + table.add_row( + "-h, --headers [cyan]<NAME VALUE> ...", + "Include additional HTTP headers in the request.", + ) + table.add_row( + "--cookies [cyan]<NAME VALUE> ...", "Cookies to include in the request." + ) + table.add_row( + "--auth [cyan]<USER PASS>", + "Username and password to include in the request. Specify '-' for the password to use " + "a password prompt. Note that using --verbose/-v will expose the Authorization " + "header, including the password encoding in a trivially reversible format.", + ) + + table.add_row( + "--proxies [cyan]URL", + "Send the request via a proxy. Should be the URL giving the proxy address.", + ) + + table.add_row( + "--timeout [cyan]FLOAT", + "Timeout value to use for network operations, such as establishing the connection, " + "reading some data, etc... [Default: 5.0]", + ) + + table.add_row("--follow-redirects", "Automatically follow redirects.") + table.add_row("--no-verify", "Disable SSL verification.") + table.add_row( + "--http2", "Send the request using HTTP/2, if the remote server supports it." + ) + + table.add_row( + "--download [cyan]FILE", + "Save the response content as a file, rather than displaying it.", + ) + + table.add_row("-v, --verbose", "Verbose output. Show request as well as response.") + table.add_row("--help", "Show this message and exit.") + console.print(table) + + +def get_lexer_for_response(response: Response) -> str: + content_type = response.headers.get("Content-Type") + if content_type is not None: + mime_type, _, _ = content_type.partition(";") + try: + return typing.cast( + str, pygments.lexers.get_lexer_for_mimetype(mime_type.strip()).name + ) + except pygments.util.ClassNotFound: # pragma: no cover + pass + return "" # pragma: no cover + + +def format_request_headers(request: httpcore.Request, http2: bool = False) -> str: + version = "HTTP/2" if http2 else "HTTP/1.1" + headers = [ + (name.lower() if http2 else name, value) for name, value in request.headers + ] + method = request.method.decode("ascii") + target = request.url.target.decode("ascii") + lines = [f"{method} {target} {version}"] + [ + f"{name.decode('ascii')}: {value.decode('ascii')}" for name, value in headers + ] + return "\n".join(lines) + + +def format_response_headers( + http_version: bytes, + status: int, + reason_phrase: typing.Optional[bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], +) -> str: + version = http_version.decode("ascii") + reason = ( + codes.get_reason_phrase(status) + if reason_phrase is None + else reason_phrase.decode("ascii") + ) + lines = [f"{version} {status} {reason}"] + [ + f"{name.decode('ascii')}: {value.decode('ascii')}" for name, value in headers + ] + return "\n".join(lines) + + +def print_request_headers(request: httpcore.Request, http2: bool = False) -> None: + console = rich.console.Console() + http_text = format_request_headers(request, http2=http2) + syntax = rich.syntax.Syntax(http_text, "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + syntax = rich.syntax.Syntax("", "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + + +def print_response_headers( + http_version: bytes, + status: int, + reason_phrase: typing.Optional[bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], +) -> None: + console = rich.console.Console() + http_text = format_response_headers(http_version, status, reason_phrase, headers) + syntax = rich.syntax.Syntax(http_text, "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + syntax = rich.syntax.Syntax("", "http", theme="ansi_dark", word_wrap=True) + console.print(syntax) + + +def print_response(response: Response) -> None: + console = rich.console.Console() + lexer_name = get_lexer_for_response(response) + if lexer_name: + if lexer_name.lower() == "json": + try: + data = response.json() + text = json.dumps(data, indent=4) + except ValueError: # pragma: no cover + text = response.text + else: + text = response.text + + syntax = rich.syntax.Syntax(text, lexer_name, theme="ansi_dark", word_wrap=True) + console.print(syntax) + else: + console.print(f"<{len(response.content)} bytes of binary data>") + + +_PCTRTT = typing.Tuple[typing.Tuple[str, str], ...] +_PCTRTTT = typing.Tuple[_PCTRTT, ...] +_PeerCertRetDictType = typing.Dict[str, typing.Union[str, _PCTRTTT, _PCTRTT]] + + +def format_certificate(cert: _PeerCertRetDictType) -> str: # pragma: no cover + lines = [] + for key, value in cert.items(): + if isinstance(value, (list, tuple)): + lines.append(f"* {key}:") + for item in value: + if key in ("subject", "issuer"): + for sub_item in item: + lines.append(f"* {sub_item[0]}: {sub_item[1]!r}") + elif isinstance(item, tuple) and len(item) == 2: + lines.append(f"* {item[0]}: {item[1]!r}") + else: + lines.append(f"* {item!r}") + else: + lines.append(f"* {key}: {value!r}") + return "\n".join(lines) + + +def trace( + name: str, info: typing.Mapping[str, typing.Any], verbose: bool = False +) -> None: + console = rich.console.Console() + if name == "connection.connect_tcp.started" and verbose: + host = info["host"] + console.print(f"* Connecting to {host!r}") + elif name == "connection.connect_tcp.complete" and verbose: + stream = info["return_value"] + server_addr = stream.get_extra_info("server_addr") + console.print(f"* Connected to {server_addr[0]!r} on port {server_addr[1]}") + elif name == "connection.start_tls.complete" and verbose: # pragma: no cover + stream = info["return_value"] + ssl_object = stream.get_extra_info("ssl_object") + version = ssl_object.version() + cipher = ssl_object.cipher() + server_cert = ssl_object.getpeercert() + alpn = ssl_object.selected_alpn_protocol() + console.print(f"* SSL established using {version!r} / {cipher[0]!r}") + console.print(f"* Selected ALPN protocol: {alpn!r}") + if server_cert: + console.print("* Server certificate:") + console.print(format_certificate(server_cert)) + elif name == "http11.send_request_headers.started" and verbose: + request = info["request"] + print_request_headers(request, http2=False) + elif name == "http2.send_request_headers.started" and verbose: # pragma: no cover + request = info["request"] + print_request_headers(request, http2=True) + elif name == "http11.receive_response_headers.complete": + http_version, status, reason_phrase, headers = info["return_value"] + print_response_headers(http_version, status, reason_phrase, headers) + elif name == "http2.receive_response_headers.complete": # pragma: no cover + status, headers = info["return_value"] + http_version = b"HTTP/2" + reason_phrase = None + print_response_headers(http_version, status, reason_phrase, headers) + + +def download_response(response: Response, download: typing.BinaryIO) -> None: + console = rich.console.Console() + console.print() + content_length = response.headers.get("Content-Length") + with rich.progress.Progress( + "[progress.description]{task.description}", + "[progress.percentage]{task.percentage:>3.0f}%", + rich.progress.BarColumn(bar_width=None), + rich.progress.DownloadColumn(), + rich.progress.TransferSpeedColumn(), + ) as progress: + description = f"Downloading [bold]{rich.markup.escape(download.name)}" + download_task = progress.add_task( + description, + total=int(content_length or 0), + start=content_length is not None, + ) + for chunk in response.iter_bytes(): + download.write(chunk) + progress.update(download_task, completed=response.num_bytes_downloaded) + + +def validate_json( + ctx: click.Context, + param: typing.Union[click.Option, click.Parameter], + value: typing.Any, +) -> typing.Any: + if value is None: + return None + + try: + return json.loads(value) + except json.JSONDecodeError: # pragma: no cover + raise click.BadParameter("Not valid JSON") + + +def validate_auth( + ctx: click.Context, + param: typing.Union[click.Option, click.Parameter], + value: typing.Any, +) -> typing.Any: + if value == (None, None): + return None + + username, password = value + if password == "-": # pragma: no cover + password = click.prompt("Password", hide_input=True) + return (username, password) + + +def handle_help( + ctx: click.Context, + param: typing.Union[click.Option, click.Parameter], + value: typing.Any, +) -> None: + if not value or ctx.resilient_parsing: + return + + print_help() + ctx.exit() + + +@click.command(add_help_option=False) +@click.argument("url", type=str) +@click.option( + "--method", + "-m", + "method", + type=str, + help=( + "Request method, such as GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD. " + "[Default: GET, or POST if a request body is included]" + ), +) +@click.option( + "--params", + "-p", + "params", + type=(str, str), + multiple=True, + help="Query parameters to include in the request URL.", +) +@click.option( + "--content", + "-c", + "content", + type=str, + help="Byte content to include in the request body.", +) +@click.option( + "--data", + "-d", + "data", + type=(str, str), + multiple=True, + help="Form data to include in the request body.", +) +@click.option( + "--files", + "-f", + "files", + type=(str, click.File(mode="rb")), + multiple=True, + help="Form files to include in the request body.", +) +@click.option( + "--json", + "-j", + "json", + type=str, + callback=validate_json, + help="JSON data to include in the request body.", +) +@click.option( + "--headers", + "-h", + "headers", + type=(str, str), + multiple=True, + help="Include additional HTTP headers in the request.", +) +@click.option( + "--cookies", + "cookies", + type=(str, str), + multiple=True, + help="Cookies to include in the request.", +) +@click.option( + "--auth", + "auth", + type=(str, str), + default=(None, None), + callback=validate_auth, + help=( + "Username and password to include in the request. " + "Specify '-' for the password to use a password prompt. " + "Note that using --verbose/-v will expose the Authorization header, " + "including the password encoding in a trivially reversible format." + ), +) +@click.option( + "--proxies", + "proxies", + type=str, + default=None, + help="Send the request via a proxy. Should be the URL giving the proxy address.", +) +@click.option( + "--timeout", + "timeout", + type=float, + default=5.0, + help=( + "Timeout value to use for network operations, such as establishing the " + "connection, reading some data, etc... [Default: 5.0]" + ), +) +@click.option( + "--follow-redirects", + "follow_redirects", + is_flag=True, + default=False, + help="Automatically follow redirects.", +) +@click.option( + "--no-verify", + "verify", + is_flag=True, + default=True, + help="Disable SSL verification.", +) +@click.option( + "--http2", + "http2", + type=bool, + is_flag=True, + default=False, + help="Send the request using HTTP/2, if the remote server supports it.", +) +@click.option( + "--download", + type=click.File("wb"), + help="Save the response content as a file, rather than displaying it.", +) +@click.option( + "--verbose", + "-v", + type=bool, + is_flag=True, + default=False, + help="Verbose. Show request as well as response.", +) +@click.option( + "--help", + is_flag=True, + is_eager=True, + expose_value=False, + callback=handle_help, + help="Show this message and exit.", +) +def main( + url: str, + method: str, + params: typing.List[typing.Tuple[str, str]], + content: str, + data: typing.List[typing.Tuple[str, str]], + files: typing.List[typing.Tuple[str, click.File]], + json: str, + headers: typing.List[typing.Tuple[str, str]], + cookies: typing.List[typing.Tuple[str, str]], + auth: typing.Optional[typing.Tuple[str, str]], + proxies: str, + timeout: float, + follow_redirects: bool, + verify: bool, + http2: bool, + download: typing.Optional[typing.BinaryIO], + verbose: bool, +) -> None: + """ + An HTTP command line client. + Sends a request and displays the response. + """ + if not method: + method = "POST" if content or data or files or json else "GET" + + try: + with Client( + proxies=proxies, + timeout=timeout, + verify=verify, + http2=http2, + ) as client: + with client.stream( + method, + url, + params=list(params), + content=content, + data=dict(data), + files=files, # type: ignore + json=json, + headers=headers, + cookies=dict(cookies), + auth=auth, + follow_redirects=follow_redirects, + extensions={"trace": functools.partial(trace, verbose=verbose)}, + ) as response: + if download is not None: + download_response(response, download) + else: + response.read() + if response.content: + print_response(response) + + except RequestError as exc: + console = rich.console.Console() + console.print(f"[red]{type(exc).__name__}[/red]: {exc}") + sys.exit(1) + + sys.exit(0 if response.is_success else 1) diff --git a/contrib/python/httpx/httpx/_models.py b/contrib/python/httpx/httpx/_models.py new file mode 100644 index 0000000000..e1e45cf06b --- /dev/null +++ b/contrib/python/httpx/httpx/_models.py @@ -0,0 +1,1209 @@ +import datetime +import email.message +import json as jsonlib +import typing +import urllib.request +from collections.abc import Mapping +from http.cookiejar import Cookie, CookieJar + +from ._content import ByteStream, UnattachedStream, encode_request, encode_response +from ._decoders import ( + SUPPORTED_DECODERS, + ByteChunker, + ContentDecoder, + IdentityDecoder, + LineDecoder, + MultiDecoder, + TextChunker, + TextDecoder, +) +from ._exceptions import ( + CookieConflict, + HTTPStatusError, + RequestNotRead, + ResponseNotRead, + StreamClosed, + StreamConsumed, + request_context, +) +from ._multipart import get_multipart_boundary_from_content_type +from ._status_codes import codes +from ._types import ( + AsyncByteStream, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestExtensions, + RequestFiles, + ResponseContent, + ResponseExtensions, + SyncByteStream, +) +from ._urls import URL +from ._utils import ( + guess_json_utf, + is_known_encoding, + normalize_header_key, + normalize_header_value, + obfuscate_sensitive_headers, + parse_content_type_charset, + parse_header_links, +) + + +class Headers(typing.MutableMapping[str, str]): + """ + HTTP headers, as a case-insensitive multi-dict. + """ + + def __init__( + self, + headers: typing.Optional[HeaderTypes] = None, + encoding: typing.Optional[str] = None, + ) -> None: + if headers is None: + self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]] + elif isinstance(headers, Headers): + self._list = list(headers._list) + elif isinstance(headers, Mapping): + self._list = [ + ( + normalize_header_key(k, lower=False, encoding=encoding), + normalize_header_key(k, lower=True, encoding=encoding), + normalize_header_value(v, encoding), + ) + for k, v in headers.items() + ] + else: + self._list = [ + ( + normalize_header_key(k, lower=False, encoding=encoding), + normalize_header_key(k, lower=True, encoding=encoding), + normalize_header_value(v, encoding), + ) + for k, v in headers + ] + + self._encoding = encoding + + @property + def encoding(self) -> str: + """ + Header encoding is mandated as ascii, but we allow fallbacks to utf-8 + or iso-8859-1. + """ + if self._encoding is None: + for encoding in ["ascii", "utf-8"]: + for key, value in self.raw: + try: + key.decode(encoding) + value.decode(encoding) + except UnicodeDecodeError: + break + else: + # The else block runs if 'break' did not occur, meaning + # all values fitted the encoding. + self._encoding = encoding + break + else: + # The ISO-8859-1 encoding covers all 256 code points in a byte, + # so will never raise decode errors. + self._encoding = "iso-8859-1" + return self._encoding + + @encoding.setter + def encoding(self, value: str) -> None: + self._encoding = value + + @property + def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: + """ + Returns a list of the raw header items, as byte pairs. + """ + return [(raw_key, value) for raw_key, _, value in self._list] + + def keys(self) -> typing.KeysView[str]: + return {key.decode(self.encoding): None for _, key, value in self._list}.keys() + + def values(self) -> typing.ValuesView[str]: + values_dict: typing.Dict[str, str] = {} + for _, key, value in self._list: + str_key = key.decode(self.encoding) + str_value = value.decode(self.encoding) + if str_key in values_dict: + values_dict[str_key] += f", {str_value}" + else: + values_dict[str_key] = str_value + return values_dict.values() + + def items(self) -> typing.ItemsView[str, str]: + """ + Return `(key, value)` items of headers. Concatenate headers + into a single comma separated value when a key occurs multiple times. + """ + values_dict: typing.Dict[str, str] = {} + for _, key, value in self._list: + str_key = key.decode(self.encoding) + str_value = value.decode(self.encoding) + if str_key in values_dict: + values_dict[str_key] += f", {str_value}" + else: + values_dict[str_key] = str_value + return values_dict.items() + + def multi_items(self) -> typing.List[typing.Tuple[str, str]]: + """ + Return a list of `(key, value)` pairs of headers. Allow multiple + occurrences of the same key without concatenating into a single + comma separated value. + """ + return [ + (key.decode(self.encoding), value.decode(self.encoding)) + for _, key, value in self._list + ] + + def get(self, key: str, default: typing.Any = None) -> typing.Any: + """ + Return a header value. If multiple occurrences of the header occur + then concatenate them together with commas. + """ + try: + return self[key] + except KeyError: + return default + + def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]: + """ + Return a list of all header values for a given key. + If `split_commas=True` is passed, then any comma separated header + values are split into multiple return strings. + """ + get_header_key = key.lower().encode(self.encoding) + + values = [ + item_value.decode(self.encoding) + for _, item_key, item_value in self._list + if item_key.lower() == get_header_key + ] + + if not split_commas: + return values + + split_values = [] + for value in values: + split_values.extend([item.strip() for item in value.split(",")]) + return split_values + + def update(self, headers: typing.Optional[HeaderTypes] = None) -> None: # type: ignore + headers = Headers(headers) + for key in headers.keys(): + if key in self: + self.pop(key) + self._list.extend(headers._list) + + def copy(self) -> "Headers": + return Headers(self, encoding=self.encoding) + + def __getitem__(self, key: str) -> str: + """ + Return a single header value. + + If there are multiple headers with the same key, then we concatenate + them with commas. See: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + normalized_key = key.lower().encode(self.encoding) + + items = [ + header_value.decode(self.encoding) + for _, header_key, header_value in self._list + if header_key == normalized_key + ] + + if items: + return ", ".join(items) + + raise KeyError(key) + + def __setitem__(self, key: str, value: str) -> None: + """ + Set the header `key` to `value`, removing any duplicate entries. + Retains insertion order. + """ + set_key = key.encode(self._encoding or "utf-8") + set_value = value.encode(self._encoding or "utf-8") + lookup_key = set_key.lower() + + found_indexes = [ + idx + for idx, (_, item_key, _) in enumerate(self._list) + if item_key == lookup_key + ] + + for idx in reversed(found_indexes[1:]): + del self._list[idx] + + if found_indexes: + idx = found_indexes[0] + self._list[idx] = (set_key, lookup_key, set_value) + else: + self._list.append((set_key, lookup_key, set_value)) + + def __delitem__(self, key: str) -> None: + """ + Remove the header `key`. + """ + del_key = key.lower().encode(self.encoding) + + pop_indexes = [ + idx + for idx, (_, item_key, _) in enumerate(self._list) + if item_key.lower() == del_key + ] + + if not pop_indexes: + raise KeyError(key) + + for idx in reversed(pop_indexes): + del self._list[idx] + + def __contains__(self, key: typing.Any) -> bool: + header_key = key.lower().encode(self.encoding) + return header_key in [key for _, key, _ in self._list] + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._list) + + def __eq__(self, other: typing.Any) -> bool: + try: + other_headers = Headers(other) + except ValueError: + return False + + self_list = [(key, value) for _, key, value in self._list] + other_list = [(key, value) for _, key, value in other_headers._list] + return sorted(self_list) == sorted(other_list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + + encoding_str = "" + if self.encoding != "ascii": + encoding_str = f", encoding={self.encoding!r}" + + as_list = list(obfuscate_sensitive_headers(self.multi_items())) + as_dict = dict(as_list) + + no_duplicate_keys = len(as_dict) == len(as_list) + if no_duplicate_keys: + return f"{class_name}({as_dict!r}{encoding_str})" + return f"{class_name}({as_list!r}{encoding_str})" + + +class Request: + def __init__( + self, + method: typing.Union[str, bytes], + url: typing.Union["URL", str], + *, + params: typing.Optional[QueryParamTypes] = None, + headers: typing.Optional[HeaderTypes] = None, + cookies: typing.Optional[CookieTypes] = None, + content: typing.Optional[RequestContent] = None, + data: typing.Optional[RequestData] = None, + files: typing.Optional[RequestFiles] = None, + json: typing.Optional[typing.Any] = None, + stream: typing.Union[SyncByteStream, AsyncByteStream, None] = None, + extensions: typing.Optional[RequestExtensions] = None, + ): + self.method = ( + method.decode("ascii").upper() + if isinstance(method, bytes) + else method.upper() + ) + self.url = URL(url) + if params is not None: + self.url = self.url.copy_merge_params(params=params) + self.headers = Headers(headers) + self.extensions = {} if extensions is None else extensions + + if cookies: + Cookies(cookies).set_cookie_header(self) + + if stream is None: + content_type: typing.Optional[str] = self.headers.get("content-type") + headers, stream = encode_request( + content=content, + data=data, + files=files, + json=json, + boundary=get_multipart_boundary_from_content_type( + content_type=content_type.encode(self.headers.encoding) + if content_type + else None + ), + ) + self._prepare(headers) + self.stream = stream + # Load the request body, except for streaming content. + if isinstance(stream, ByteStream): + self.read() + else: + # There's an important distinction between `Request(content=...)`, + # and `Request(stream=...)`. + # + # Using `content=...` implies automatically populated `Host` and content + # headers, of either `Content-Length: ...` or `Transfer-Encoding: chunked`. + # + # Using `stream=...` will not automatically include *any* auto-populated headers. + # + # As an end-user you don't really need `stream=...`. It's only + # useful when: + # + # * Preserving the request stream when copying requests, eg for redirects. + # * Creating request instances on the *server-side* of the transport API. + self.stream = stream + + def _prepare(self, default_headers: typing.Dict[str, str]) -> None: + for key, value in default_headers.items(): + # Ignore Transfer-Encoding if the Content-Length has been set explicitly. + if key.lower() == "transfer-encoding" and "Content-Length" in self.headers: + continue + self.headers.setdefault(key, value) + + auto_headers: typing.List[typing.Tuple[bytes, bytes]] = [] + + has_host = "Host" in self.headers + has_content_length = ( + "Content-Length" in self.headers or "Transfer-Encoding" in self.headers + ) + + if not has_host and self.url.host: + auto_headers.append((b"Host", self.url.netloc)) + if not has_content_length and self.method in ("POST", "PUT", "PATCH"): + auto_headers.append((b"Content-Length", b"0")) + + self.headers = Headers(auto_headers + self.headers.raw) + + @property + def content(self) -> bytes: + if not hasattr(self, "_content"): + raise RequestNotRead() + return self._content + + def read(self) -> bytes: + """ + Read and return the request content. + """ + if not hasattr(self, "_content"): + assert isinstance(self.stream, typing.Iterable) + self._content = b"".join(self.stream) + if not isinstance(self.stream, ByteStream): + # If a streaming request has been read entirely into memory, then + # we can replace the stream with a raw bytes implementation, + # to ensure that any non-replayable streams can still be used. + self.stream = ByteStream(self._content) + return self._content + + async def aread(self) -> bytes: + """ + Read and return the request content. + """ + if not hasattr(self, "_content"): + assert isinstance(self.stream, typing.AsyncIterable) + self._content = b"".join([part async for part in self.stream]) + if not isinstance(self.stream, ByteStream): + # If a streaming request has been read entirely into memory, then + # we can replace the stream with a raw bytes implementation, + # to ensure that any non-replayable streams can still be used. + self.stream = ByteStream(self._content) + return self._content + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + url = str(self.url) + return f"<{class_name}({self.method!r}, {url!r})>" + + def __getstate__(self) -> typing.Dict[str, typing.Any]: + return { + name: value + for name, value in self.__dict__.items() + if name not in ["extensions", "stream"] + } + + def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None: + for name, value in state.items(): + setattr(self, name, value) + self.extensions = {} + self.stream = UnattachedStream() + + +class Response: + def __init__( + self, + status_code: int, + *, + headers: typing.Optional[HeaderTypes] = None, + content: typing.Optional[ResponseContent] = None, + text: typing.Optional[str] = None, + html: typing.Optional[str] = None, + json: typing.Any = None, + stream: typing.Union[SyncByteStream, AsyncByteStream, None] = None, + request: typing.Optional[Request] = None, + extensions: typing.Optional[ResponseExtensions] = None, + history: typing.Optional[typing.List["Response"]] = None, + default_encoding: typing.Union[str, typing.Callable[[bytes], str]] = "utf-8", + ): + self.status_code = status_code + self.headers = Headers(headers) + + self._request: typing.Optional[Request] = request + + # When follow_redirects=False and a redirect is received, + # the client will set `response.next_request`. + self.next_request: typing.Optional[Request] = None + + self.extensions: ResponseExtensions = {} if extensions is None else extensions + self.history = [] if history is None else list(history) + + self.is_closed = False + self.is_stream_consumed = False + + self.default_encoding = default_encoding + + if stream is None: + headers, stream = encode_response(content, text, html, json) + self._prepare(headers) + self.stream = stream + if isinstance(stream, ByteStream): + # Load the response body, except for streaming content. + self.read() + else: + # There's an important distinction between `Response(content=...)`, + # and `Response(stream=...)`. + # + # Using `content=...` implies automatically populated content headers, + # of either `Content-Length: ...` or `Transfer-Encoding: chunked`. + # + # Using `stream=...` will not automatically include any content headers. + # + # As an end-user you don't really need `stream=...`. It's only + # useful when creating response instances having received a stream + # from the transport API. + self.stream = stream + + self._num_bytes_downloaded = 0 + + def _prepare(self, default_headers: typing.Dict[str, str]) -> None: + for key, value in default_headers.items(): + # Ignore Transfer-Encoding if the Content-Length has been set explicitly. + if key.lower() == "transfer-encoding" and "content-length" in self.headers: + continue + self.headers.setdefault(key, value) + + @property + def elapsed(self) -> datetime.timedelta: + """ + Returns the time taken for the complete request/response + cycle to complete. + """ + if not hasattr(self, "_elapsed"): + raise RuntimeError( + "'.elapsed' may only be accessed after the response " + "has been read or closed." + ) + return self._elapsed + + @elapsed.setter + def elapsed(self, elapsed: datetime.timedelta) -> None: + self._elapsed = elapsed + + @property + def request(self) -> Request: + """ + Returns the request instance associated to the current response. + """ + if self._request is None: + raise RuntimeError( + "The request instance has not been set on this response." + ) + return self._request + + @request.setter + def request(self, value: Request) -> None: + self._request = value + + @property + def http_version(self) -> str: + try: + http_version: bytes = self.extensions["http_version"] + except KeyError: + return "HTTP/1.1" + else: + return http_version.decode("ascii", errors="ignore") + + @property + def reason_phrase(self) -> str: + try: + reason_phrase: bytes = self.extensions["reason_phrase"] + except KeyError: + return codes.get_reason_phrase(self.status_code) + else: + return reason_phrase.decode("ascii", errors="ignore") + + @property + def url(self) -> URL: + """ + Returns the URL for which the request was made. + """ + return self.request.url + + @property + def content(self) -> bytes: + if not hasattr(self, "_content"): + raise ResponseNotRead() + return self._content + + @property + def text(self) -> str: + if not hasattr(self, "_text"): + content = self.content + if not content: + self._text = "" + else: + decoder = TextDecoder(encoding=self.encoding or "utf-8") + self._text = "".join([decoder.decode(self.content), decoder.flush()]) + return self._text + + @property + def encoding(self) -> typing.Optional[str]: + """ + Return an encoding to use for decoding the byte content into text. + The priority for determining this is given by... + + * `.encoding = <>` has been set explicitly. + * The encoding as specified by the charset parameter in the Content-Type header. + * The encoding as determined by `default_encoding`, which may either be + a string like "utf-8" indicating the encoding to use, or may be a callable + which enables charset autodetection. + """ + if not hasattr(self, "_encoding"): + encoding = self.charset_encoding + if encoding is None or not is_known_encoding(encoding): + if isinstance(self.default_encoding, str): + encoding = self.default_encoding + elif hasattr(self, "_content"): + encoding = self.default_encoding(self._content) + self._encoding = encoding or "utf-8" + return self._encoding + + @encoding.setter + def encoding(self, value: str) -> None: + self._encoding = value + + @property + def charset_encoding(self) -> typing.Optional[str]: + """ + Return the encoding, as specified by the Content-Type header. + """ + content_type = self.headers.get("Content-Type") + if content_type is None: + return None + + return parse_content_type_charset(content_type) + + def _get_content_decoder(self) -> ContentDecoder: + """ + Returns a decoder instance which can be used to decode the raw byte + content, depending on the Content-Encoding used in the response. + """ + if not hasattr(self, "_decoder"): + decoders: typing.List[ContentDecoder] = [] + values = self.headers.get_list("content-encoding", split_commas=True) + for value in values: + value = value.strip().lower() + try: + decoder_cls = SUPPORTED_DECODERS[value] + decoders.append(decoder_cls()) + except KeyError: + continue + + if len(decoders) == 1: + self._decoder = decoders[0] + elif len(decoders) > 1: + self._decoder = MultiDecoder(children=decoders) + else: + self._decoder = IdentityDecoder() + + return self._decoder + + @property + def is_informational(self) -> bool: + """ + A property which is `True` for 1xx status codes, `False` otherwise. + """ + return codes.is_informational(self.status_code) + + @property + def is_success(self) -> bool: + """ + A property which is `True` for 2xx status codes, `False` otherwise. + """ + return codes.is_success(self.status_code) + + @property + def is_redirect(self) -> bool: + """ + A property which is `True` for 3xx status codes, `False` otherwise. + + Note that not all responses with a 3xx status code indicate a URL redirect. + + Use `response.has_redirect_location` to determine responses with a properly + formed URL redirection. + """ + return codes.is_redirect(self.status_code) + + @property + def is_client_error(self) -> bool: + """ + A property which is `True` for 4xx status codes, `False` otherwise. + """ + return codes.is_client_error(self.status_code) + + @property + def is_server_error(self) -> bool: + """ + A property which is `True` for 5xx status codes, `False` otherwise. + """ + return codes.is_server_error(self.status_code) + + @property + def is_error(self) -> bool: + """ + A property which is `True` for 4xx and 5xx status codes, `False` otherwise. + """ + return codes.is_error(self.status_code) + + @property + def has_redirect_location(self) -> bool: + """ + Returns True for 3xx responses with a properly formed URL redirection, + `False` otherwise. + """ + return ( + self.status_code + in ( + # 301 (Cacheable redirect. Method may change to GET.) + codes.MOVED_PERMANENTLY, + # 302 (Uncacheable redirect. Method may change to GET.) + codes.FOUND, + # 303 (Client should make a GET or HEAD request.) + codes.SEE_OTHER, + # 307 (Equiv. 302, but retain method) + codes.TEMPORARY_REDIRECT, + # 308 (Equiv. 301, but retain method) + codes.PERMANENT_REDIRECT, + ) + and "Location" in self.headers + ) + + def raise_for_status(self) -> "Response": + """ + Raise the `HTTPStatusError` if one occurred. + """ + request = self._request + if request is None: + raise RuntimeError( + "Cannot call `raise_for_status` as the request " + "instance has not been set on this response." + ) + + if self.is_success: + return self + + if self.has_redirect_location: + message = ( + "{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n" + "Redirect location: '{0.headers[location]}'\n" + "For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}" + ) + else: + message = ( + "{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n" + "For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}" + ) + + status_class = self.status_code // 100 + error_types = { + 1: "Informational response", + 3: "Redirect response", + 4: "Client error", + 5: "Server error", + } + error_type = error_types.get(status_class, "Invalid status code") + message = message.format(self, error_type=error_type) + raise HTTPStatusError(message, request=request, response=self) + + def json(self, **kwargs: typing.Any) -> typing.Any: + if self.charset_encoding is None and self.content and len(self.content) > 3: + encoding = guess_json_utf(self.content) + if encoding is not None: + return jsonlib.loads(self.content.decode(encoding), **kwargs) + return jsonlib.loads(self.text, **kwargs) + + @property + def cookies(self) -> "Cookies": + if not hasattr(self, "_cookies"): + self._cookies = Cookies() + self._cookies.extract_cookies(self) + return self._cookies + + @property + def links(self) -> typing.Dict[typing.Optional[str], typing.Dict[str, str]]: + """ + Returns the parsed header links of the response, if any + """ + header = self.headers.get("link") + ldict = {} + if header: + links = parse_header_links(header) + for link in links: + key = link.get("rel") or link.get("url") + ldict[key] = link + return ldict + + @property + def num_bytes_downloaded(self) -> int: + return self._num_bytes_downloaded + + def __repr__(self) -> str: + return f"<Response [{self.status_code} {self.reason_phrase}]>" + + def __getstate__(self) -> typing.Dict[str, typing.Any]: + return { + name: value + for name, value in self.__dict__.items() + if name not in ["extensions", "stream", "is_closed", "_decoder"] + } + + def __setstate__(self, state: typing.Dict[str, typing.Any]) -> None: + for name, value in state.items(): + setattr(self, name, value) + self.is_closed = True + self.extensions = {} + self.stream = UnattachedStream() + + def read(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "_content"): + self._content = b"".join(self.iter_bytes()) + return self._content + + def iter_bytes( + self, chunk_size: typing.Optional[int] = None + ) -> typing.Iterator[bytes]: + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, and brotli encoded responses. + """ + if hasattr(self, "_content"): + chunk_size = len(self._content) if chunk_size is None else chunk_size + for i in range(0, len(self._content), max(chunk_size, 1)): + yield self._content[i : i + chunk_size] + else: + decoder = self._get_content_decoder() + chunker = ByteChunker(chunk_size=chunk_size) + with request_context(request=self._request): + for raw_bytes in self.iter_raw(): + decoded = decoder.decode(raw_bytes) + for chunk in chunker.decode(decoded): + yield chunk + decoded = decoder.flush() + for chunk in chunker.decode(decoded): + yield chunk # pragma: no cover + for chunk in chunker.flush(): + yield chunk + + def iter_text( + self, chunk_size: typing.Optional[int] = None + ) -> typing.Iterator[str]: + """ + A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + decoder = TextDecoder(encoding=self.encoding or "utf-8") + chunker = TextChunker(chunk_size=chunk_size) + with request_context(request=self._request): + for byte_content in self.iter_bytes(): + text_content = decoder.decode(byte_content) + for chunk in chunker.decode(text_content): + yield chunk + text_content = decoder.flush() + for chunk in chunker.decode(text_content): + yield chunk + for chunk in chunker.flush(): + yield chunk + + def iter_lines(self) -> typing.Iterator[str]: + decoder = LineDecoder() + with request_context(request=self._request): + for text in self.iter_text(): + for line in decoder.decode(text): + yield line + for line in decoder.flush(): + yield line + + def iter_raw( + self, chunk_size: typing.Optional[int] = None + ) -> typing.Iterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + if self.is_stream_consumed: + raise StreamConsumed() + if self.is_closed: + raise StreamClosed() + if not isinstance(self.stream, SyncByteStream): + raise RuntimeError("Attempted to call a sync iterator on an async stream.") + + self.is_stream_consumed = True + self._num_bytes_downloaded = 0 + chunker = ByteChunker(chunk_size=chunk_size) + + with request_context(request=self._request): + for raw_stream_bytes in self.stream: + self._num_bytes_downloaded += len(raw_stream_bytes) + for chunk in chunker.decode(raw_stream_bytes): + yield chunk + + for chunk in chunker.flush(): + yield chunk + + self.close() + + def close(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + if not isinstance(self.stream, SyncByteStream): + raise RuntimeError("Attempted to call an sync close on an async stream.") + + if not self.is_closed: + self.is_closed = True + with request_context(request=self._request): + self.stream.close() + + async def aread(self) -> bytes: + """ + Read and return the response content. + """ + if not hasattr(self, "_content"): + self._content = b"".join([part async for part in self.aiter_bytes()]) + return self._content + + async def aiter_bytes( + self, chunk_size: typing.Optional[int] = None + ) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, and brotli encoded responses. + """ + if hasattr(self, "_content"): + chunk_size = len(self._content) if chunk_size is None else chunk_size + for i in range(0, len(self._content), max(chunk_size, 1)): + yield self._content[i : i + chunk_size] + else: + decoder = self._get_content_decoder() + chunker = ByteChunker(chunk_size=chunk_size) + with request_context(request=self._request): + async for raw_bytes in self.aiter_raw(): + decoded = decoder.decode(raw_bytes) + for chunk in chunker.decode(decoded): + yield chunk + decoded = decoder.flush() + for chunk in chunker.decode(decoded): + yield chunk # pragma: no cover + for chunk in chunker.flush(): + yield chunk + + async def aiter_text( + self, chunk_size: typing.Optional[int] = None + ) -> typing.AsyncIterator[str]: + """ + A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + decoder = TextDecoder(encoding=self.encoding or "utf-8") + chunker = TextChunker(chunk_size=chunk_size) + with request_context(request=self._request): + async for byte_content in self.aiter_bytes(): + text_content = decoder.decode(byte_content) + for chunk in chunker.decode(text_content): + yield chunk + text_content = decoder.flush() + for chunk in chunker.decode(text_content): + yield chunk + for chunk in chunker.flush(): + yield chunk + + async def aiter_lines(self) -> typing.AsyncIterator[str]: + decoder = LineDecoder() + with request_context(request=self._request): + async for text in self.aiter_text(): + for line in decoder.decode(text): + yield line + for line in decoder.flush(): + yield line + + async def aiter_raw( + self, chunk_size: typing.Optional[int] = None + ) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + if self.is_stream_consumed: + raise StreamConsumed() + if self.is_closed: + raise StreamClosed() + if not isinstance(self.stream, AsyncByteStream): + raise RuntimeError("Attempted to call an async iterator on an sync stream.") + + self.is_stream_consumed = True + self._num_bytes_downloaded = 0 + chunker = ByteChunker(chunk_size=chunk_size) + + with request_context(request=self._request): + async for raw_stream_bytes in self.stream: + self._num_bytes_downloaded += len(raw_stream_bytes) + for chunk in chunker.decode(raw_stream_bytes): + yield chunk + + for chunk in chunker.flush(): + yield chunk + + await self.aclose() + + async def aclose(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + if not isinstance(self.stream, AsyncByteStream): + raise RuntimeError("Attempted to call an async close on an sync stream.") + + if not self.is_closed: + self.is_closed = True + with request_context(request=self._request): + await self.stream.aclose() + + +class Cookies(typing.MutableMapping[str, str]): + """ + HTTP Cookies, as a mutable mapping. + """ + + def __init__(self, cookies: typing.Optional[CookieTypes] = None) -> None: + if cookies is None or isinstance(cookies, dict): + self.jar = CookieJar() + if isinstance(cookies, dict): + for key, value in cookies.items(): + self.set(key, value) + elif isinstance(cookies, list): + self.jar = CookieJar() + for key, value in cookies: + self.set(key, value) + elif isinstance(cookies, Cookies): + self.jar = CookieJar() + for cookie in cookies.jar: + self.jar.set_cookie(cookie) + else: + self.jar = cookies + + def extract_cookies(self, response: Response) -> None: + """ + Loads any cookies based on the response `Set-Cookie` headers. + """ + urllib_response = self._CookieCompatResponse(response) + urllib_request = self._CookieCompatRequest(response.request) + + self.jar.extract_cookies(urllib_response, urllib_request) # type: ignore + + def set_cookie_header(self, request: Request) -> None: + """ + Sets an appropriate 'Cookie:' HTTP header on the `Request`. + """ + urllib_request = self._CookieCompatRequest(request) + self.jar.add_cookie_header(urllib_request) + + def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None: + """ + Set a cookie value by name. May optionally include domain and path. + """ + kwargs = { + "version": 0, + "name": name, + "value": value, + "port": None, + "port_specified": False, + "domain": domain, + "domain_specified": bool(domain), + "domain_initial_dot": domain.startswith("."), + "path": path, + "path_specified": bool(path), + "secure": False, + "expires": None, + "discard": True, + "comment": None, + "comment_url": None, + "rest": {"HttpOnly": None}, + "rfc2109": False, + } + cookie = Cookie(**kwargs) # type: ignore + self.jar.set_cookie(cookie) + + def get( # type: ignore + self, + name: str, + default: typing.Optional[str] = None, + domain: typing.Optional[str] = None, + path: typing.Optional[str] = None, + ) -> typing.Optional[str]: + """ + Get a cookie by name. May optionally include domain and path + in order to specify exactly which cookie to retrieve. + """ + value = None + for cookie in self.jar: + if cookie.name == name: + if domain is None or cookie.domain == domain: + if path is None or cookie.path == path: + if value is not None: + message = f"Multiple cookies exist with name={name}" + raise CookieConflict(message) + value = cookie.value + + if value is None: + return default + return value + + def delete( + self, + name: str, + domain: typing.Optional[str] = None, + path: typing.Optional[str] = None, + ) -> None: + """ + Delete a cookie by name. May optionally include domain and path + in order to specify exactly which cookie to delete. + """ + if domain is not None and path is not None: + return self.jar.clear(domain, path, name) + + remove = [ + cookie + for cookie in self.jar + if cookie.name == name + and (domain is None or cookie.domain == domain) + and (path is None or cookie.path == path) + ] + + for cookie in remove: + self.jar.clear(cookie.domain, cookie.path, cookie.name) + + def clear( + self, domain: typing.Optional[str] = None, path: typing.Optional[str] = None + ) -> None: + """ + Delete all cookies. Optionally include a domain and path in + order to only delete a subset of all the cookies. + """ + args = [] + if domain is not None: + args.append(domain) + if path is not None: + assert domain is not None + args.append(path) + self.jar.clear(*args) + + def update(self, cookies: typing.Optional[CookieTypes] = None) -> None: # type: ignore + cookies = Cookies(cookies) + for cookie in cookies.jar: + self.jar.set_cookie(cookie) + + def __setitem__(self, name: str, value: str) -> None: + return self.set(name, value) + + def __getitem__(self, name: str) -> str: + value = self.get(name) + if value is None: + raise KeyError(name) + return value + + def __delitem__(self, name: str) -> None: + return self.delete(name) + + def __len__(self) -> int: + return len(self.jar) + + def __iter__(self) -> typing.Iterator[str]: + return (cookie.name for cookie in self.jar) + + def __bool__(self) -> bool: + for _ in self.jar: + return True + return False + + def __repr__(self) -> str: + cookies_repr = ", ".join( + [ + f"<Cookie {cookie.name}={cookie.value} for {cookie.domain} />" + for cookie in self.jar + ] + ) + + return f"<Cookies[{cookies_repr}]>" + + class _CookieCompatRequest(urllib.request.Request): + """ + Wraps a `Request` instance up in a compatibility interface suitable + for use with `CookieJar` operations. + """ + + def __init__(self, request: Request) -> None: + super().__init__( + url=str(request.url), + headers=dict(request.headers), + method=request.method, + ) + self.request = request + + def add_unredirected_header(self, key: str, value: str) -> None: + super().add_unredirected_header(key, value) + self.request.headers[key] = value + + class _CookieCompatResponse: + """ + Wraps a `Request` instance up in a compatibility interface suitable + for use with `CookieJar` operations. + """ + + def __init__(self, response: Response): + self.response = response + + def info(self) -> email.message.Message: + info = email.message.Message() + for key, value in self.response.headers.multi_items(): + # Note that setting `info[key]` here is an "append" operation, + # not a "replace" operation. + # https://docs.python.org/3/library/email.compat32-message.html#email.message.Message.__setitem__ + info[key] = value + return info diff --git a/contrib/python/httpx/httpx/_multipart.py b/contrib/python/httpx/httpx/_multipart.py new file mode 100644 index 0000000000..446f4ad2df --- /dev/null +++ b/contrib/python/httpx/httpx/_multipart.py @@ -0,0 +1,267 @@ +import binascii +import io +import os +import typing +from pathlib import Path + +from ._types import ( + AsyncByteStream, + FileContent, + FileTypes, + RequestData, + RequestFiles, + SyncByteStream, +) +from ._utils import ( + format_form_param, + guess_content_type, + peek_filelike_length, + primitive_value_to_str, + to_bytes, +) + + +def get_multipart_boundary_from_content_type( + content_type: typing.Optional[bytes], +) -> typing.Optional[bytes]: + if not content_type or not content_type.startswith(b"multipart/form-data"): + return None + # parse boundary according to + # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1 + if b";" in content_type: + for section in content_type.split(b";"): + if section.strip().lower().startswith(b"boundary="): + return section.strip()[len(b"boundary=") :].strip(b'"') + return None + + +class DataField: + """ + A single form field item, within a multipart form field. + """ + + def __init__( + self, name: str, value: typing.Union[str, bytes, int, float, None] + ) -> None: + if not isinstance(name, str): + raise TypeError( + f"Invalid type for name. Expected str, got {type(name)}: {name!r}" + ) + if value is not None and not isinstance(value, (str, bytes, int, float)): + raise TypeError( + f"Invalid type for value. Expected primitive type, got {type(value)}: {value!r}" + ) + self.name = name + self.value: typing.Union[str, bytes] = ( + value if isinstance(value, bytes) else primitive_value_to_str(value) + ) + + def render_headers(self) -> bytes: + if not hasattr(self, "_headers"): + name = format_form_param("name", self.name) + self._headers = b"".join( + [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"] + ) + + return self._headers + + def render_data(self) -> bytes: + if not hasattr(self, "_data"): + self._data = to_bytes(self.value) + + return self._data + + def get_length(self) -> int: + headers = self.render_headers() + data = self.render_data() + return len(headers) + len(data) + + def render(self) -> typing.Iterator[bytes]: + yield self.render_headers() + yield self.render_data() + + +class FileField: + """ + A single file field item, within a multipart form field. + """ + + CHUNK_SIZE = 64 * 1024 + + def __init__(self, name: str, value: FileTypes) -> None: + self.name = name + + fileobj: FileContent + + headers: typing.Dict[str, str] = {} + content_type: typing.Optional[str] = None + + # This large tuple based API largely mirror's requests' API + # It would be good to think of better APIs for this that we could include in httpx 2.0 + # since variable length tuples (especially of 4 elements) are quite unwieldly + if isinstance(value, tuple): + if len(value) == 2: + # neither the 3rd parameter (content_type) nor the 4th (headers) was included + filename, fileobj = value # type: ignore + elif len(value) == 3: + filename, fileobj, content_type = value # type: ignore + else: + # all 4 parameters included + filename, fileobj, content_type, headers = value # type: ignore + else: + filename = Path(str(getattr(value, "name", "upload"))).name + fileobj = value + + if content_type is None: + content_type = guess_content_type(filename) + + has_content_type_header = any("content-type" in key.lower() for key in headers) + if content_type is not None and not has_content_type_header: + # note that unlike requests, we ignore the content_type + # provided in the 3rd tuple element if it is also included in the headers + # requests does the opposite (it overwrites the header with the 3rd tuple element) + headers["Content-Type"] = content_type + + if isinstance(fileobj, io.StringIO): + raise TypeError( + "Multipart file uploads require 'io.BytesIO', not 'io.StringIO'." + ) + if isinstance(fileobj, io.TextIOBase): + raise TypeError( + "Multipart file uploads must be opened in binary mode, not text mode." + ) + + self.filename = filename + self.file = fileobj + self.headers = headers + + def get_length(self) -> typing.Optional[int]: + headers = self.render_headers() + + if isinstance(self.file, (str, bytes)): + return len(headers) + len(to_bytes(self.file)) + + file_length = peek_filelike_length(self.file) + + # If we can't determine the filesize without reading it into memory, + # then return `None` here, to indicate an unknown file length. + if file_length is None: + return None + + return len(headers) + file_length + + def render_headers(self) -> bytes: + if not hasattr(self, "_headers"): + parts = [ + b"Content-Disposition: form-data; ", + format_form_param("name", self.name), + ] + if self.filename: + filename = format_form_param("filename", self.filename) + parts.extend([b"; ", filename]) + for header_name, header_value in self.headers.items(): + key, val = f"\r\n{header_name}: ".encode(), header_value.encode() + parts.extend([key, val]) + parts.append(b"\r\n\r\n") + self._headers = b"".join(parts) + + return self._headers + + def render_data(self) -> typing.Iterator[bytes]: + if isinstance(self.file, (str, bytes)): + yield to_bytes(self.file) + return + + if hasattr(self.file, "seek"): + try: + self.file.seek(0) + except io.UnsupportedOperation: + pass + + chunk = self.file.read(self.CHUNK_SIZE) + while chunk: + yield to_bytes(chunk) + chunk = self.file.read(self.CHUNK_SIZE) + + def render(self) -> typing.Iterator[bytes]: + yield self.render_headers() + yield from self.render_data() + + +class MultipartStream(SyncByteStream, AsyncByteStream): + """ + Request content as streaming multipart encoded form data. + """ + + def __init__( + self, + data: RequestData, + files: RequestFiles, + boundary: typing.Optional[bytes] = None, + ) -> None: + if boundary is None: + boundary = binascii.hexlify(os.urandom(16)) + + self.boundary = boundary + self.content_type = "multipart/form-data; boundary=%s" % boundary.decode( + "ascii" + ) + self.fields = list(self._iter_fields(data, files)) + + def _iter_fields( + self, data: RequestData, files: RequestFiles + ) -> typing.Iterator[typing.Union[FileField, DataField]]: + for name, value in data.items(): + if isinstance(value, (tuple, list)): + for item in value: + yield DataField(name=name, value=item) + else: + yield DataField(name=name, value=value) + + file_items = files.items() if isinstance(files, typing.Mapping) else files + for name, value in file_items: + yield FileField(name=name, value=value) + + def iter_chunks(self) -> typing.Iterator[bytes]: + for field in self.fields: + yield b"--%s\r\n" % self.boundary + yield from field.render() + yield b"\r\n" + yield b"--%s--\r\n" % self.boundary + + def get_content_length(self) -> typing.Optional[int]: + """ + Return the length of the multipart encoded content, or `None` if + any of the files have a length that cannot be determined upfront. + """ + boundary_length = len(self.boundary) + length = 0 + + for field in self.fields: + field_length = field.get_length() + if field_length is None: + return None + + length += 2 + boundary_length + 2 # b"--{boundary}\r\n" + length += field_length + length += 2 # b"\r\n" + + length += 2 + boundary_length + 4 # b"--{boundary}--\r\n" + return length + + # Content stream interface. + + def get_headers(self) -> typing.Dict[str, str]: + content_length = self.get_content_length() + content_type = self.content_type + if content_length is None: + return {"Transfer-Encoding": "chunked", "Content-Type": content_type} + return {"Content-Length": str(content_length), "Content-Type": content_type} + + def __iter__(self) -> typing.Iterator[bytes]: + for chunk in self.iter_chunks(): + yield chunk + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + for chunk in self.iter_chunks(): + yield chunk diff --git a/contrib/python/httpx/httpx/_status_codes.py b/contrib/python/httpx/httpx/_status_codes.py new file mode 100644 index 0000000000..671c30e1b8 --- /dev/null +++ b/contrib/python/httpx/httpx/_status_codes.py @@ -0,0 +1,158 @@ +from enum import IntEnum + + +class codes(IntEnum): + """HTTP status codes and reason phrases + + Status codes from the following RFCs are all observed: + + * RFC 7231: Hypertext Transfer Protocol (HTTP/1.1), obsoletes 2616 + * RFC 6585: Additional HTTP Status Codes + * RFC 3229: Delta encoding in HTTP + * RFC 4918: HTTP Extensions for WebDAV, obsoletes 2518 + * RFC 5842: Binding Extensions to WebDAV + * RFC 7238: Permanent Redirect + * RFC 2295: Transparent Content Negotiation in HTTP + * RFC 2774: An HTTP Extension Framework + * RFC 7540: Hypertext Transfer Protocol Version 2 (HTTP/2) + * RFC 2324: Hyper Text Coffee Pot Control Protocol (HTCPCP/1.0) + * RFC 7725: An HTTP Status Code to Report Legal Obstacles + * RFC 8297: An HTTP Status Code for Indicating Hints + * RFC 8470: Using Early Data in HTTP + """ + + def __new__(cls, value: int, phrase: str = "") -> "codes": + obj = int.__new__(cls, value) + obj._value_ = value + + obj.phrase = phrase # type: ignore[attr-defined] + return obj + + def __str__(self) -> str: + return str(self.value) + + @classmethod + def get_reason_phrase(cls, value: int) -> str: + try: + return codes(value).phrase # type: ignore + except ValueError: + return "" + + @classmethod + def is_informational(cls, value: int) -> bool: + """ + Returns `True` for 1xx status codes, `False` otherwise. + """ + return 100 <= value <= 199 + + @classmethod + def is_success(cls, value: int) -> bool: + """ + Returns `True` for 2xx status codes, `False` otherwise. + """ + return 200 <= value <= 299 + + @classmethod + def is_redirect(cls, value: int) -> bool: + """ + Returns `True` for 3xx status codes, `False` otherwise. + """ + return 300 <= value <= 399 + + @classmethod + def is_client_error(cls, value: int) -> bool: + """ + Returns `True` for 4xx status codes, `False` otherwise. + """ + return 400 <= value <= 499 + + @classmethod + def is_server_error(cls, value: int) -> bool: + """ + Returns `True` for 5xx status codes, `False` otherwise. + """ + return 500 <= value <= 599 + + @classmethod + def is_error(cls, value: int) -> bool: + """ + Returns `True` for 4xx or 5xx status codes, `False` otherwise. + """ + return 400 <= value <= 599 + + # informational + CONTINUE = 100, "Continue" + SWITCHING_PROTOCOLS = 101, "Switching Protocols" + PROCESSING = 102, "Processing" + EARLY_HINTS = 103, "Early Hints" + + # success + OK = 200, "OK" + CREATED = 201, "Created" + ACCEPTED = 202, "Accepted" + NON_AUTHORITATIVE_INFORMATION = 203, "Non-Authoritative Information" + NO_CONTENT = 204, "No Content" + RESET_CONTENT = 205, "Reset Content" + PARTIAL_CONTENT = 206, "Partial Content" + MULTI_STATUS = 207, "Multi-Status" + ALREADY_REPORTED = 208, "Already Reported" + IM_USED = 226, "IM Used" + + # redirection + MULTIPLE_CHOICES = 300, "Multiple Choices" + MOVED_PERMANENTLY = 301, "Moved Permanently" + FOUND = 302, "Found" + SEE_OTHER = 303, "See Other" + NOT_MODIFIED = 304, "Not Modified" + USE_PROXY = 305, "Use Proxy" + TEMPORARY_REDIRECT = 307, "Temporary Redirect" + PERMANENT_REDIRECT = 308, "Permanent Redirect" + + # client error + BAD_REQUEST = 400, "Bad Request" + UNAUTHORIZED = 401, "Unauthorized" + PAYMENT_REQUIRED = 402, "Payment Required" + FORBIDDEN = 403, "Forbidden" + NOT_FOUND = 404, "Not Found" + METHOD_NOT_ALLOWED = 405, "Method Not Allowed" + NOT_ACCEPTABLE = 406, "Not Acceptable" + PROXY_AUTHENTICATION_REQUIRED = 407, "Proxy Authentication Required" + REQUEST_TIMEOUT = 408, "Request Timeout" + CONFLICT = 409, "Conflict" + GONE = 410, "Gone" + LENGTH_REQUIRED = 411, "Length Required" + PRECONDITION_FAILED = 412, "Precondition Failed" + REQUEST_ENTITY_TOO_LARGE = 413, "Request Entity Too Large" + REQUEST_URI_TOO_LONG = 414, "Request-URI Too Long" + UNSUPPORTED_MEDIA_TYPE = 415, "Unsupported Media Type" + REQUESTED_RANGE_NOT_SATISFIABLE = 416, "Requested Range Not Satisfiable" + EXPECTATION_FAILED = 417, "Expectation Failed" + IM_A_TEAPOT = 418, "I'm a teapot" + MISDIRECTED_REQUEST = 421, "Misdirected Request" + UNPROCESSABLE_ENTITY = 422, "Unprocessable Entity" + LOCKED = 423, "Locked" + FAILED_DEPENDENCY = 424, "Failed Dependency" + TOO_EARLY = 425, "Too Early" + UPGRADE_REQUIRED = 426, "Upgrade Required" + PRECONDITION_REQUIRED = 428, "Precondition Required" + TOO_MANY_REQUESTS = 429, "Too Many Requests" + REQUEST_HEADER_FIELDS_TOO_LARGE = 431, "Request Header Fields Too Large" + UNAVAILABLE_FOR_LEGAL_REASONS = 451, "Unavailable For Legal Reasons" + + # server errors + INTERNAL_SERVER_ERROR = 500, "Internal Server Error" + NOT_IMPLEMENTED = 501, "Not Implemented" + BAD_GATEWAY = 502, "Bad Gateway" + SERVICE_UNAVAILABLE = 503, "Service Unavailable" + GATEWAY_TIMEOUT = 504, "Gateway Timeout" + HTTP_VERSION_NOT_SUPPORTED = 505, "HTTP Version Not Supported" + VARIANT_ALSO_NEGOTIATES = 506, "Variant Also Negotiates" + INSUFFICIENT_STORAGE = 507, "Insufficient Storage" + LOOP_DETECTED = 508, "Loop Detected" + NOT_EXTENDED = 510, "Not Extended" + NETWORK_AUTHENTICATION_REQUIRED = 511, "Network Authentication Required" + + +# Include lower-case styles for `requests` compatibility. +for code in codes: + setattr(codes, code._name_.lower(), int(code)) diff --git a/contrib/python/httpx/httpx/_transports/__init__.py b/contrib/python/httpx/httpx/_transports/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/__init__.py diff --git a/contrib/python/httpx/httpx/_transports/asgi.py b/contrib/python/httpx/httpx/_transports/asgi.py new file mode 100644 index 0000000000..f67f0fbd5b --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/asgi.py @@ -0,0 +1,179 @@ +import typing + +import sniffio + +from .._models import Request, Response +from .._types import AsyncByteStream +from .base import AsyncBaseTransport + +if typing.TYPE_CHECKING: # pragma: no cover + import asyncio + + import trio + + Event = typing.Union[asyncio.Event, trio.Event] + + +_Message = typing.Dict[str, typing.Any] +_Receive = typing.Callable[[], typing.Awaitable[_Message]] +_Send = typing.Callable[ + [typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None] +] +_ASGIApp = typing.Callable[ + [typing.Dict[str, typing.Any], _Receive, _Send], typing.Coroutine[None, None, None] +] + + +def create_event() -> "Event": + if sniffio.current_async_library() == "trio": + import trio + + return trio.Event() + else: + import asyncio + + return asyncio.Event() + + +class ASGIResponseStream(AsyncByteStream): + def __init__(self, body: typing.List[bytes]) -> None: + self._body = body + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + yield b"".join(self._body) + + +class ASGITransport(AsyncBaseTransport): + """ + A custom AsyncTransport that handles sending requests directly to an ASGI app. + The simplest way to use this functionality is to use the `app` argument. + + ``` + client = httpx.AsyncClient(app=app) + ``` + + Alternatively, you can setup the transport instance explicitly. + This allows you to include any additional configuration arguments specific + to the ASGITransport class: + + ``` + transport = httpx.ASGITransport( + app=app, + root_path="/submount", + client=("1.2.3.4", 123) + ) + client = httpx.AsyncClient(transport=transport) + ``` + + Arguments: + + * `app` - The ASGI application. + * `raise_app_exceptions` - Boolean indicating if exceptions in the application + should be raised. Default to `True`. Can be set to `False` for use cases + such as testing the content of a client 500 response. + * `root_path` - The root path on which the ASGI application should be mounted. + * `client` - A two-tuple indicating the client IP and port of incoming requests. + ``` + """ + + def __init__( + self, + app: _ASGIApp, + raise_app_exceptions: bool = True, + root_path: str = "", + client: typing.Tuple[str, int] = ("127.0.0.1", 123), + ) -> None: + self.app = app + self.raise_app_exceptions = raise_app_exceptions + self.root_path = root_path + self.client = client + + async def handle_async_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + # ASGI scope. + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path, + "query_string": request.url.query, + "server": (request.url.host, request.url.port), + "client": self.client, + "root_path": self.root_path, + } + + # Request. + request_body_chunks = request.stream.__aiter__() + request_complete = False + + # Response. + status_code = None + response_headers = None + body_parts = [] + response_started = False + response_complete = create_event() + + # ASGI callables. + + async def receive() -> typing.Dict[str, typing.Any]: + nonlocal request_complete + + if request_complete: + await response_complete.wait() + return {"type": "http.disconnect"} + + try: + body = await request_body_chunks.__anext__() + except StopAsyncIteration: + request_complete = True + return {"type": "http.request", "body": b"", "more_body": False} + return {"type": "http.request", "body": body, "more_body": True} + + async def send(message: typing.Dict[str, typing.Any]) -> None: + nonlocal status_code, response_headers, response_started + + if message["type"] == "http.response.start": + assert not response_started + + status_code = message["status"] + response_headers = message.get("headers", []) + response_started = True + + elif message["type"] == "http.response.body": + assert not response_complete.is_set() + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body and request.method != "HEAD": + body_parts.append(body) + + if not more_body: + response_complete.set() + + try: + await self.app(scope, receive, send) + except Exception: # noqa: PIE-786 + if self.raise_app_exceptions: + raise + + response_complete.set() + if status_code is None: + status_code = 500 + if response_headers is None: + response_headers = {} + + assert response_complete.is_set() + assert status_code is not None + assert response_headers is not None + + stream = ASGIResponseStream(body_parts) + + return Response(status_code, headers=response_headers, stream=stream) diff --git a/contrib/python/httpx/httpx/_transports/base.py b/contrib/python/httpx/httpx/_transports/base.py new file mode 100644 index 0000000000..f6fdfe6943 --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/base.py @@ -0,0 +1,82 @@ +import typing +from types import TracebackType + +from .._models import Request, Response + +T = typing.TypeVar("T", bound="BaseTransport") +A = typing.TypeVar("A", bound="AsyncBaseTransport") + + +class BaseTransport: + def __enter__(self: T) -> T: + return self + + def __exit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + self.close() + + def handle_request(self, request: Request) -> Response: + """ + Send a single HTTP request and return a response. + + Developers shouldn't typically ever need to call into this API directly, + since the Client class provides all the higher level user-facing API + niceties. + + In order to properly release any network resources, the response + stream should *either* be consumed immediately, with a call to + `response.stream.read()`, or else the `handle_request` call should + be followed with a try/finally block to ensuring the stream is + always closed. + + Example usage: + + with httpx.HTTPTransport() as transport: + req = httpx.Request( + method=b"GET", + url=(b"https", b"www.example.com", 443, b"/"), + headers=[(b"Host", b"www.example.com")], + ) + resp = transport.handle_request(req) + body = resp.stream.read() + print(resp.status_code, resp.headers, body) + + + Takes a `Request` instance as the only argument. + + Returns a `Response` instance. + """ + raise NotImplementedError( + "The 'handle_request' method must be implemented." + ) # pragma: no cover + + def close(self) -> None: + pass + + +class AsyncBaseTransport: + async def __aenter__(self: A) -> A: + return self + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + await self.aclose() + + async def handle_async_request( + self, + request: Request, + ) -> Response: + raise NotImplementedError( + "The 'handle_async_request' method must be implemented." + ) # pragma: no cover + + async def aclose(self) -> None: + pass diff --git a/contrib/python/httpx/httpx/_transports/default.py b/contrib/python/httpx/httpx/_transports/default.py new file mode 100644 index 0000000000..7dba5b8208 --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/default.py @@ -0,0 +1,378 @@ +""" +Custom transports, with nicely configured defaults. + +The following additional keyword arguments are currently supported by httpcore... + +* uds: str +* local_address: str +* retries: int + +Example usages... + +# Disable HTTP/2 on a single specific domain. +mounts = { + "all://": httpx.HTTPTransport(http2=True), + "all://*example.org": httpx.HTTPTransport() +} + +# Using advanced httpcore configuration, with connection retries. +transport = httpx.HTTPTransport(retries=1) +client = httpx.Client(transport=transport) + +# Using advanced httpcore configuration, with unix domain sockets. +transport = httpx.HTTPTransport(uds="socket.uds") +client = httpx.Client(transport=transport) +""" +import contextlib +import typing +from types import TracebackType + +import httpcore + +from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context +from .._exceptions import ( + ConnectError, + ConnectTimeout, + LocalProtocolError, + NetworkError, + PoolTimeout, + ProtocolError, + ProxyError, + ReadError, + ReadTimeout, + RemoteProtocolError, + TimeoutException, + UnsupportedProtocol, + WriteError, + WriteTimeout, +) +from .._models import Request, Response +from .._types import AsyncByteStream, CertTypes, SyncByteStream, VerifyTypes +from .base import AsyncBaseTransport, BaseTransport + +T = typing.TypeVar("T", bound="HTTPTransport") +A = typing.TypeVar("A", bound="AsyncHTTPTransport") + +SOCKET_OPTION = typing.Union[ + typing.Tuple[int, int, int], + typing.Tuple[int, int, typing.Union[bytes, bytearray]], + typing.Tuple[int, int, None, int], +] + + +@contextlib.contextmanager +def map_httpcore_exceptions() -> typing.Iterator[None]: + try: + yield + except Exception as exc: # noqa: PIE-786 + mapped_exc = None + + for from_exc, to_exc in HTTPCORE_EXC_MAP.items(): + if not isinstance(exc, from_exc): + continue + # We want to map to the most specific exception we can find. + # Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to + # `httpx.ReadTimeout`, not just `httpx.TimeoutException`. + if mapped_exc is None or issubclass(to_exc, mapped_exc): + mapped_exc = to_exc + + if mapped_exc is None: # pragma: no cover + raise + + message = str(exc) + raise mapped_exc(message) from exc + + +HTTPCORE_EXC_MAP = { + httpcore.TimeoutException: TimeoutException, + httpcore.ConnectTimeout: ConnectTimeout, + httpcore.ReadTimeout: ReadTimeout, + httpcore.WriteTimeout: WriteTimeout, + httpcore.PoolTimeout: PoolTimeout, + httpcore.NetworkError: NetworkError, + httpcore.ConnectError: ConnectError, + httpcore.ReadError: ReadError, + httpcore.WriteError: WriteError, + httpcore.ProxyError: ProxyError, + httpcore.UnsupportedProtocol: UnsupportedProtocol, + httpcore.ProtocolError: ProtocolError, + httpcore.LocalProtocolError: LocalProtocolError, + httpcore.RemoteProtocolError: RemoteProtocolError, +} + + +class ResponseStream(SyncByteStream): + def __init__(self, httpcore_stream: typing.Iterable[bytes]): + self._httpcore_stream = httpcore_stream + + def __iter__(self) -> typing.Iterator[bytes]: + with map_httpcore_exceptions(): + for part in self._httpcore_stream: + yield part + + def close(self) -> None: + if hasattr(self._httpcore_stream, "close"): + self._httpcore_stream.close() + + +class HTTPTransport(BaseTransport): + def __init__( + self, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + trust_env: bool = True, + proxy: typing.Optional[Proxy] = None, + uds: typing.Optional[str] = None, + local_address: typing.Optional[str] = None, + retries: int = 0, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> None: + ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + + if proxy is None: + self._pool = httpcore.ConnectionPool( + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + uds=uds, + local_address=local_address, + retries=retries, + socket_options=socket_options, + ) + elif proxy.url.scheme in ("http", "https"): + self._pool = httpcore.HTTPProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + proxy_headers=proxy.headers.raw, + ssl_context=ssl_context, + proxy_ssl_context=proxy.ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + socket_options=socket_options, + ) + elif proxy.url.scheme == "socks5": + try: + import socksio # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using SOCKS proxy, but the 'socksio' package is not installed. " + "Make sure to install httpx using `pip install httpx[socks]`." + ) from None + + self._pool = httpcore.SOCKSProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + ) + else: # pragma: no cover + raise ValueError( + f"Proxy protocol must be either 'http', 'https', or 'socks5', but got {proxy.url.scheme!r}." + ) + + def __enter__(self: T) -> T: # Use generics for subclass support. + self._pool.__enter__() + return self + + def __exit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + with map_httpcore_exceptions(): + self._pool.__exit__(exc_type, exc_value, traceback) + + def handle_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, SyncByteStream) + + req = httpcore.Request( + method=request.method, + url=httpcore.URL( + scheme=request.url.raw_scheme, + host=request.url.raw_host, + port=request.url.port, + target=request.url.raw_path, + ), + headers=request.headers.raw, + content=request.stream, + extensions=request.extensions, + ) + with map_httpcore_exceptions(): + resp = self._pool.handle_request(req) + + assert isinstance(resp.stream, typing.Iterable) + + return Response( + status_code=resp.status, + headers=resp.headers, + stream=ResponseStream(resp.stream), + extensions=resp.extensions, + ) + + def close(self) -> None: + self._pool.close() + + +class AsyncResponseStream(AsyncByteStream): + def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]): + self._httpcore_stream = httpcore_stream + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + with map_httpcore_exceptions(): + async for part in self._httpcore_stream: + yield part + + async def aclose(self) -> None: + if hasattr(self._httpcore_stream, "aclose"): + await self._httpcore_stream.aclose() + + +class AsyncHTTPTransport(AsyncBaseTransport): + def __init__( + self, + verify: VerifyTypes = True, + cert: typing.Optional[CertTypes] = None, + http1: bool = True, + http2: bool = False, + limits: Limits = DEFAULT_LIMITS, + trust_env: bool = True, + proxy: typing.Optional[Proxy] = None, + uds: typing.Optional[str] = None, + local_address: typing.Optional[str] = None, + retries: int = 0, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> None: + ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) + + if proxy is None: + self._pool = httpcore.AsyncConnectionPool( + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + uds=uds, + local_address=local_address, + retries=retries, + socket_options=socket_options, + ) + elif proxy.url.scheme in ("http", "https"): + self._pool = httpcore.AsyncHTTPProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + proxy_headers=proxy.headers.raw, + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + socket_options=socket_options, + ) + elif proxy.url.scheme == "socks5": + try: + import socksio # noqa + except ImportError: # pragma: no cover + raise ImportError( + "Using SOCKS proxy, but the 'socksio' package is not installed. " + "Make sure to install httpx using `pip install httpx[socks]`." + ) from None + + self._pool = httpcore.AsyncSOCKSProxy( + proxy_url=httpcore.URL( + scheme=proxy.url.raw_scheme, + host=proxy.url.raw_host, + port=proxy.url.port, + target=proxy.url.raw_path, + ), + proxy_auth=proxy.raw_auth, + ssl_context=ssl_context, + max_connections=limits.max_connections, + max_keepalive_connections=limits.max_keepalive_connections, + keepalive_expiry=limits.keepalive_expiry, + http1=http1, + http2=http2, + ) + else: # pragma: no cover + raise ValueError( + f"Proxy protocol must be either 'http', 'https', or 'socks5', but got {proxy.url.scheme!r}." + ) + + async def __aenter__(self: A) -> A: # Use generics for subclass support. + await self._pool.__aenter__() + return self + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + with map_httpcore_exceptions(): + await self._pool.__aexit__(exc_type, exc_value, traceback) + + async def handle_async_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + req = httpcore.Request( + method=request.method, + url=httpcore.URL( + scheme=request.url.raw_scheme, + host=request.url.raw_host, + port=request.url.port, + target=request.url.raw_path, + ), + headers=request.headers.raw, + content=request.stream, + extensions=request.extensions, + ) + with map_httpcore_exceptions(): + resp = await self._pool.handle_async_request(req) + + assert isinstance(resp.stream, typing.AsyncIterable) + + return Response( + status_code=resp.status, + headers=resp.headers, + stream=AsyncResponseStream(resp.stream), + extensions=resp.extensions, + ) + + async def aclose(self) -> None: + await self._pool.aclose() diff --git a/contrib/python/httpx/httpx/_transports/mock.py b/contrib/python/httpx/httpx/_transports/mock.py new file mode 100644 index 0000000000..82043da2d9 --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/mock.py @@ -0,0 +1,38 @@ +import typing + +from .._models import Request, Response +from .base import AsyncBaseTransport, BaseTransport + +SyncHandler = typing.Callable[[Request], Response] +AsyncHandler = typing.Callable[[Request], typing.Coroutine[None, None, Response]] + + +class MockTransport(AsyncBaseTransport, BaseTransport): + def __init__(self, handler: typing.Union[SyncHandler, AsyncHandler]) -> None: + self.handler = handler + + def handle_request( + self, + request: Request, + ) -> Response: + request.read() + response = self.handler(request) + if not isinstance(response, Response): # pragma: no cover + raise TypeError("Cannot use an async handler in a sync Client") + return response + + async def handle_async_request( + self, + request: Request, + ) -> Response: + await request.aread() + response = self.handler(request) + + # Allow handler to *optionally* be an `async` function. + # If it is, then the `response` variable need to be awaited to actually + # return the result. + + if not isinstance(response, Response): + response = await response + + return response diff --git a/contrib/python/httpx/httpx/_transports/wsgi.py b/contrib/python/httpx/httpx/_transports/wsgi.py new file mode 100644 index 0000000000..a23d42c414 --- /dev/null +++ b/contrib/python/httpx/httpx/_transports/wsgi.py @@ -0,0 +1,144 @@ +import io +import itertools +import sys +import typing + +from .._models import Request, Response +from .._types import SyncByteStream +from .base import BaseTransport + +if typing.TYPE_CHECKING: + from _typeshed import OptExcInfo # pragma: no cover + from _typeshed.wsgi import WSGIApplication # pragma: no cover + +_T = typing.TypeVar("_T") + + +def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]: + body = iter(body) + for chunk in body: + if chunk: + return itertools.chain([chunk], body) + return [] + + +class WSGIByteStream(SyncByteStream): + def __init__(self, result: typing.Iterable[bytes]) -> None: + self._close = getattr(result, "close", None) + self._result = _skip_leading_empty_chunks(result) + + def __iter__(self) -> typing.Iterator[bytes]: + for part in self._result: + yield part + + def close(self) -> None: + if self._close is not None: + self._close() + + +class WSGITransport(BaseTransport): + """ + A custom transport that handles sending requests directly to an WSGI app. + The simplest way to use this functionality is to use the `app` argument. + + ``` + client = httpx.Client(app=app) + ``` + + Alternatively, you can setup the transport instance explicitly. + This allows you to include any additional configuration arguments specific + to the WSGITransport class: + + ``` + transport = httpx.WSGITransport( + app=app, + script_name="/submount", + remote_addr="1.2.3.4" + ) + client = httpx.Client(transport=transport) + ``` + + Arguments: + + * `app` - The WSGI application. + * `raise_app_exceptions` - Boolean indicating if exceptions in the application + should be raised. Default to `True`. Can be set to `False` for use cases + such as testing the content of a client 500 response. + * `script_name` - The root path on which the WSGI application should be mounted. + * `remote_addr` - A string indicating the client IP of incoming requests. + ``` + """ + + def __init__( + self, + app: "WSGIApplication", + raise_app_exceptions: bool = True, + script_name: str = "", + remote_addr: str = "127.0.0.1", + wsgi_errors: typing.Optional[typing.TextIO] = None, + ) -> None: + self.app = app + self.raise_app_exceptions = raise_app_exceptions + self.script_name = script_name + self.remote_addr = remote_addr + self.wsgi_errors = wsgi_errors + + def handle_request(self, request: Request) -> Response: + request.read() + wsgi_input = io.BytesIO(request.content) + + port = request.url.port or {"http": 80, "https": 443}[request.url.scheme] + environ = { + "wsgi.version": (1, 0), + "wsgi.url_scheme": request.url.scheme, + "wsgi.input": wsgi_input, + "wsgi.errors": self.wsgi_errors or sys.stderr, + "wsgi.multithread": True, + "wsgi.multiprocess": False, + "wsgi.run_once": False, + "REQUEST_METHOD": request.method, + "SCRIPT_NAME": self.script_name, + "PATH_INFO": request.url.path, + "QUERY_STRING": request.url.query.decode("ascii"), + "SERVER_NAME": request.url.host, + "SERVER_PORT": str(port), + "SERVER_PROTOCOL": "HTTP/1.1", + "REMOTE_ADDR": self.remote_addr, + } + for header_key, header_value in request.headers.raw: + key = header_key.decode("ascii").upper().replace("-", "_") + if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): + key = "HTTP_" + key + environ[key] = header_value.decode("ascii") + + seen_status = None + seen_response_headers = None + seen_exc_info = None + + def start_response( + status: str, + response_headers: typing.List[typing.Tuple[str, str]], + exc_info: typing.Optional["OptExcInfo"] = None, + ) -> typing.Callable[[bytes], typing.Any]: + nonlocal seen_status, seen_response_headers, seen_exc_info + seen_status = status + seen_response_headers = response_headers + seen_exc_info = exc_info + return lambda _: None + + result = self.app(environ, start_response) + + stream = WSGIByteStream(result) + + assert seen_status is not None + assert seen_response_headers is not None + if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions: + raise seen_exc_info[1] + + status_code = int(seen_status.split()[0]) + headers = [ + (key.encode("ascii"), value.encode("ascii")) + for key, value in seen_response_headers + ] + + return Response(status_code, headers=headers, stream=stream) diff --git a/contrib/python/httpx/httpx/_types.py b/contrib/python/httpx/httpx/_types.py new file mode 100644 index 0000000000..83cf35a32a --- /dev/null +++ b/contrib/python/httpx/httpx/_types.py @@ -0,0 +1,133 @@ +""" +Type definitions for type checking purposes. +""" + +import ssl +from http.cookiejar import CookieJar +from typing import ( + IO, + TYPE_CHECKING, + Any, + AsyncIterable, + AsyncIterator, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) + +if TYPE_CHECKING: # pragma: no cover + from ._auth import Auth # noqa: F401 + from ._config import Proxy, Timeout # noqa: F401 + from ._models import Cookies, Headers, Request # noqa: F401 + from ._urls import URL, QueryParams # noqa: F401 + + +PrimitiveData = Optional[Union[str, int, float, bool]] + +RawURL = NamedTuple( + "RawURL", + [ + ("raw_scheme", bytes), + ("raw_host", bytes), + ("port", Optional[int]), + ("raw_path", bytes), + ], +) + +URLTypes = Union["URL", str] + +QueryParamTypes = Union[ + "QueryParams", + Mapping[str, Union[PrimitiveData, Sequence[PrimitiveData]]], + List[Tuple[str, PrimitiveData]], + Tuple[Tuple[str, PrimitiveData], ...], + str, + bytes, +] + +HeaderTypes = Union[ + "Headers", + Mapping[str, str], + Mapping[bytes, bytes], + Sequence[Tuple[str, str]], + Sequence[Tuple[bytes, bytes]], +] + +CookieTypes = Union["Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]] + +CertTypes = Union[ + # certfile + str, + # (certfile, keyfile) + Tuple[str, Optional[str]], + # (certfile, keyfile, password) + Tuple[str, Optional[str], Optional[str]], +] +VerifyTypes = Union[str, bool, ssl.SSLContext] +TimeoutTypes = Union[ + Optional[float], + Tuple[Optional[float], Optional[float], Optional[float], Optional[float]], + "Timeout", +] +ProxiesTypes = Union[URLTypes, "Proxy", Dict[URLTypes, Union[None, URLTypes, "Proxy"]]] + +AuthTypes = Union[ + Tuple[Union[str, bytes], Union[str, bytes]], + Callable[["Request"], "Request"], + "Auth", +] + +RequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] +ResponseContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] +ResponseExtensions = MutableMapping[str, Any] + +RequestData = Mapping[str, Any] + +FileContent = Union[IO[bytes], bytes, str] +FileTypes = Union[ + # file (or bytes) + FileContent, + # (filename, file (or bytes)) + Tuple[Optional[str], FileContent], + # (filename, file (or bytes), content_type) + Tuple[Optional[str], FileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], +] +RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] + +RequestExtensions = MutableMapping[str, Any] + + +class SyncByteStream: + def __iter__(self) -> Iterator[bytes]: + raise NotImplementedError( + "The '__iter__' method must be implemented." + ) # pragma: no cover + yield b"" # pragma: no cover + + def close(self) -> None: + """ + Subclasses can override this method to release any network resources + after a request/response cycle is complete. + """ + + +class AsyncByteStream: + async def __aiter__(self) -> AsyncIterator[bytes]: + raise NotImplementedError( + "The '__aiter__' method must be implemented." + ) # pragma: no cover + yield b"" # pragma: no cover + + async def aclose(self) -> None: + pass diff --git a/contrib/python/httpx/httpx/_urlparse.py b/contrib/python/httpx/httpx/_urlparse.py new file mode 100644 index 0000000000..e1ba8dcdb7 --- /dev/null +++ b/contrib/python/httpx/httpx/_urlparse.py @@ -0,0 +1,464 @@ +""" +An implementation of `urlparse` that provides URL validation and normalization +as described by RFC3986. + +We rely on this implementation rather than the one in Python's stdlib, because: + +* It provides more complete URL validation. +* It properly differentiates between an empty querystring and an absent querystring, + to distinguish URLs with a trailing '?'. +* It handles scheme, hostname, port, and path normalization. +* It supports IDNA hostnames, normalizing them to their encoded form. +* The API supports passing individual components, as well as the complete URL string. + +Previously we relied on the excellent `rfc3986` package to handle URL parsing and +validation, but this module provides a simpler alternative, with less indirection +required. +""" +import ipaddress +import re +import typing + +import idna + +from ._exceptions import InvalidURL + +MAX_URL_LENGTH = 65536 + +# https://datatracker.ietf.org/doc/html/rfc3986.html#section-2.3 +UNRESERVED_CHARACTERS = ( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" +) +SUB_DELIMS = "!$&'()*+,;=" + +PERCENT_ENCODED_REGEX = re.compile("%[A-Fa-f0-9]{2}") + + +# {scheme}: (optional) +# //{authority} (optional) +# {path} +# ?{query} (optional) +# #{fragment} (optional) +URL_REGEX = re.compile( + ( + r"(?:(?P<scheme>{scheme}):)?" + r"(?://(?P<authority>{authority}))?" + r"(?P<path>{path})" + r"(?:\?(?P<query>{query}))?" + r"(?:#(?P<fragment>{fragment}))?" + ).format( + scheme="([a-zA-Z][a-zA-Z0-9+.-]*)?", + authority="[^/?#]*", + path="[^?#]*", + query="[^#]*", + fragment=".*", + ) +) + +# {userinfo}@ (optional) +# {host} +# :{port} (optional) +AUTHORITY_REGEX = re.compile( + ( + r"(?:(?P<userinfo>{userinfo})@)?" r"(?P<host>{host})" r":?(?P<port>{port})?" + ).format( + userinfo="[^@]*", # Any character sequence not including '@'. + host="(\\[.*\\]|[^:]*)", # Either any character sequence not including ':', + # or an IPv6 address enclosed within square brackets. + port=".*", # Any character sequence. + ) +) + + +# If we call urlparse with an individual component, then we need to regex +# validate that component individually. +# Note that we're duplicating the same strings as above. Shock! Horror!! +COMPONENT_REGEX = { + "scheme": re.compile("([a-zA-Z][a-zA-Z0-9+.-]*)?"), + "authority": re.compile("[^/?#]*"), + "path": re.compile("[^?#]*"), + "query": re.compile("[^#]*"), + "fragment": re.compile(".*"), + "userinfo": re.compile("[^@]*"), + "host": re.compile("(\\[.*\\]|[^:]*)"), + "port": re.compile(".*"), +} + + +# We use these simple regexs as a first pass before handing off to +# the stdlib 'ipaddress' module for IP address validation. +IPv4_STYLE_HOSTNAME = re.compile(r"^[0-9]+.[0-9]+.[0-9]+.[0-9]+$") +IPv6_STYLE_HOSTNAME = re.compile(r"^\[.*\]$") + + +class ParseResult(typing.NamedTuple): + scheme: str + userinfo: str + host: str + port: typing.Optional[int] + path: str + query: typing.Optional[str] + fragment: typing.Optional[str] + + @property + def authority(self) -> str: + return "".join( + [ + f"{self.userinfo}@" if self.userinfo else "", + f"[{self.host}]" if ":" in self.host else self.host, + f":{self.port}" if self.port is not None else "", + ] + ) + + @property + def netloc(self) -> str: + return "".join( + [ + f"[{self.host}]" if ":" in self.host else self.host, + f":{self.port}" if self.port is not None else "", + ] + ) + + def copy_with(self, **kwargs: typing.Optional[str]) -> "ParseResult": + if not kwargs: + return self + + defaults = { + "scheme": self.scheme, + "authority": self.authority, + "path": self.path, + "query": self.query, + "fragment": self.fragment, + } + defaults.update(kwargs) + return urlparse("", **defaults) + + def __str__(self) -> str: + authority = self.authority + return "".join( + [ + f"{self.scheme}:" if self.scheme else "", + f"//{authority}" if authority else "", + self.path, + f"?{self.query}" if self.query is not None else "", + f"#{self.fragment}" if self.fragment is not None else "", + ] + ) + + +def urlparse(url: str = "", **kwargs: typing.Optional[str]) -> ParseResult: + # Initial basic checks on allowable URLs. + # --------------------------------------- + + # Hard limit the maximum allowable URL length. + if len(url) > MAX_URL_LENGTH: + raise InvalidURL("URL too long") + + # If a URL includes any ASCII control characters including \t, \r, \n, + # then treat it as invalid. + if any(char.isascii() and not char.isprintable() for char in url): + raise InvalidURL("Invalid non-printable ASCII character in URL") + + # Some keyword arguments require special handling. + # ------------------------------------------------ + + # Coerce "port" to a string, if it is provided as an integer. + if "port" in kwargs: + port = kwargs["port"] + kwargs["port"] = str(port) if isinstance(port, int) else port + + # Replace "netloc" with "host and "port". + if "netloc" in kwargs: + netloc = kwargs.pop("netloc") or "" + kwargs["host"], _, kwargs["port"] = netloc.partition(":") + + # Replace "username" and/or "password" with "userinfo". + if "username" in kwargs or "password" in kwargs: + username = quote(kwargs.pop("username", "") or "") + password = quote(kwargs.pop("password", "") or "") + kwargs["userinfo"] = f"{username}:{password}" if password else username + + # Replace "raw_path" with "path" and "query". + if "raw_path" in kwargs: + raw_path = kwargs.pop("raw_path") or "" + kwargs["path"], seperator, kwargs["query"] = raw_path.partition("?") + if not seperator: + kwargs["query"] = None + + # Ensure that IPv6 "host" addresses are always escaped with "[...]". + if "host" in kwargs: + host = kwargs.get("host") or "" + if ":" in host and not (host.startswith("[") and host.endswith("]")): + kwargs["host"] = f"[{host}]" + + # If any keyword arguments are provided, ensure they are valid. + # ------------------------------------------------------------- + + for key, value in kwargs.items(): + if value is not None: + if len(value) > MAX_URL_LENGTH: + raise InvalidURL(f"URL component '{key}' too long") + + # If a component includes any ASCII control characters including \t, \r, \n, + # then treat it as invalid. + if any(char.isascii() and not char.isprintable() for char in value): + raise InvalidURL( + f"Invalid non-printable ASCII character in URL component '{key}'" + ) + + # Ensure that keyword arguments match as a valid regex. + if not COMPONENT_REGEX[key].fullmatch(value): + raise InvalidURL(f"Invalid URL component '{key}'") + + # The URL_REGEX will always match, but may have empty components. + url_match = URL_REGEX.match(url) + assert url_match is not None + url_dict = url_match.groupdict() + + # * 'scheme', 'authority', and 'path' may be empty strings. + # * 'query' may be 'None', indicating no trailing "?" portion. + # Any string including the empty string, indicates a trailing "?". + # * 'fragment' may be 'None', indicating no trailing "#" portion. + # Any string including the empty string, indicates a trailing "#". + scheme = kwargs.get("scheme", url_dict["scheme"]) or "" + authority = kwargs.get("authority", url_dict["authority"]) or "" + path = kwargs.get("path", url_dict["path"]) or "" + query = kwargs.get("query", url_dict["query"]) + fragment = kwargs.get("fragment", url_dict["fragment"]) + + # The AUTHORITY_REGEX will always match, but may have empty components. + authority_match = AUTHORITY_REGEX.match(authority) + assert authority_match is not None + authority_dict = authority_match.groupdict() + + # * 'userinfo' and 'host' may be empty strings. + # * 'port' may be 'None'. + userinfo = kwargs.get("userinfo", authority_dict["userinfo"]) or "" + host = kwargs.get("host", authority_dict["host"]) or "" + port = kwargs.get("port", authority_dict["port"]) + + # Normalize and validate each component. + # We end up with a parsed representation of the URL, + # with components that are plain ASCII bytestrings. + parsed_scheme: str = scheme.lower() + parsed_userinfo: str = quote(userinfo, safe=SUB_DELIMS + ":") + parsed_host: str = encode_host(host) + parsed_port: typing.Optional[int] = normalize_port(port, scheme) + + has_scheme = parsed_scheme != "" + has_authority = ( + parsed_userinfo != "" or parsed_host != "" or parsed_port is not None + ) + validate_path(path, has_scheme=has_scheme, has_authority=has_authority) + if has_authority: + path = normalize_path(path) + + # The GEN_DELIMS set is... : / ? # [ ] @ + # These do not need to be percent-quoted unless they serve as delimiters for the + # specific component. + + # For 'path' we need to drop ? and # from the GEN_DELIMS set. + parsed_path: str = quote(path, safe=SUB_DELIMS + ":/[]@") + # For 'query' we need to drop '#' from the GEN_DELIMS set. + # We also exclude '/' because it is more robust to replace it with a percent + # encoding despite it not being a requirement of the spec. + parsed_query: typing.Optional[str] = ( + None if query is None else quote(query, safe=SUB_DELIMS + ":?[]@") + ) + # For 'fragment' we can include all of the GEN_DELIMS set. + parsed_fragment: typing.Optional[str] = ( + None if fragment is None else quote(fragment, safe=SUB_DELIMS + ":/?#[]@") + ) + + # The parsed ASCII bytestrings are our canonical form. + # All properties of the URL are derived from these. + return ParseResult( + parsed_scheme, + parsed_userinfo, + parsed_host, + parsed_port, + parsed_path, + parsed_query, + parsed_fragment, + ) + + +def encode_host(host: str) -> str: + if not host: + return "" + + elif IPv4_STYLE_HOSTNAME.match(host): + # Validate IPv4 hostnames like #.#.#.# + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # IPv4address = dec-octet "." dec-octet "." dec-octet "." dec-octet + try: + ipaddress.IPv4Address(host) + except ipaddress.AddressValueError: + raise InvalidURL(f"Invalid IPv4 address: {host!r}") + return host + + elif IPv6_STYLE_HOSTNAME.match(host): + # Validate IPv6 hostnames like [...] + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # "A host identified by an Internet Protocol literal address, version 6 + # [RFC3513] or later, is distinguished by enclosing the IP literal + # within square brackets ("[" and "]"). This is the only place where + # square bracket characters are allowed in the URI syntax." + try: + ipaddress.IPv6Address(host[1:-1]) + except ipaddress.AddressValueError: + raise InvalidURL(f"Invalid IPv6 address: {host!r}") + return host[1:-1] + + elif host.isascii(): + # Regular ASCII hostnames + # + # From https://datatracker.ietf.org/doc/html/rfc3986/#section-3.2.2 + # + # reg-name = *( unreserved / pct-encoded / sub-delims ) + return quote(host.lower(), safe=SUB_DELIMS) + + # IDNA hostnames + try: + return idna.encode(host.lower()).decode("ascii") + except idna.IDNAError: + raise InvalidURL(f"Invalid IDNA hostname: {host!r}") + + +def normalize_port( + port: typing.Optional[typing.Union[str, int]], scheme: str +) -> typing.Optional[int]: + # From https://tools.ietf.org/html/rfc3986#section-3.2.3 + # + # "A scheme may define a default port. For example, the "http" scheme + # defines a default port of "80", corresponding to its reserved TCP + # port number. The type of port designated by the port number (e.g., + # TCP, UDP, SCTP) is defined by the URI scheme. URI producers and + # normalizers should omit the port component and its ":" delimiter if + # port is empty or if its value would be the same as that of the + # scheme's default." + if port is None or port == "": + return None + + try: + port_as_int = int(port) + except ValueError: + raise InvalidURL(f"Invalid port: {port!r}") + + # See https://url.spec.whatwg.org/#url-miscellaneous + default_port = {"ftp": 21, "http": 80, "https": 443, "ws": 80, "wss": 443}.get( + scheme + ) + if port_as_int == default_port: + return None + return port_as_int + + +def validate_path(path: str, has_scheme: bool, has_authority: bool) -> None: + """ + Path validation rules that depend on if the URL contains a scheme or authority component. + + See https://datatracker.ietf.org/doc/html/rfc3986.html#section-3.3 + """ + if has_authority: + # > If a URI contains an authority component, then the path component + # > must either be empty or begin with a slash ("/") character." + if path and not path.startswith("/"): + raise InvalidURL("For absolute URLs, path must be empty or begin with '/'") + else: + # > If a URI does not contain an authority component, then the path cannot begin + # > with two slash characters ("//"). + if path.startswith("//"): + raise InvalidURL( + "URLs with no authority component cannot have a path starting with '//'" + ) + # > In addition, a URI reference (Section 4.1) may be a relative-path reference, in which + # > case the first path segment cannot contain a colon (":") character. + if path.startswith(":") and not has_scheme: + raise InvalidURL( + "URLs with no scheme component cannot have a path starting with ':'" + ) + + +def normalize_path(path: str) -> str: + """ + Drop "." and ".." segments from a URL path. + + For example: + + normalize_path("/path/./to/somewhere/..") == "/path/to" + """ + # https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4 + components = path.split("/") + output: typing.List[str] = [] + for component in components: + if component == ".": + pass + elif component == "..": + if output and output != [""]: + output.pop() + else: + output.append(component) + return "/".join(output) + + +def percent_encode(char: str) -> str: + """ + Replace a single character with the percent-encoded representation. + + Characters outside the ASCII range are represented with their a percent-encoded + representation of their UTF-8 byte sequence. + + For example: + + percent_encode(" ") == "%20" + """ + return "".join([f"%{byte:02x}" for byte in char.encode("utf-8")]).upper() + + +def is_safe(string: str, safe: str = "/") -> bool: + """ + Determine if a given string is already quote-safe. + """ + NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe + "%" + + # All characters must already be non-escaping or '%' + for char in string: + if char not in NON_ESCAPED_CHARS: + return False + + # Any '%' characters must be valid '%xx' escape sequences. + return string.count("%") == len(PERCENT_ENCODED_REGEX.findall(string)) + + +def quote(string: str, safe: str = "/") -> str: + """ + Use percent-encoding to quote a string if required. + """ + if is_safe(string, safe=safe): + return string + + NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe + return "".join( + [char if char in NON_ESCAPED_CHARS else percent_encode(char) for char in string] + ) + + +def urlencode(items: typing.List[typing.Tuple[str, str]]) -> str: + # We can use a much simpler version of the stdlib urlencode here because + # we don't need to handle a bunch of different typing cases, such as bytes vs str. + # + # https://github.com/python/cpython/blob/b2f7b2ef0b5421e01efb8c7bee2ef95d3bab77eb/Lib/urllib/parse.py#L926 + # + # Note that we use '%20' encoding for spaces. and '%2F for '/'. + # This is slightly different than `requests`, but is the behaviour that browsers use. + # + # See + # - https://github.com/encode/httpx/issues/2536 + # - https://github.com/encode/httpx/issues/2721 + # - https://docs.python.org/3/library/urllib.parse.html#urllib.parse.urlencode + return "&".join([quote(k, safe="") + "=" + quote(v, safe="") for k, v in items]) diff --git a/contrib/python/httpx/httpx/_urls.py b/contrib/python/httpx/httpx/_urls.py new file mode 100644 index 0000000000..b023941b62 --- /dev/null +++ b/contrib/python/httpx/httpx/_urls.py @@ -0,0 +1,642 @@ +import typing +from urllib.parse import parse_qs, unquote + +import idna + +from ._types import QueryParamTypes, RawURL, URLTypes +from ._urlparse import urlencode, urlparse +from ._utils import primitive_value_to_str + + +class URL: + """ + url = httpx.URL("HTTPS://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink") + + assert url.scheme == "https" + assert url.username == "jo@email.com" + assert url.password == "a secret" + assert url.userinfo == b"jo%40email.com:a%20secret" + assert url.host == "müller.de" + assert url.raw_host == b"xn--mller-kva.de" + assert url.port == 1234 + assert url.netloc == b"xn--mller-kva.de:1234" + assert url.path == "/pa th" + assert url.query == b"?search=ab" + assert url.raw_path == b"/pa%20th?search=ab" + assert url.fragment == "anchorlink" + + The components of a URL are broken down like this: + + https://jo%40email.com:a%20secret@müller.de:1234/pa%20th?search=ab#anchorlink + [scheme] [ username ] [password] [ host ][port][ path ] [ query ] [fragment] + [ userinfo ] [ netloc ][ raw_path ] + + Note that: + + * `url.scheme` is normalized to always be lowercased. + + * `url.host` is normalized to always be lowercased. Internationalized domain + names are represented in unicode, without IDNA encoding applied. For instance: + + url = httpx.URL("http://中国.icom.museum") + assert url.host == "中国.icom.museum" + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.host == "中国.icom.museum" + + * `url.raw_host` is normalized to always be lowercased, and is IDNA encoded. + + url = httpx.URL("http://中国.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + * `url.port` is either None or an integer. URLs that include the default port for + "http", "https", "ws", "wss", and "ftp" schemes have their port normalized to `None`. + + assert httpx.URL("http://example.com") == httpx.URL("http://example.com:80") + assert httpx.URL("http://example.com").port is None + assert httpx.URL("http://example.com:80").port is None + + * `url.userinfo` is raw bytes, without URL escaping. Usually you'll want to work with + `url.username` and `url.password` instead, which handle the URL escaping. + + * `url.raw_path` is raw bytes of both the path and query, without URL escaping. + This portion is used as the target when constructing HTTP requests. Usually you'll + want to work with `url.path` instead. + + * `url.query` is raw bytes, without URL escaping. A URL query string portion can only + be properly URL escaped when decoding the parameter names and values themselves. + """ + + def __init__( + self, url: typing.Union["URL", str] = "", **kwargs: typing.Any + ) -> None: + if kwargs: + allowed = { + "scheme": str, + "username": str, + "password": str, + "userinfo": bytes, + "host": str, + "port": int, + "netloc": bytes, + "path": str, + "query": bytes, + "raw_path": bytes, + "fragment": str, + "params": object, + } + + # Perform type checking for all supported keyword arguments. + for key, value in kwargs.items(): + if key not in allowed: + message = f"{key!r} is an invalid keyword argument for URL()" + raise TypeError(message) + if value is not None and not isinstance(value, allowed[key]): + expected = allowed[key].__name__ + seen = type(value).__name__ + message = f"Argument {key!r} must be {expected} but got {seen}" + raise TypeError(message) + if isinstance(value, bytes): + kwargs[key] = value.decode("ascii") + + if "params" in kwargs: + # Replace any "params" keyword with the raw "query" instead. + # + # Ensure that empty params use `kwargs["query"] = None` rather + # than `kwargs["query"] = ""`, so that generated URLs do not + # include an empty trailing "?". + params = kwargs.pop("params") + kwargs["query"] = None if not params else str(QueryParams(params)) + + if isinstance(url, str): + self._uri_reference = urlparse(url, **kwargs) + elif isinstance(url, URL): + self._uri_reference = url._uri_reference.copy_with(**kwargs) + else: + raise TypeError( + f"Invalid type for url. Expected str or httpx.URL, got {type(url)}: {url!r}" + ) + + @property + def scheme(self) -> str: + """ + The URL scheme, such as "http", "https". + Always normalised to lowercase. + """ + return self._uri_reference.scheme + + @property + def raw_scheme(self) -> bytes: + """ + The raw bytes representation of the URL scheme, such as b"http", b"https". + Always normalised to lowercase. + """ + return self._uri_reference.scheme.encode("ascii") + + @property + def userinfo(self) -> bytes: + """ + The URL userinfo as a raw bytestring. + For example: b"jo%40email.com:a%20secret". + """ + return self._uri_reference.userinfo.encode("ascii") + + @property + def username(self) -> str: + """ + The URL username as a string, with URL decoding applied. + For example: "jo@email.com" + """ + userinfo = self._uri_reference.userinfo + return unquote(userinfo.partition(":")[0]) + + @property + def password(self) -> str: + """ + The URL password as a string, with URL decoding applied. + For example: "a secret" + """ + userinfo = self._uri_reference.userinfo + return unquote(userinfo.partition(":")[2]) + + @property + def host(self) -> str: + """ + The URL host as a string. + Always normalized to lowercase, with IDNA hosts decoded into unicode. + + Examples: + + url = httpx.URL("http://www.EXAMPLE.org") + assert url.host == "www.example.org" + + url = httpx.URL("http://中国.icom.museum") + assert url.host == "中国.icom.museum" + + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.host == "中国.icom.museum" + + url = httpx.URL("https://[::ffff:192.168.0.1]") + assert url.host == "::ffff:192.168.0.1" + """ + host: str = self._uri_reference.host + + if host.startswith("xn--"): + host = idna.decode(host) + + return host + + @property + def raw_host(self) -> bytes: + """ + The raw bytes representation of the URL host. + Always normalized to lowercase, and IDNA encoded. + + Examples: + + url = httpx.URL("http://www.EXAMPLE.org") + assert url.raw_host == b"www.example.org" + + url = httpx.URL("http://中国.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + url = httpx.URL("http://xn--fiqs8s.icom.museum") + assert url.raw_host == b"xn--fiqs8s.icom.museum" + + url = httpx.URL("https://[::ffff:192.168.0.1]") + assert url.raw_host == b"::ffff:192.168.0.1" + """ + return self._uri_reference.host.encode("ascii") + + @property + def port(self) -> typing.Optional[int]: + """ + The URL port as an integer. + + Note that the URL class performs port normalization as per the WHATWG spec. + Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always + treated as `None`. + + For example: + + assert httpx.URL("http://www.example.com") == httpx.URL("http://www.example.com:80") + assert httpx.URL("http://www.example.com:80").port is None + """ + return self._uri_reference.port + + @property + def netloc(self) -> bytes: + """ + Either `<host>` or `<host>:<port>` as bytes. + Always normalized to lowercase, and IDNA encoded. + + This property may be used for generating the value of a request + "Host" header. + """ + return self._uri_reference.netloc.encode("ascii") + + @property + def path(self) -> str: + """ + The URL path as a string. Excluding the query string, and URL decoded. + + For example: + + url = httpx.URL("https://example.com/pa%20th") + assert url.path == "/pa th" + """ + path = self._uri_reference.path or "/" + return unquote(path) + + @property + def query(self) -> bytes: + """ + The URL query string, as raw bytes, excluding the leading b"?". + + This is necessarily a bytewise interface, because we cannot + perform URL decoding of this representation until we've parsed + the keys and values into a QueryParams instance. + + For example: + + url = httpx.URL("https://example.com/?filter=some%20search%20terms") + assert url.query == b"filter=some%20search%20terms" + """ + query = self._uri_reference.query or "" + return query.encode("ascii") + + @property + def params(self) -> "QueryParams": + """ + The URL query parameters, neatly parsed and packaged into an immutable + multidict representation. + """ + return QueryParams(self._uri_reference.query) + + @property + def raw_path(self) -> bytes: + """ + The complete URL path and query string as raw bytes. + Used as the target when constructing HTTP requests. + + For example: + + GET /users?search=some%20text HTTP/1.1 + Host: www.example.org + Connection: close + """ + path = self._uri_reference.path or "/" + if self._uri_reference.query is not None: + path += "?" + self._uri_reference.query + return path.encode("ascii") + + @property + def fragment(self) -> str: + """ + The URL fragments, as used in HTML anchors. + As a string, without the leading '#'. + """ + return unquote(self._uri_reference.fragment or "") + + @property + def raw(self) -> RawURL: + """ + Provides the (scheme, host, port, target) for the outgoing request. + + In older versions of `httpx` this was used in the low-level transport API. + We no longer use `RawURL`, and this property will be deprecated in a future release. + """ + return RawURL( + self.raw_scheme, + self.raw_host, + self.port, + self.raw_path, + ) + + @property + def is_absolute_url(self) -> bool: + """ + Return `True` for absolute URLs such as 'http://example.com/path', + and `False` for relative URLs such as '/path'. + """ + # We don't use `.is_absolute` from `rfc3986` because it treats + # URLs with a fragment portion as not absolute. + # What we actually care about is if the URL provides + # a scheme and hostname to which connections should be made. + return bool(self._uri_reference.scheme and self._uri_reference.host) + + @property + def is_relative_url(self) -> bool: + """ + Return `False` for absolute URLs such as 'http://example.com/path', + and `True` for relative URLs such as '/path'. + """ + return not self.is_absolute_url + + def copy_with(self, **kwargs: typing.Any) -> "URL": + """ + Copy this URL, returning a new URL with some components altered. + Accepts the same set of parameters as the components that are made + available via properties on the `URL` class. + + For example: + + url = httpx.URL("https://www.example.com").copy_with(username="jo@gmail.com", password="a secret") + assert url == "https://jo%40email.com:a%20secret@www.example.com" + """ + return URL(self, **kwargs) + + def copy_set_param(self, key: str, value: typing.Any = None) -> "URL": + return self.copy_with(params=self.params.set(key, value)) + + def copy_add_param(self, key: str, value: typing.Any = None) -> "URL": + return self.copy_with(params=self.params.add(key, value)) + + def copy_remove_param(self, key: str) -> "URL": + return self.copy_with(params=self.params.remove(key)) + + def copy_merge_params(self, params: QueryParamTypes) -> "URL": + return self.copy_with(params=self.params.merge(params)) + + def join(self, url: URLTypes) -> "URL": + """ + Return an absolute URL, using this URL as the base. + + Eg. + + url = httpx.URL("https://www.example.com/test") + url = url.join("/new/path") + assert url == "https://www.example.com/new/path" + """ + from urllib.parse import urljoin + + return URL(urljoin(str(self), str(URL(url)))) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, (URL, str)) and str(self) == str(URL(other)) + + def __str__(self) -> str: + return str(self._uri_reference) + + def __repr__(self) -> str: + scheme, userinfo, host, port, path, query, fragment = self._uri_reference + + if ":" in userinfo: + # Mask any password component. + userinfo = f'{userinfo.split(":")[0]}:[secure]' + + authority = "".join( + [ + f"{userinfo}@" if userinfo else "", + f"[{host}]" if ":" in host else host, + f":{port}" if port is not None else "", + ] + ) + url = "".join( + [ + f"{self.scheme}:" if scheme else "", + f"//{authority}" if authority else "", + path, + f"?{query}" if query is not None else "", + f"#{fragment}" if fragment is not None else "", + ] + ) + + return f"{self.__class__.__name__}({url!r})" + + +class QueryParams(typing.Mapping[str, str]): + """ + URL query parameters, as a multi-dict. + """ + + def __init__( + self, *args: typing.Optional[QueryParamTypes], **kwargs: typing.Any + ) -> None: + assert len(args) < 2, "Too many arguments." + assert not (args and kwargs), "Cannot mix named and unnamed arguments." + + value = args[0] if args else kwargs + + if value is None or isinstance(value, (str, bytes)): + value = value.decode("ascii") if isinstance(value, bytes) else value + self._dict = parse_qs(value, keep_blank_values=True) + elif isinstance(value, QueryParams): + self._dict = {k: list(v) for k, v in value._dict.items()} + else: + dict_value: typing.Dict[typing.Any, typing.List[typing.Any]] = {} + if isinstance(value, (list, tuple)): + # Convert list inputs like: + # [("a", "123"), ("a", "456"), ("b", "789")] + # To a dict representation, like: + # {"a": ["123", "456"], "b": ["789"]} + for item in value: + dict_value.setdefault(item[0], []).append(item[1]) + else: + # Convert dict inputs like: + # {"a": "123", "b": ["456", "789"]} + # To dict inputs where values are always lists, like: + # {"a": ["123"], "b": ["456", "789"]} + dict_value = { + k: list(v) if isinstance(v, (list, tuple)) else [v] + for k, v in value.items() + } + + # Ensure that keys and values are neatly coerced to strings. + # We coerce values `True` and `False` to JSON-like "true" and "false" + # representations, and coerce `None` values to the empty string. + self._dict = { + str(k): [primitive_value_to_str(item) for item in v] + for k, v in dict_value.items() + } + + def keys(self) -> typing.KeysView[str]: + """ + Return all the keys in the query params. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.keys()) == ["a", "b"] + """ + return self._dict.keys() + + def values(self) -> typing.ValuesView[str]: + """ + Return all the values in the query params. If a key occurs more than once + only the first item for that key is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.values()) == ["123", "789"] + """ + return {k: v[0] for k, v in self._dict.items()}.values() + + def items(self) -> typing.ItemsView[str, str]: + """ + Return all items in the query params. If a key occurs more than once + only the first item for that key is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.items()) == [("a", "123"), ("b", "789")] + """ + return {k: v[0] for k, v in self._dict.items()}.items() + + def multi_items(self) -> typing.List[typing.Tuple[str, str]]: + """ + Return all items in the query params. Allow duplicate keys to occur. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")] + """ + multi_items: typing.List[typing.Tuple[str, str]] = [] + for k, v in self._dict.items(): + multi_items.extend([(k, i) for i in v]) + return multi_items + + def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + """ + Get a value from the query param for a given key. If the key occurs + more than once, then only the first value is returned. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert q.get("a") == "123" + """ + if key in self._dict: + return self._dict[str(key)][0] + return default + + def get_list(self, key: str) -> typing.List[str]: + """ + Get all values from the query param for a given key. + + Usage: + + q = httpx.QueryParams("a=123&a=456&b=789") + assert q.get_list("a") == ["123", "456"] + """ + return list(self._dict.get(str(key), [])) + + def set(self, key: str, value: typing.Any = None) -> "QueryParams": + """ + Return a new QueryParams instance, setting the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.set("a", "456") + assert q == httpx.QueryParams("a=456") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict[str(key)] = [primitive_value_to_str(value)] + return q + + def add(self, key: str, value: typing.Any = None) -> "QueryParams": + """ + Return a new QueryParams instance, setting or appending the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.add("a", "456") + assert q == httpx.QueryParams("a=123&a=456") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)] + return q + + def remove(self, key: str) -> "QueryParams": + """ + Return a new QueryParams instance, removing the value of a key. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.remove("a") + assert q == httpx.QueryParams("") + """ + q = QueryParams() + q._dict = dict(self._dict) + q._dict.pop(str(key), None) + return q + + def merge(self, params: typing.Optional[QueryParamTypes] = None) -> "QueryParams": + """ + Return a new QueryParams instance, updated with. + + Usage: + + q = httpx.QueryParams("a=123") + q = q.merge({"b": "456"}) + assert q == httpx.QueryParams("a=123&b=456") + + q = httpx.QueryParams("a=123") + q = q.merge({"a": "456", "b": "789"}) + assert q == httpx.QueryParams("a=456&b=789") + """ + q = QueryParams(params) + q._dict = {**self._dict, **q._dict} + return q + + def __getitem__(self, key: typing.Any) -> str: + return self._dict[key][0] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __bool__(self) -> bool: + return bool(self._dict) + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, self.__class__): + return False + return sorted(self.multi_items()) == sorted(other.multi_items()) + + def __str__(self) -> str: + """ + Note that we use '%20' encoding for spaces, and treat '/' as a safe + character. + + See https://github.com/encode/httpx/issues/2536 and + https://docs.python.org/3/library/urllib.parse.html#urllib.parse.urlencode + """ + return urlencode(self.multi_items()) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + query_string = str(self) + return f"{class_name}({query_string!r})" + + def update(self, params: typing.Optional[QueryParamTypes] = None) -> None: + raise RuntimeError( + "QueryParams are immutable since 0.18.0. " + "Use `q = q.merge(...)` to create an updated copy." + ) + + def __setitem__(self, key: str, value: str) -> None: + raise RuntimeError( + "QueryParams are immutable since 0.18.0. " + "Use `q = q.set(key, value)` to create an updated copy." + ) diff --git a/contrib/python/httpx/httpx/_utils.py b/contrib/python/httpx/httpx/_utils.py new file mode 100644 index 0000000000..1775b1a1ef --- /dev/null +++ b/contrib/python/httpx/httpx/_utils.py @@ -0,0 +1,477 @@ +import codecs +import email.message +import ipaddress +import mimetypes +import os +import re +import time +import typing +from pathlib import Path +from urllib.request import getproxies + +import sniffio + +from ._types import PrimitiveData + +if typing.TYPE_CHECKING: # pragma: no cover + from ._urls import URL + + +_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"} +_HTML5_FORM_ENCODING_REPLACEMENTS.update( + {chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B} +) +_HTML5_FORM_ENCODING_RE = re.compile( + r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()]) +) + + +def normalize_header_key( + value: typing.Union[str, bytes], + lower: bool, + encoding: typing.Optional[str] = None, +) -> bytes: + """ + Coerce str/bytes into a strictly byte-wise HTTP header key. + """ + if isinstance(value, bytes): + bytes_value = value + else: + bytes_value = value.encode(encoding or "ascii") + + return bytes_value.lower() if lower else bytes_value + + +def normalize_header_value( + value: typing.Union[str, bytes], encoding: typing.Optional[str] = None +) -> bytes: + """ + Coerce str/bytes into a strictly byte-wise HTTP header value. + """ + if isinstance(value, bytes): + return value + return value.encode(encoding or "ascii") + + +def primitive_value_to_str(value: "PrimitiveData") -> str: + """ + Coerce a primitive data type into a string value. + + Note that we prefer JSON-style 'true'/'false' for boolean values here. + """ + if value is True: + return "true" + elif value is False: + return "false" + elif value is None: + return "" + return str(value) + + +def is_known_encoding(encoding: str) -> bool: + """ + Return `True` if `encoding` is a known codec. + """ + try: + codecs.lookup(encoding) + except LookupError: + return False + return True + + +def format_form_param(name: str, value: str) -> bytes: + """ + Encode a name/value pair within a multipart form. + """ + + def replacer(match: typing.Match[str]) -> str: + return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)] + + value = _HTML5_FORM_ENCODING_RE.sub(replacer, value) + return f'{name}="{value}"'.encode() + + +# Null bytes; no need to recreate these on each call to guess_json_utf +_null = b"\x00" +_null2 = _null * 2 +_null3 = _null * 3 + + +def guess_json_utf(data: bytes) -> typing.Optional[str]: + # JSON always starts with two ASCII characters, so detection is as + # easy as counting the nulls and from their location and count + # determine the encoding. Also detect a BOM, if present. + sample = data[:4] + if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE): + return "utf-32" # BOM included + if sample[:3] == codecs.BOM_UTF8: + return "utf-8-sig" # BOM included, MS style (discouraged) + if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE): + return "utf-16" # BOM included + nullcount = sample.count(_null) + if nullcount == 0: + return "utf-8" + if nullcount == 2: + if sample[::2] == _null2: # 1st and 3rd are null + return "utf-16-be" + if sample[1::2] == _null2: # 2nd and 4th are null + return "utf-16-le" + # Did not detect 2 valid UTF-16 ascii-range characters + if nullcount == 3: + if sample[:3] == _null3: + return "utf-32-be" + if sample[1:] == _null3: + return "utf-32-le" + # Did not detect a valid UTF-32 ascii-range character + return None + + +def get_ca_bundle_from_env() -> typing.Optional[str]: + if "SSL_CERT_FILE" in os.environ: + ssl_file = Path(os.environ["SSL_CERT_FILE"]) + if ssl_file.is_file(): + return str(ssl_file) + if "SSL_CERT_DIR" in os.environ: + ssl_path = Path(os.environ["SSL_CERT_DIR"]) + if ssl_path.is_dir(): + return str(ssl_path) + return None + + +def parse_header_links(value: str) -> typing.List[typing.Dict[str, str]]: + """ + Returns a list of parsed link headers, for more info see: + https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link + The generic syntax of those is: + Link: < uri-reference >; param1=value1; param2="value2" + So for instance: + Link; '<http:/.../front.jpeg>; type="image/jpeg",<http://.../back.jpeg>;' + would return + [ + {"url": "http:/.../front.jpeg", "type": "image/jpeg"}, + {"url": "http://.../back.jpeg"}, + ] + :param value: HTTP Link entity-header field + :return: list of parsed link headers + """ + links: typing.List[typing.Dict[str, str]] = [] + replace_chars = " '\"" + value = value.strip(replace_chars) + if not value: + return links + for val in re.split(", *<", value): + try: + url, params = val.split(";", 1) + except ValueError: + url, params = val, "" + link = {"url": url.strip("<> '\"")} + for param in params.split(";"): + try: + key, value = param.split("=") + except ValueError: + break + link[key.strip(replace_chars)] = value.strip(replace_chars) + links.append(link) + return links + + +def parse_content_type_charset(content_type: str) -> typing.Optional[str]: + # We used to use `cgi.parse_header()` here, but `cgi` became a dead battery. + # See: https://peps.python.org/pep-0594/#cgi + msg = email.message.Message() + msg["content-type"] = content_type + return msg.get_content_charset(failobj=None) + + +SENSITIVE_HEADERS = {"authorization", "proxy-authorization"} + + +def obfuscate_sensitive_headers( + items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]] +) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]: + for k, v in items: + if to_str(k.lower()) in SENSITIVE_HEADERS: + v = to_bytes_or_str("[secure]", match_type_of=v) + yield k, v + + +def port_or_default(url: "URL") -> typing.Optional[int]: + if url.port is not None: + return url.port + return {"http": 80, "https": 443}.get(url.scheme) + + +def same_origin(url: "URL", other: "URL") -> bool: + """ + Return 'True' if the given URLs share the same origin. + """ + return ( + url.scheme == other.scheme + and url.host == other.host + and port_or_default(url) == port_or_default(other) + ) + + +def is_https_redirect(url: "URL", location: "URL") -> bool: + """ + Return 'True' if 'location' is a HTTPS upgrade of 'url' + """ + if url.host != location.host: + return False + + return ( + url.scheme == "http" + and port_or_default(url) == 80 + and location.scheme == "https" + and port_or_default(location) == 443 + ) + + +def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]: + """Gets proxy information from the environment""" + + # urllib.request.getproxies() falls back on System + # Registry and Config for proxies on Windows and macOS. + # We don't want to propagate non-HTTP proxies into + # our configuration such as 'TRAVIS_APT_PROXY'. + proxy_info = getproxies() + mounts: typing.Dict[str, typing.Optional[str]] = {} + + for scheme in ("http", "https", "all"): + if proxy_info.get(scheme): + hostname = proxy_info[scheme] + mounts[f"{scheme}://"] = ( + hostname if "://" in hostname else f"http://{hostname}" + ) + + no_proxy_hosts = [host.strip() for host in proxy_info.get("no", "").split(",")] + for hostname in no_proxy_hosts: + # See https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html for details + # on how names in `NO_PROXY` are handled. + if hostname == "*": + # If NO_PROXY=* is used or if "*" occurs as any one of the comma + # separated hostnames, then we should just bypass any information + # from HTTP_PROXY, HTTPS_PROXY, ALL_PROXY, and always ignore + # proxies. + return {} + elif hostname: + # NO_PROXY=.google.com is marked as "all://*.google.com, + # which disables "www.google.com" but not "google.com" + # NO_PROXY=google.com is marked as "all://*google.com, + # which disables "www.google.com" and "google.com". + # (But not "wwwgoogle.com") + # NO_PROXY can include domains, IPv6, IPv4 addresses and "localhost" + # NO_PROXY=example.com,::1,localhost,192.168.0.0/16 + if is_ipv4_hostname(hostname): + mounts[f"all://{hostname}"] = None + elif is_ipv6_hostname(hostname): + mounts[f"all://[{hostname}]"] = None + elif hostname.lower() == "localhost": + mounts[f"all://{hostname}"] = None + else: + mounts[f"all://*{hostname}"] = None + + return mounts + + +def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes: + return value.encode(encoding) if isinstance(value, str) else value + + +def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str: + return value if isinstance(value, str) else value.decode(encoding) + + +def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr: + return value if isinstance(match_type_of, str) else value.encode() + + +def unquote(value: str) -> str: + return value[1:-1] if value[0] == value[-1] == '"' else value + + +def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]: + if filename: + return mimetypes.guess_type(filename)[0] or "application/octet-stream" + return None + + +def peek_filelike_length(stream: typing.Any) -> typing.Optional[int]: + """ + Given a file-like stream object, return its length in number of bytes + without reading it into memory. + """ + try: + # Is it an actual file? + fd = stream.fileno() + # Yup, seems to be an actual file. + length = os.fstat(fd).st_size + except (AttributeError, OSError): + # No... Maybe it's something that supports random access, like `io.BytesIO`? + try: + # Assuming so, go to end of stream to figure out its length, + # then put it back in place. + offset = stream.tell() + length = stream.seek(0, os.SEEK_END) + stream.seek(offset) + except (AttributeError, OSError): + # Not even that? Sorry, we're doomed... + return None + + return length + + +class Timer: + async def _get_time(self) -> float: + library = sniffio.current_async_library() + if library == "trio": + import trio + + return trio.current_time() + elif library == "curio": # pragma: no cover + import curio + + return typing.cast(float, await curio.clock()) + + import asyncio + + return asyncio.get_event_loop().time() + + def sync_start(self) -> None: + self.started = time.perf_counter() + + async def async_start(self) -> None: + self.started = await self._get_time() + + def sync_elapsed(self) -> float: + now = time.perf_counter() + return now - self.started + + async def async_elapsed(self) -> float: + now = await self._get_time() + return now - self.started + + +class URLPattern: + """ + A utility class currently used for making lookups against proxy keys... + + # Wildcard matching... + >>> pattern = URLPattern("all://") + >>> pattern.matches(httpx.URL("http://example.com")) + True + + # Witch scheme matching... + >>> pattern = URLPattern("https://") + >>> pattern.matches(httpx.URL("https://example.com")) + True + >>> pattern.matches(httpx.URL("http://example.com")) + False + + # With domain matching... + >>> pattern = URLPattern("https://example.com") + >>> pattern.matches(httpx.URL("https://example.com")) + True + >>> pattern.matches(httpx.URL("http://example.com")) + False + >>> pattern.matches(httpx.URL("https://other.com")) + False + + # Wildcard scheme, with domain matching... + >>> pattern = URLPattern("all://example.com") + >>> pattern.matches(httpx.URL("https://example.com")) + True + >>> pattern.matches(httpx.URL("http://example.com")) + True + >>> pattern.matches(httpx.URL("https://other.com")) + False + + # With port matching... + >>> pattern = URLPattern("https://example.com:1234") + >>> pattern.matches(httpx.URL("https://example.com:1234")) + True + >>> pattern.matches(httpx.URL("https://example.com")) + False + """ + + def __init__(self, pattern: str) -> None: + from ._urls import URL + + if pattern and ":" not in pattern: + raise ValueError( + f"Proxy keys should use proper URL forms rather " + f"than plain scheme strings. " + f'Instead of "{pattern}", use "{pattern}://"' + ) + + url = URL(pattern) + self.pattern = pattern + self.scheme = "" if url.scheme == "all" else url.scheme + self.host = "" if url.host == "*" else url.host + self.port = url.port + if not url.host or url.host == "*": + self.host_regex: typing.Optional[typing.Pattern[str]] = None + elif url.host.startswith("*."): + # *.example.com should match "www.example.com", but not "example.com" + domain = re.escape(url.host[2:]) + self.host_regex = re.compile(f"^.+\\.{domain}$") + elif url.host.startswith("*"): + # *example.com should match "www.example.com" and "example.com" + domain = re.escape(url.host[1:]) + self.host_regex = re.compile(f"^(.+\\.)?{domain}$") + else: + # example.com should match "example.com" but not "www.example.com" + domain = re.escape(url.host) + self.host_regex = re.compile(f"^{domain}$") + + def matches(self, other: "URL") -> bool: + if self.scheme and self.scheme != other.scheme: + return False + if ( + self.host + and self.host_regex is not None + and not self.host_regex.match(other.host) + ): + return False + if self.port is not None and self.port != other.port: + return False + return True + + @property + def priority(self) -> typing.Tuple[int, int, int]: + """ + The priority allows URLPattern instances to be sortable, so that + we can match from most specific to least specific. + """ + # URLs with a port should take priority over URLs without a port. + port_priority = 0 if self.port is not None else 1 + # Longer hostnames should match first. + host_priority = -len(self.host) + # Longer schemes should match first. + scheme_priority = -len(self.scheme) + return (port_priority, host_priority, scheme_priority) + + def __hash__(self) -> int: + return hash(self.pattern) + + def __lt__(self, other: "URLPattern") -> bool: + return self.priority < other.priority + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, URLPattern) and self.pattern == other.pattern + + +def is_ipv4_hostname(hostname: str) -> bool: + try: + ipaddress.IPv4Address(hostname.split("/")[0]) + except Exception: + return False + return True + + +def is_ipv6_hostname(hostname: str) -> bool: + try: + ipaddress.IPv6Address(hostname.split("/")[0]) + except Exception: + return False + return True diff --git a/contrib/python/httpx/httpx/py.typed b/contrib/python/httpx/httpx/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/httpx/httpx/py.typed diff --git a/contrib/python/httpx/ya.make b/contrib/python/httpx/ya.make new file mode 100644 index 0000000000..850e354ef0 --- /dev/null +++ b/contrib/python/httpx/ya.make @@ -0,0 +1,58 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(0.25.0) + +LICENSE(BSD-3-Clause) + +PEERDIR( + contrib/python/certifi + contrib/python/httpcore + contrib/python/idna + contrib/python/sniffio +) + +NO_LINT() + +NO_CHECK_IMPORTS( + httpx._main +) + +PY_SRCS( + TOP_LEVEL + httpx/__init__.py + httpx/__version__.py + httpx/_api.py + httpx/_auth.py + httpx/_client.py + httpx/_compat.py + httpx/_config.py + httpx/_content.py + httpx/_decoders.py + httpx/_exceptions.py + httpx/_main.py + httpx/_models.py + httpx/_multipart.py + httpx/_status_codes.py + httpx/_transports/__init__.py + httpx/_transports/asgi.py + httpx/_transports/base.py + httpx/_transports/default.py + httpx/_transports/mock.py + httpx/_transports/wsgi.py + httpx/_types.py + httpx/_urlparse.py + httpx/_urls.py + httpx/_utils.py +) + +RESOURCE_FILES( + PREFIX contrib/python/httpx/ + .dist-info/METADATA + .dist-info/entry_points.txt + .dist-info/top_level.txt + httpx/py.typed +) + +END() diff --git a/contrib/python/sniffio/.dist-info/METADATA b/contrib/python/sniffio/.dist-info/METADATA new file mode 100644 index 0000000000..22520c72af --- /dev/null +++ b/contrib/python/sniffio/.dist-info/METADATA @@ -0,0 +1,102 @@ +Metadata-Version: 2.1 +Name: sniffio +Version: 1.3.0 +Summary: Sniff out which async library your code is running under +Home-page: https://github.com/python-trio/sniffio +Author: Nathaniel J. Smith +Author-email: njs@pobox.com +License: MIT OR Apache-2.0 +Keywords: async,trio,asyncio +Classifier: License :: OSI Approved :: MIT License +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Framework :: Trio +Classifier: Framework :: AsyncIO +Classifier: Operating System :: POSIX :: Linux +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: Microsoft :: Windows +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Intended Audience :: Developers +Classifier: Development Status :: 5 - Production/Stable +Requires-Python: >=3.7 +License-File: LICENSE +License-File: LICENSE.APACHE2 +License-File: LICENSE.MIT + +.. image:: https://img.shields.io/badge/chat-join%20now-blue.svg + :target: https://gitter.im/python-trio/general + :alt: Join chatroom + +.. image:: https://img.shields.io/badge/docs-read%20now-blue.svg + :target: https://sniffio.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +.. image:: https://img.shields.io/pypi/v/sniffio.svg + :target: https://pypi.org/project/sniffio + :alt: Latest PyPi version + +.. image:: https://img.shields.io/conda/vn/conda-forge/sniffio.svg + :target: https://anaconda.org/conda-forge/sniffio + :alt: Latest conda-forge version + +.. image:: https://travis-ci.org/python-trio/sniffio.svg?branch=master + :target: https://travis-ci.org/python-trio/sniffio + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-trio/sniffio/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-trio/sniffio + :alt: Test coverage + +================================================================= +sniffio: Sniff out which async library your code is running under +================================================================= + +You're writing a library. You've decided to be ambitious, and support +multiple async I/O packages, like `Trio +<https://trio.readthedocs.io>`__, and `asyncio +<https://docs.python.org/3/library/asyncio.html>`__, and ... You've +written a bunch of clever code to handle all the differences. But... +how do you know *which* piece of clever code to run? + +This is a tiny package whose only purpose is to let you detect which +async library your code is running under. + +* Documentation: https://sniffio.readthedocs.io + +* Bug tracker and source code: https://github.com/python-trio/sniffio + +* License: MIT or Apache License 2.0, your choice + +* Contributor guide: https://trio.readthedocs.io/en/latest/contributing.html + +* Code of conduct: Contributors are requested to follow our `code of + conduct + <https://trio.readthedocs.io/en/latest/code-of-conduct.html>`_ + in all project spaces. + +This library is maintained by the Trio project, as a service to the +async Python community as a whole. + + +Quickstart +---------- + +.. code-block:: python3 + + from sniffio import current_async_library + import trio + import asyncio + + async def print_library(): + library = current_async_library() + print("This is:", library) + + # Prints "This is trio" + trio.run(print_library) + + # Prints "This is asyncio" + asyncio.run(print_library()) + +For more details, including how to add support to new async libraries, +`please peruse our fine manual <https://sniffio.readthedocs.io>`__. diff --git a/contrib/python/sniffio/.dist-info/top_level.txt b/contrib/python/sniffio/.dist-info/top_level.txt new file mode 100644 index 0000000000..01c650244d --- /dev/null +++ b/contrib/python/sniffio/.dist-info/top_level.txt @@ -0,0 +1 @@ +sniffio diff --git a/contrib/python/sniffio/LICENSE b/contrib/python/sniffio/LICENSE new file mode 100644 index 0000000000..51f3442917 --- /dev/null +++ b/contrib/python/sniffio/LICENSE @@ -0,0 +1,3 @@ +This software is made available under the terms of *either* of the +licenses found in LICENSE.APACHE2 or LICENSE.MIT. Contributions to are +made under the terms of *both* these licenses. diff --git a/contrib/python/sniffio/LICENSE.APACHE2 b/contrib/python/sniffio/LICENSE.APACHE2 new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/contrib/python/sniffio/LICENSE.APACHE2 @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/contrib/python/sniffio/LICENSE.MIT b/contrib/python/sniffio/LICENSE.MIT new file mode 100644 index 0000000000..b8bb971859 --- /dev/null +++ b/contrib/python/sniffio/LICENSE.MIT @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/contrib/python/sniffio/README.rst b/contrib/python/sniffio/README.rst new file mode 100644 index 0000000000..2a62cea5a0 --- /dev/null +++ b/contrib/python/sniffio/README.rst @@ -0,0 +1,76 @@ +.. image:: https://img.shields.io/badge/chat-join%20now-blue.svg + :target: https://gitter.im/python-trio/general + :alt: Join chatroom + +.. image:: https://img.shields.io/badge/docs-read%20now-blue.svg + :target: https://sniffio.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +.. image:: https://img.shields.io/pypi/v/sniffio.svg + :target: https://pypi.org/project/sniffio + :alt: Latest PyPi version + +.. image:: https://img.shields.io/conda/vn/conda-forge/sniffio.svg + :target: https://anaconda.org/conda-forge/sniffio + :alt: Latest conda-forge version + +.. image:: https://travis-ci.org/python-trio/sniffio.svg?branch=master + :target: https://travis-ci.org/python-trio/sniffio + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-trio/sniffio/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-trio/sniffio + :alt: Test coverage + +================================================================= +sniffio: Sniff out which async library your code is running under +================================================================= + +You're writing a library. You've decided to be ambitious, and support +multiple async I/O packages, like `Trio +<https://trio.readthedocs.io>`__, and `asyncio +<https://docs.python.org/3/library/asyncio.html>`__, and ... You've +written a bunch of clever code to handle all the differences. But... +how do you know *which* piece of clever code to run? + +This is a tiny package whose only purpose is to let you detect which +async library your code is running under. + +* Documentation: https://sniffio.readthedocs.io + +* Bug tracker and source code: https://github.com/python-trio/sniffio + +* License: MIT or Apache License 2.0, your choice + +* Contributor guide: https://trio.readthedocs.io/en/latest/contributing.html + +* Code of conduct: Contributors are requested to follow our `code of + conduct + <https://trio.readthedocs.io/en/latest/code-of-conduct.html>`_ + in all project spaces. + +This library is maintained by the Trio project, as a service to the +async Python community as a whole. + + +Quickstart +---------- + +.. code-block:: python3 + + from sniffio import current_async_library + import trio + import asyncio + + async def print_library(): + library = current_async_library() + print("This is:", library) + + # Prints "This is trio" + trio.run(print_library) + + # Prints "This is asyncio" + asyncio.run(print_library()) + +For more details, including how to add support to new async libraries, +`please peruse our fine manual <https://sniffio.readthedocs.io>`__. diff --git a/contrib/python/sniffio/sniffio/__init__.py b/contrib/python/sniffio/sniffio/__init__.py new file mode 100644 index 0000000000..fb3364d7f1 --- /dev/null +++ b/contrib/python/sniffio/sniffio/__init__.py @@ -0,0 +1,15 @@ +"""Top-level package for sniffio.""" + +__all__ = [ + "current_async_library", "AsyncLibraryNotFoundError", + "current_async_library_cvar" +] + +from ._version import __version__ + +from ._impl import ( + current_async_library, + AsyncLibraryNotFoundError, + current_async_library_cvar, + thread_local, +) diff --git a/contrib/python/sniffio/sniffio/_impl.py b/contrib/python/sniffio/sniffio/_impl.py new file mode 100644 index 0000000000..c1a7bbf218 --- /dev/null +++ b/contrib/python/sniffio/sniffio/_impl.py @@ -0,0 +1,95 @@ +from contextvars import ContextVar +from typing import Optional +import sys +import threading + +current_async_library_cvar = ContextVar( + "current_async_library_cvar", default=None +) # type: ContextVar[Optional[str]] + + +class _ThreadLocal(threading.local): + # Since threading.local provides no explicit mechanism is for setting + # a default for a value, a custom class with a class attribute is used + # instead. + name = None # type: Optional[str] + + +thread_local = _ThreadLocal() + + +class AsyncLibraryNotFoundError(RuntimeError): + pass + + +def current_async_library() -> str: + """Detect which async library is currently running. + + The following libraries are currently supported: + + ================ =========== ============================ + Library Requires Magic string + ================ =========== ============================ + **Trio** Trio v0.6+ ``"trio"`` + **Curio** - ``"curio"`` + **asyncio** ``"asyncio"`` + **Trio-asyncio** v0.8.2+ ``"trio"`` or ``"asyncio"``, + depending on current mode + ================ =========== ============================ + + Returns: + A string like ``"trio"``. + + Raises: + AsyncLibraryNotFoundError: if called from synchronous context, + or if the current async library was not recognized. + + Examples: + + .. code-block:: python3 + + from sniffio import current_async_library + + async def generic_sleep(seconds): + library = current_async_library() + if library == "trio": + import trio + await trio.sleep(seconds) + elif library == "asyncio": + import asyncio + await asyncio.sleep(seconds) + # ... and so on ... + else: + raise RuntimeError(f"Unsupported library {library!r}") + + """ + value = thread_local.name + if value is not None: + return value + + value = current_async_library_cvar.get() + if value is not None: + return value + + # Need to sniff for asyncio + if "asyncio" in sys.modules: + import asyncio + try: + current_task = asyncio.current_task # type: ignore[attr-defined] + except AttributeError: + current_task = asyncio.Task.current_task # type: ignore[attr-defined] + try: + if current_task() is not None: + return "asyncio" + except RuntimeError: + pass + + # Sniff for curio (for now) + if 'curio' in sys.modules: + from curio.meta import curio_running + if curio_running(): + return 'curio' + + raise AsyncLibraryNotFoundError( + "unknown async library, or not in async context" + ) diff --git a/contrib/python/sniffio/sniffio/_version.py b/contrib/python/sniffio/sniffio/_version.py new file mode 100644 index 0000000000..5a5f906bbf --- /dev/null +++ b/contrib/python/sniffio/sniffio/_version.py @@ -0,0 +1,3 @@ +# This file is imported from __init__.py and exec'd from setup.py + +__version__ = "1.3.0" diff --git a/contrib/python/sniffio/sniffio/py.typed b/contrib/python/sniffio/sniffio/py.typed new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/contrib/python/sniffio/sniffio/py.typed diff --git a/contrib/python/sniffio/ya.make b/contrib/python/sniffio/ya.make new file mode 100644 index 0000000000..d0e376d4ca --- /dev/null +++ b/contrib/python/sniffio/ya.make @@ -0,0 +1,25 @@ +# Generated by devtools/yamaker (pypi). + +PY3_LIBRARY() + +VERSION(1.3.0) + +LICENSE(Apache-2.0 AND MIT) + +NO_LINT() + +PY_SRCS( + TOP_LEVEL + sniffio/__init__.py + sniffio/_impl.py + sniffio/_version.py +) + +RESOURCE_FILES( + PREFIX contrib/python/sniffio/ + .dist-info/METADATA + .dist-info/top_level.txt + sniffio/py.typed +) + +END() diff --git a/library/cpp/CMakeLists.darwin-x86_64.txt b/library/cpp/CMakeLists.darwin-x86_64.txt index 0f393b2039..22b108ee05 100644 --- a/library/cpp/CMakeLists.darwin-x86_64.txt +++ b/library/cpp/CMakeLists.darwin-x86_64.txt @@ -63,6 +63,7 @@ add_subdirectory(openssl) add_subdirectory(packedtypes) add_subdirectory(packers) add_subdirectory(pop_count) +add_subdirectory(porto) add_subdirectory(presort) add_subdirectory(protobuf) add_subdirectory(random_provider) diff --git a/library/cpp/CMakeLists.linux-aarch64.txt b/library/cpp/CMakeLists.linux-aarch64.txt index cf47314f07..b42033fd82 100644 --- a/library/cpp/CMakeLists.linux-aarch64.txt +++ b/library/cpp/CMakeLists.linux-aarch64.txt @@ -62,6 +62,7 @@ add_subdirectory(openssl) add_subdirectory(packedtypes) add_subdirectory(packers) add_subdirectory(pop_count) +add_subdirectory(porto) add_subdirectory(presort) add_subdirectory(protobuf) add_subdirectory(random_provider) diff --git a/library/cpp/CMakeLists.linux-x86_64.txt b/library/cpp/CMakeLists.linux-x86_64.txt index 0f393b2039..22b108ee05 100644 --- a/library/cpp/CMakeLists.linux-x86_64.txt +++ b/library/cpp/CMakeLists.linux-x86_64.txt @@ -63,6 +63,7 @@ add_subdirectory(openssl) add_subdirectory(packedtypes) add_subdirectory(packers) add_subdirectory(pop_count) +add_subdirectory(porto) add_subdirectory(presort) add_subdirectory(protobuf) add_subdirectory(random_provider) diff --git a/library/cpp/CMakeLists.windows-x86_64.txt b/library/cpp/CMakeLists.windows-x86_64.txt index 772027a342..8925d1f2bf 100644 --- a/library/cpp/CMakeLists.windows-x86_64.txt +++ b/library/cpp/CMakeLists.windows-x86_64.txt @@ -62,6 +62,7 @@ add_subdirectory(openssl) add_subdirectory(packedtypes) add_subdirectory(packers) add_subdirectory(pop_count) +add_subdirectory(porto) add_subdirectory(presort) add_subdirectory(protobuf) add_subdirectory(random_provider) diff --git a/library/cpp/porto/CMakeLists.darwin-x86_64.txt b/library/cpp/porto/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..499930c4b0 --- /dev/null +++ b/library/cpp/porto/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,9 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(proto) diff --git a/library/cpp/porto/CMakeLists.linux-aarch64.txt b/library/cpp/porto/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..f61df9eb93 --- /dev/null +++ b/library/cpp/porto/CMakeLists.linux-aarch64.txt @@ -0,0 +1,22 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(proto) + +add_library(library-cpp-porto) +target_link_libraries(library-cpp-porto PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + cpp-porto-proto + contrib-libs-protobuf +) +target_sources(library-cpp-porto PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/porto/libporto.cpp + ${CMAKE_SOURCE_DIR}/library/cpp/porto/metrics.cpp +) diff --git a/library/cpp/porto/CMakeLists.linux-x86_64.txt b/library/cpp/porto/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..f61df9eb93 --- /dev/null +++ b/library/cpp/porto/CMakeLists.linux-x86_64.txt @@ -0,0 +1,22 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(proto) + +add_library(library-cpp-porto) +target_link_libraries(library-cpp-porto PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + cpp-porto-proto + contrib-libs-protobuf +) +target_sources(library-cpp-porto PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/porto/libporto.cpp + ${CMAKE_SOURCE_DIR}/library/cpp/porto/metrics.cpp +) diff --git a/library/cpp/porto/CMakeLists.txt b/library/cpp/porto/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/library/cpp/porto/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/library/cpp/porto/CMakeLists.windows-x86_64.txt b/library/cpp/porto/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..499930c4b0 --- /dev/null +++ b/library/cpp/porto/CMakeLists.windows-x86_64.txt @@ -0,0 +1,9 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(proto) diff --git a/library/cpp/porto/libporto.cpp b/library/cpp/porto/libporto.cpp new file mode 100644 index 0000000000..8fd8924300 --- /dev/null +++ b/library/cpp/porto/libporto.cpp @@ -0,0 +1,1547 @@ +#include "libporto.hpp" +#include "metrics.hpp" + +#include <google/protobuf/text_format.h> +#include <google/protobuf/io/zero_copy_stream_impl.h> +#include <google/protobuf/io/coded_stream.h> + +extern "C" { +#include <errno.h> +#include <time.h> +#include <unistd.h> +#include <sys/socket.h> +#include <sys/un.h> + +#ifndef __linux__ +#include <fcntl.h> +#else +#include <sys/epoll.h> +#endif +} + +namespace Porto { + +TPortoApi::~TPortoApi() { + Disconnect(); +} + +EError TPortoApi::SetError(const TString &prefix, int _errno) { + LastErrorMsg = prefix + ": " + strerror(_errno); + + switch (_errno) { + case ENOENT: + LastError = EError::SocketUnavailable; + break; + case EAGAIN: + LastErrorMsg = prefix + ": Timeout exceeded. Timeout value: " + std::to_string(Timeout); + LastError = EError::SocketTimeout; + break; + case EIO: + case EPIPE: + LastError = EError::SocketError; + break; + default: + LastError = EError::Unknown; + break; + } + + Disconnect(); + return LastError; +} + +TString TPortoApi::GetLastError() const { + return EError_Name(LastError) + ":(" + LastErrorMsg + ")"; +} + +EError TPortoApi::Connect(const char *socket_path) { + struct sockaddr_un peer_addr; + socklen_t peer_addr_size; + + Disconnect(); + +#ifdef __linux__ + Fd = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); + if (Fd < 0) + return SetError("socket", errno); +#else + Fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (Fd < 0) + return SetError("socket", errno); + if (fcntl(Fd, F_SETFD, FD_CLOEXEC) < 0) + return SetError("fcntl FD_CLOEXEC", errno); +#endif + + if (Timeout > 0 && SetSocketTimeout(3, Timeout)) + return LastError; + + memset(&peer_addr, 0, sizeof(struct sockaddr_un)); + peer_addr.sun_family = AF_UNIX; + strncpy(peer_addr.sun_path, socket_path, strlen(socket_path)); + + peer_addr_size = sizeof(struct sockaddr_un); + if (connect(Fd, (struct sockaddr *) &peer_addr, peer_addr_size) < 0) + return SetError("connect", errno); + + /* Restore async wait state */ + if (!AsyncWaitNames.empty()) { + for (auto &name: AsyncWaitNames) + Req.mutable_asyncwait()->add_name(name); + for (auto &label: AsyncWaitLabels) + Req.mutable_asyncwait()->add_label(label); + if (AsyncWaitTimeout >= 0) + Req.mutable_asyncwait()->set_timeout_ms(AsyncWaitTimeout * 1000); + return Call(); + } + + return EError::Success; +} + +void TPortoApi::Disconnect() { + if (Fd >= 0) + close(Fd); + Fd = -1; +} + +EError TPortoApi::SetSocketTimeout(int direction, int timeout) { + struct timeval tv; + + if (Fd < 0) + return EError::Success; + + tv.tv_sec = timeout > 0 ? timeout : 0; + tv.tv_usec = 0; + + if ((direction & 1) && setsockopt(Fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof tv)) + return SetError("setsockopt SO_SNDTIMEO", errno); + + if ((direction & 2) && setsockopt(Fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof tv)) + return SetError("setsockopt SO_RCVTIMEO", errno); + + return EError::Success; +} + +EError TPortoApi::SetTimeout(int timeout) { + Timeout = timeout ? timeout : DEFAULT_TIMEOUT; + return SetSocketTimeout(3, Timeout); +} + +EError TPortoApi::SetDiskTimeout(int timeout) { + DiskTimeout = timeout ? timeout : DEFAULT_DISK_TIMEOUT; + return EError::Success; +} + +EError TPortoApi::Send(const TPortoRequest &req) { + google::protobuf::io::FileOutputStream raw(Fd); + + if (!req.IsInitialized()) { + LastError = EError::InvalidMethod; + LastErrorMsg = "Request is not initialized"; + return EError::InvalidMethod; + } + + { + google::protobuf::io::CodedOutputStream output(&raw); + + output.WriteVarint32(req.ByteSize()); + req.SerializeWithCachedSizes(&output); + } + + raw.Flush(); + + int err = raw.GetErrno(); + if (err) + return SetError("send", err); + + return EError::Success; +} + +EError TPortoApi::Recv(TPortoResponse &rsp) { + google::protobuf::io::FileInputStream raw(Fd); + google::protobuf::io::CodedInputStream input(&raw); + + while (true) { + uint32_t size; + + if (!input.ReadVarint32(&size)) + return SetError("recv", raw.GetErrno() ?: EIO); + + auto prev_limit = input.PushLimit(size); + + rsp.Clear(); + + if (!rsp.ParseFromCodedStream(&input)) + return SetError("recv", raw.GetErrno() ?: EIO); + + input.PopLimit(prev_limit); + + if (rsp.has_asyncwait()) { + if (AsyncWaitCallback) + AsyncWaitCallback(rsp.asyncwait()); + + if (AsyncWaitOneShot) + return EError::Success; + + continue; + } + + return EError::Success; + } +} + +EError TPortoApi::Call(const TPortoRequest &req, + TPortoResponse &rsp, + int extra_timeout) { + bool reconnect = AutoReconnect; + EError err = EError::Success; + + if (Fd < 0) { + if (!reconnect) + return SetError("Not connected", EIO); + err = Connect(); + reconnect = false; + } + + if (!err) { + err = Send(req); + if (err == EError::SocketError && reconnect) { + err = Connect(); + if (!err) + err = Send(req); + } + } + + if (!err && extra_timeout && Timeout > 0) + err = SetSocketTimeout(2, extra_timeout > 0 ? (extra_timeout + Timeout) : -1); + + if (!err) + err = Recv(rsp); + + if (extra_timeout && Timeout > 0) { + EError err = SetSocketTimeout(2, Timeout); + (void)err; + } + + if (!err) { + err = LastError = rsp.error(); + LastErrorMsg = rsp.errormsg(); + } + + return err; +} + +EError TPortoApi::Call(int extra_timeout) { + return Call(Req, Rsp, extra_timeout); +} + +EError TPortoApi::Call(const TString &req, + TString &rsp, + int extra_timeout) { + Req.Clear(); + if (!google::protobuf::TextFormat::ParseFromString(req, &Req)) { + LastError = EError::InvalidMethod; + LastErrorMsg = "Cannot parse request"; + rsp = ""; + return EError::InvalidMethod; + } + + EError err = Call(Req, Rsp, extra_timeout); + + rsp = Rsp.DebugString(); + + return err; +} + +EError TPortoApi::GetVersion(TString &tag, TString &revision) { + Req.Clear(); + Req.mutable_version(); + + if (!Call()) { + tag = Rsp.version().tag(); + revision = Rsp.version().revision(); + } + + return LastError; +} + +const TGetSystemResponse *TPortoApi::GetSystem() { + Req.Clear(); + Req.mutable_getsystem(); + if (!Call()) + return &Rsp.getsystem(); + return nullptr; +} + +EError TPortoApi::SetSystem(const TString &key, const TString &val) { + TString rsp; + return Call("SetSystem {" + key + ":" + val + "}", rsp); +} + +/* Container */ + +EError TPortoApi::Create(const TString &name) { + Req.Clear(); + auto req = Req.mutable_create(); + req->set_name(name); + return Call(); +} + +EError TPortoApi::CreateWeakContainer(const TString &name) { + Req.Clear(); + auto req = Req.mutable_createweak(); + req->set_name(name); + return Call(); +} + +EError TPortoApi::Destroy(const TString &name) { + Req.Clear(); + auto req = Req.mutable_destroy(); + req->set_name(name); + return Call(); +} + +const TListResponse *TPortoApi::List(const TString &mask) { + Req.Clear(); + auto req = Req.mutable_list(); + + if(!mask.empty()) + req->set_mask(mask); + + if (!Call()) + return &Rsp.list(); + + return nullptr; +} + +EError TPortoApi::List(TVector<TString> &list, const TString &mask) { + Req.Clear(); + auto req = Req.mutable_list(); + if(!mask.empty()) + req->set_mask(mask); + if (!Call()) + list = TVector<TString>(std::begin(Rsp.list().name()), + std::end(Rsp.list().name())); + return LastError; +} + +const TListPropertiesResponse *TPortoApi::ListProperties() { + Req.Clear(); + Req.mutable_listproperties(); + + if (Call()) + return nullptr; + + bool has_data = false; + for (const auto &prop: Rsp.listproperties().list()) { + if (prop.read_only()) { + has_data = true; + break; + } + } + + if (!has_data) { + TPortoRequest req; + TPortoResponse rsp; + + req.mutable_listdataproperties(); + if (!Call(req, rsp)) { + for (const auto &data: rsp.listdataproperties().list()) { + auto d = Rsp.mutable_listproperties()->add_list(); + d->set_name(data.name()); + d->set_desc(data.desc()); + d->set_read_only(true); + } + } + } + + return &Rsp.listproperties(); +} + +EError TPortoApi::ListProperties(TVector<TString> &properties) { + properties.clear(); + auto rsp = ListProperties(); + if (rsp) { + for (auto &prop: rsp->list()) + properties.push_back(prop.name()); + } + return LastError; +} + +const TGetResponse *TPortoApi::Get(const TVector<TString> &names, + const TVector<TString> &vars, + int flags) { + Req.Clear(); + auto get = Req.mutable_get(); + + for (const auto &n : names) + get->add_name(n); + + for (const auto &v : vars) + get->add_variable(v); + + if (flags & GET_NONBLOCK) + get->set_nonblock(true); + if (flags & GET_SYNC) + get->set_sync(true); + if (flags & GET_REAL) + get->set_real(true); + + if (!Call()) + return &Rsp.get(); + + return nullptr; +} + +EError TPortoApi::GetContainerSpec(const TString &name, TContainer &container) { + Req.Clear(); + TListContainersRequest req; + auto filter = req.add_filters(); + filter->set_name(name); + + TVector<TContainer> containers; + + auto ret = ListContainersBy(req, containers); + if (containers.empty()) + return EError::ContainerDoesNotExist; + + if (!ret) + container = containers[0]; + + return ret; +} + +EError TPortoApi::ListContainersBy(const TListContainersRequest &listContainersRequest, TVector<TContainer> &containers) { + Req.Clear(); + auto req = Req.mutable_listcontainersby(); + *req = listContainersRequest; + + auto ret = Call(); + if (ret) + return ret; + + for (auto &ct : Rsp.listcontainersby().containers()) + containers.push_back(ct); + + return EError::Success; +} + +EError TPortoApi::CreateFromSpec(const TContainerSpec &container, TVector<TVolumeSpec> volumes, bool start) { + Req.Clear(); + auto req = Req.mutable_createfromspec(); + + auto ct = req->mutable_container(); + *ct = container; + + for (auto &volume : volumes) { + auto v = req->add_volumes(); + *v = volume; + } + + req->set_start(start); + + return Call(); +} + +EError TPortoApi::UpdateFromSpec(const TContainerSpec &container) { + Req.Clear(); + auto req = Req.mutable_updatefromspec(); + + auto ct = req->mutable_container(); + *ct = container; + + return Call(); +} + +EError TPortoApi::GetProperty(const TString &name, + const TString &property, + TString &value, + int flags) { + Req.Clear(); + auto req = Req.mutable_getproperty(); + + req->set_name(name); + req->set_property(property); + if (flags & GET_SYNC) + req->set_sync(true); + if (flags & GET_REAL) + req->set_real(true); + + if (!Call()) + value = Rsp.getproperty().value(); + + return LastError; +} + +EError TPortoApi::SetProperty(const TString &name, + const TString &property, + const TString &value) { + Req.Clear(); + auto req = Req.mutable_setproperty(); + + req->set_name(name); + req->set_property(property); + req->set_value(value); + + return Call(); +} + +EError TPortoApi::GetInt(const TString &name, + const TString &property, + const TString &index, + uint64_t &value) { + TString key = property, str; + if (index.size()) + key = property + "[" + index + "]"; + if (!GetProperty(name, key, str)) { + const char *ptr = str.c_str(); + char *end; + errno = 0; + value = strtoull(ptr, &end, 10); + if (errno || end == ptr || *end) { + LastError = EError::InvalidValue; + LastErrorMsg = " value: " + str; + } + } + return LastError; +} + +EError TPortoApi::SetInt(const TString &name, + const TString &property, + const TString &index, + uint64_t value) { + TString key = property; + if (index.size()) + key = property + "[" + index + "]"; + return SetProperty(name, key, ToString(value)); +} + +EError TPortoApi::GetProcMetric(const TVector<TString> &names, + const TString &metric, + TMap<TString, uint64_t> &values) { + auto it = ProcMetrics.find(metric); + + if (it == ProcMetrics.end()) { + LastError = EError::InvalidValue; + LastErrorMsg = " Unknown metric: " + metric; + return LastError; + } + + LastError = it->second->GetValues(names, values, *this); + + if (LastError) + LastErrorMsg = "Unknown error on Get() method"; + + return LastError; +} + +EError TPortoApi::SetLabel(const TString &name, + const TString &label, + const TString &value, + const TString &prev_value) { + Req.Clear(); + auto req = Req.mutable_setlabel(); + + req->set_name(name); + req->set_label(label); + req->set_value(value); + if (prev_value != " ") + req->set_prev_value(prev_value); + + return Call(); +} + +EError TPortoApi::IncLabel(const TString &name, + const TString &label, + int64_t add, + int64_t &result) { + Req.Clear(); + auto req = Req.mutable_inclabel(); + + req->set_name(name); + req->set_label(label); + req->set_add(add); + + EError err = Call(); + + if (Rsp.has_inclabel()) + result = Rsp.inclabel().result(); + + return err; +} + +EError TPortoApi::Start(const TString &name) { + Req.Clear(); + auto req = Req.mutable_start(); + + req->set_name(name); + + return Call(); +} + +EError TPortoApi::Stop(const TString &name, int stop_timeout) { + Req.Clear(); + auto req = Req.mutable_stop(); + + req->set_name(name); + if (stop_timeout >= 0) + req->set_timeout_ms(stop_timeout * 1000); + + return Call(stop_timeout > 0 ? stop_timeout : 0); +} + +EError TPortoApi::Kill(const TString &name, int sig) { + Req.Clear(); + auto req = Req.mutable_kill(); + + req->set_name(name); + req->set_sig(sig); + + return Call(); +} + +EError TPortoApi::Pause(const TString &name) { + Req.Clear(); + auto req = Req.mutable_pause(); + + req->set_name(name); + + return Call(); +} + +EError TPortoApi::Resume(const TString &name) { + Req.Clear(); + auto req = Req.mutable_resume(); + + req->set_name(name); + + return Call(); +} + +EError TPortoApi::Respawn(const TString &name) { + Req.Clear(); + auto req = Req.mutable_respawn(); + + req->set_name(name); + + return Call(); +} + +EError TPortoApi::CallWait(TString &result_state, int wait_timeout) { + time_t deadline = 0; + time_t last_retry = 0; + + if (wait_timeout >= 0) { + deadline = time(nullptr) + wait_timeout; + Req.mutable_wait()->set_timeout_ms(wait_timeout * 1000); + } + +retry: + if (!Call(wait_timeout)) { + if (Rsp.wait().has_state()) + result_state = Rsp.wait().state(); + else if (Rsp.wait().name() == "") + result_state = "timeout"; + else + result_state = "dead"; + } else if (LastError == EError::SocketError && AutoReconnect) { + time_t now = time(nullptr); + + if (wait_timeout < 0 || now < deadline) { + if (wait_timeout >= 0) { + wait_timeout = deadline - now; + Req.mutable_wait()->set_timeout_ms(wait_timeout * 1000); + } + if (last_retry == now) + sleep(1); + last_retry = now; + goto retry; + } + + result_state = "timeout"; + } else + result_state = "unknown"; + + return LastError; +} + +EError TPortoApi::WaitContainer(const TString &name, + TString &result_state, + int wait_timeout) { + Req.Clear(); + auto req = Req.mutable_wait(); + + req->add_name(name); + + return CallWait(result_state, wait_timeout); +} + +EError TPortoApi::WaitContainers(const TVector<TString> &names, + TString &result_name, + TString &result_state, + int wait_timeout) { + Req.Clear(); + auto req = Req.mutable_wait(); + + for (auto &c : names) + req->add_name(c); + + EError err = CallWait(result_state, wait_timeout); + + result_name = Rsp.wait().name(); + + return err; +} + +const TWaitResponse *TPortoApi::Wait(const TVector<TString> &names, + const TVector<TString> &labels, + int wait_timeout) { + Req.Clear(); + auto req = Req.mutable_wait(); + TString result_state; + + for (auto &c : names) + req->add_name(c); + for (auto &label: labels) + req->add_label(label); + + EError err = CallWait(result_state, wait_timeout); + (void)err; + + if (Rsp.has_wait()) + return &Rsp.wait(); + + return nullptr; +} + +EError TPortoApi::AsyncWait(const TVector<TString> &names, + const TVector<TString> &labels, + TWaitCallback callback, + int wait_timeout, + const TString &targetState) { + Req.Clear(); + auto req = Req.mutable_asyncwait(); + + AsyncWaitNames.clear(); + AsyncWaitLabels.clear(); + AsyncWaitTimeout = wait_timeout; + AsyncWaitCallback = callback; + + for (auto &name: names) + req->add_name(name); + for (auto &label: labels) + req->add_label(label); + if (wait_timeout >= 0) + req->set_timeout_ms(wait_timeout * 1000); + if (!targetState.empty()) { + req->set_target_state(targetState); + AsyncWaitOneShot = true; + } else + AsyncWaitOneShot = false; + + if (Call()) { + AsyncWaitCallback = nullptr; + } else { + AsyncWaitNames = names; + AsyncWaitLabels = labels; + } + + return LastError; +} + +EError TPortoApi::StopAsyncWait(const TVector<TString> &names, + const TVector<TString> &labels, + const TString &targetState) { + Req.Clear(); + auto req = Req.mutable_stopasyncwait(); + + AsyncWaitNames.clear(); + AsyncWaitLabels.clear(); + + for (auto &name: names) + req->add_name(name); + for (auto &label: labels) + req->add_label(label); + if (!targetState.empty()) { + req->set_target_state(targetState); + } + + return Call(); +} + +EError TPortoApi::ConvertPath(const TString &path, + const TString &src, + const TString &dest, + TString &res) { + Req.Clear(); + auto req = Req.mutable_convertpath(); + + req->set_path(path); + req->set_source(src); + req->set_destination(dest); + + if (!Call()) + res = Rsp.convertpath().path(); + + return LastError; +} + +EError TPortoApi::AttachProcess(const TString &name, int pid, + const TString &comm) { + Req.Clear(); + auto req = Req.mutable_attachprocess(); + + req->set_name(name); + req->set_pid(pid); + req->set_comm(comm); + + return Call(); +} + +EError TPortoApi::AttachThread(const TString &name, int pid, + const TString &comm) { + Req.Clear(); + auto req = Req.mutable_attachthread(); + + req->set_name(name); + req->set_pid(pid); + req->set_comm(comm); + + return Call(); +} + +EError TPortoApi::LocateProcess(int pid, const TString &comm, + TString &name) { + Req.Clear(); + auto req = Req.mutable_locateprocess(); + + req->set_pid(pid); + req->set_comm(comm); + + if (!Call()) + name = Rsp.locateprocess().name(); + + return LastError; +} + +/* Volume */ + +const TListVolumePropertiesResponse *TPortoApi::ListVolumeProperties() { + Req.Clear(); + Req.mutable_listvolumeproperties(); + + if (!Call()) + return &Rsp.listvolumeproperties(); + + return nullptr; +} + +EError TPortoApi::ListVolumeProperties(TVector<TString> &properties) { + properties.clear(); + auto rsp = ListVolumeProperties(); + if (rsp) { + for (auto &prop: rsp->list()) + properties.push_back(prop.name()); + } + return LastError; +} + +EError TPortoApi::CreateVolume(TString &path, + const TMap<TString, TString> &config) { + Req.Clear(); + auto req = Req.mutable_createvolume(); + + req->set_path(path); + + *(req->mutable_properties()) = + google::protobuf::Map<TString, TString>(config.begin(), config.end()); + + if (!Call(DiskTimeout) && path.empty()) + path = Rsp.createvolume().path(); + + return LastError; +} + +EError TPortoApi::TuneVolume(const TString &path, + const TMap<TString, TString> &config) { + Req.Clear(); + auto req = Req.mutable_tunevolume(); + + req->set_path(path); + + *(req->mutable_properties()) = + google::protobuf::Map<TString, TString>(config.begin(), config.end()); + + return Call(DiskTimeout); +} + +EError TPortoApi::LinkVolume(const TString &path, + const TString &container, + const TString &target, + bool read_only, + bool required) { + Req.Clear(); + auto req = (target.empty() && !required) ? Req.mutable_linkvolume() : + Req.mutable_linkvolumetarget(); + + req->set_path(path); + if (!container.empty()) + req->set_container(container); + if (target != "") + req->set_target(target); + if (read_only) + req->set_read_only(read_only); + if (required) + req->set_required(required); + + return Call(); +} + +EError TPortoApi::UnlinkVolume(const TString &path, + const TString &container, + const TString &target, + bool strict) { + Req.Clear(); + auto req = (target == "***") ? Req.mutable_unlinkvolume() : + Req.mutable_unlinkvolumetarget(); + + req->set_path(path); + if (!container.empty()) + req->set_container(container); + if (target != "***") + req->set_target(target); + if (strict) + req->set_strict(strict); + + return Call(DiskTimeout); +} + +const TListVolumesResponse * +TPortoApi::ListVolumes(const TString &path, + const TString &container) { + Req.Clear(); + auto req = Req.mutable_listvolumes(); + + if (!path.empty()) + req->set_path(path); + + if (!container.empty()) + req->set_container(container); + + if (Call()) + return nullptr; + + auto list = Rsp.mutable_listvolumes(); + + /* compat */ + for (auto v: *list->mutable_volumes()) { + if (v.links().size()) + break; + for (auto &ct: v.containers()) + v.add_links()->set_container(ct); + } + + return list; +} + +EError TPortoApi::ListVolumes(TVector<TString> &paths) { + Req.Clear(); + auto rsp = ListVolumes(); + paths.clear(); + if (rsp) { + for (auto &v : rsp->volumes()) + paths.push_back(v.path()); + } + return LastError; +} + +const TVolumeDescription *TPortoApi::GetVolumeDesc(const TString &path) { + Req.Clear(); + auto rsp = ListVolumes(path); + + if (rsp && rsp->volumes().size()) + return &rsp->volumes(0); + + return nullptr; +} + +const TVolumeSpec *TPortoApi::GetVolume(const TString &path) { + Req.Clear(); + auto req = Req.mutable_getvolume(); + + req->add_path(path); + + if (!Call() && Rsp.getvolume().volume().size()) + return &Rsp.getvolume().volume(0); + + return nullptr; +} + +const TGetVolumeResponse *TPortoApi::GetVolumes(uint64_t changed_since) { + Req.Clear(); + auto req = Req.mutable_getvolume(); + + if (changed_since) + req->set_changed_since(changed_since); + + if (!Call() && Rsp.has_getvolume()) + return &Rsp.getvolume(); + + return nullptr; +} + + +EError TPortoApi::ListVolumesBy(const TGetVolumeRequest &getVolumeRequest, TVector<TVolumeSpec> &volumes) { + Req.Clear(); + auto req = Req.mutable_getvolume(); + *req = getVolumeRequest; + + auto ret = Call(); + if (ret) + return ret; + + for (auto volume : Rsp.getvolume().volume()) + volumes.push_back(volume); + return EError::Success; +} + +EError TPortoApi::CreateVolumeFromSpec(const TVolumeSpec &volume, TVolumeSpec &resultSpec) { + Req.Clear(); + auto req = Req.mutable_newvolume(); + auto vol = req->mutable_volume(); + *vol = volume; + + auto ret = Call(); + if (ret) + return ret; + + resultSpec = Rsp.newvolume().volume(); + + return ret; +} + +/* Layer */ + +EError TPortoApi::ImportLayer(const TString &layer, + const TString &tarball, + bool merge, + const TString &place, + const TString &private_value, + bool verboseError) { + Req.Clear(); + auto req = Req.mutable_importlayer(); + + req->set_layer(layer); + req->set_tarball(tarball); + req->set_merge(merge); + req->set_verbose_error(verboseError); + if (place.size()) + req->set_place(place); + if (private_value.size()) + req->set_private_value(private_value); + + return Call(DiskTimeout); +} + +EError TPortoApi::ExportLayer(const TString &volume, + const TString &tarball, + const TString &compress) { + Req.Clear(); + auto req = Req.mutable_exportlayer(); + + req->set_volume(volume); + req->set_tarball(tarball); + if (compress.size()) + req->set_compress(compress); + + return Call(DiskTimeout); +} + +EError TPortoApi::ReExportLayer(const TString &layer, + const TString &tarball, + const TString &compress) { + Req.Clear(); + auto req = Req.mutable_exportlayer(); + + req->set_volume(""); + req->set_layer(layer); + req->set_tarball(tarball); + if (compress.size()) + req->set_compress(compress); + + return Call(DiskTimeout); +} + +EError TPortoApi::RemoveLayer(const TString &layer, + const TString &place, + bool async) { + Req.Clear(); + auto req = Req.mutable_removelayer(); + + req->set_layer(layer); + req->set_async(async); + if (place.size()) + req->set_place(place); + + return Call(DiskTimeout); +} + +const TListLayersResponse *TPortoApi::ListLayers(const TString &place, + const TString &mask) { + Req.Clear(); + auto req = Req.mutable_listlayers(); + + if (place.size()) + req->set_place(place); + if (mask.size()) + req->set_mask(mask); + + if (Call()) + return nullptr; + + auto list = Rsp.mutable_listlayers(); + + /* compat conversion */ + if (!list->layers().size() && list->layer().size()) { + for (auto &name: list->layer()) { + auto l = list->add_layers(); + l->set_name(name); + l->set_owner_user(""); + l->set_owner_group(""); + l->set_last_usage(0); + l->set_private_value(""); + } + } + + return list; +} + +EError TPortoApi::ListLayers(TVector<TString> &layers, + const TString &place, + const TString &mask) { + Req.Clear(); + auto req = Req.mutable_listlayers(); + + if (place.size()) + req->set_place(place); + if (mask.size()) + req->set_mask(mask); + + if (!Call()) + layers = TVector<TString>(std::begin(Rsp.listlayers().layer()), + std::end(Rsp.listlayers().layer())); + + return LastError; +} + +EError TPortoApi::GetLayerPrivate(TString &private_value, + const TString &layer, + const TString &place) { + Req.Clear(); + auto req = Req.mutable_getlayerprivate(); + + req->set_layer(layer); + if (place.size()) + req->set_place(place); + + if (!Call()) + private_value = Rsp.getlayerprivate().private_value(); + + return LastError; +} + +EError TPortoApi::SetLayerPrivate(const TString &private_value, + const TString &layer, + const TString &place) { + Req.Clear(); + auto req = Req.mutable_setlayerprivate(); + + req->set_layer(layer); + req->set_private_value(private_value); + if (place.size()) + req->set_place(place); + + return Call(); +} + +/* Docker images */ + +DockerImage::DockerImage(const TDockerImage &i) { + Id = i.id(); + for (const auto &tag: i.tags()) + Tags.emplace_back(tag); + for (const auto &digest: i.digests()) + Digests.emplace_back(digest); + for (const auto &layer: i.layers()) + Layers.emplace_back(layer); + if (i.has_size()) + Size = i.size(); + if (i.has_config()) { + auto &cfg = i.config(); + for (const auto &cmd: cfg.cmd()) + Config.Cmd.emplace_back(cmd); + for (const auto &env: cfg.env()) + Config.Env.emplace_back(env); + } +} + +EError TPortoApi::DockerImageStatus(DockerImage &image, + const TString &name, + const TString &place) { + auto req = Req.mutable_dockerimagestatus(); + req->set_name(name); + if (!place.empty()) + req->set_place(place); + EError ret = Call(); + if (!ret && Rsp.dockerimagestatus().has_image()) + image = DockerImage(Rsp.dockerimagestatus().image()); + return ret; +} + +EError TPortoApi::ListDockerImages(std::vector<DockerImage> &images, + const TString &place, + const TString &mask) { + auto req = Req.mutable_listdockerimages(); + if (place.size()) + req->set_place(place); + if (mask.size()) + req->set_mask(mask); + EError ret = Call(); + if (!ret) { + for (const auto &i: Rsp.listdockerimages().images()) + images.emplace_back(i); + } + return ret; +} + +EError TPortoApi::PullDockerImage(DockerImage &image, + const TString &name, + const TString &place, + const TString &auth_token, + const TString &auth_path, + const TString &auth_service) { + auto req = Req.mutable_pulldockerimage(); + req->set_name(name); + if (place.size()) + req->set_place(place); + if (auth_token.size()) + req->set_auth_token(auth_token); + if (auth_path.size()) + req->set_auth_path(auth_path); + if (auth_service.size()) + req->set_auth_service(auth_service); + EError ret = Call(); + if (!ret && Rsp.pulldockerimage().has_image()) + image = DockerImage(Rsp.pulldockerimage().image()); + return ret; +} + +EError TPortoApi::RemoveDockerImage(const TString &name, + const TString &place) { + auto req = Req.mutable_removedockerimage(); + req->set_name(name); + if (place.size()) + req->set_place(place); + return Call(); +} + +/* Storage */ + +const TListStoragesResponse *TPortoApi::ListStorages(const TString &place, + const TString &mask) { + Req.Clear(); + auto req = Req.mutable_liststorages(); + + if (place.size()) + req->set_place(place); + if (mask.size()) + req->set_mask(mask); + + if (Call()) + return nullptr; + + return &Rsp.liststorages(); +} + +EError TPortoApi::ListStorages(TVector<TString> &storages, + const TString &place, + const TString &mask) { + Req.Clear(); + auto req = Req.mutable_liststorages(); + + if (place.size()) + req->set_place(place); + if (mask.size()) + req->set_mask(mask); + + if (!Call()) { + storages.clear(); + for (auto &storage: Rsp.liststorages().storages()) + storages.push_back(storage.name()); + } + + return LastError; +} + +EError TPortoApi::RemoveStorage(const TString &storage, + const TString &place) { + Req.Clear(); + auto req = Req.mutable_removestorage(); + + req->set_name(storage); + if (place.size()) + req->set_place(place); + + return Call(DiskTimeout); +} + +EError TPortoApi::ImportStorage(const TString &storage, + const TString &archive, + const TString &place, + const TString &compression, + const TString &private_value) { + Req.Clear(); + auto req = Req.mutable_importstorage(); + + req->set_name(storage); + req->set_tarball(archive); + if (place.size()) + req->set_place(place); + if (compression.size()) + req->set_compress(compression); + if (private_value.size()) + req->set_private_value(private_value); + + return Call(DiskTimeout); +} + +EError TPortoApi::ExportStorage(const TString &storage, + const TString &archive, + const TString &place, + const TString &compression) { + Req.Clear(); + auto req = Req.mutable_exportstorage(); + + req->set_name(storage); + req->set_tarball(archive); + if (place.size()) + req->set_place(place); + if (compression.size()) + req->set_compress(compression); + + return Call(DiskTimeout); +} + +#ifdef __linux__ +void TAsyncWaiter::MainCallback(const TWaitResponse &event) { + CallbacksCount++; + + auto it = AsyncCallbacks.find(event.name()); + if (it != AsyncCallbacks.end() && it->second.State == event.state()) { + it->second.Callback(event); + AsyncCallbacks.erase(it); + } +} + +int TAsyncWaiter::Repair() { + for (const auto &it : AsyncCallbacks) { + int ret = Api.AsyncWait({it.first}, {}, GetMainCallback(), -1, it.second.State); + if (ret) + return ret; + } + return 0; +} + +void TAsyncWaiter::WatchDog() { + int ret; + auto apiFd = Api.Fd; + + while (true) { + struct epoll_event events[2]; + int nfds = epoll_wait(EpollFd, events, 2, -1); + + if (nfds < 0) { + if (errno == EINTR) + continue; + + Fatal("Can not make epoll_wait", errno); + return; + } + + for (int n = 0; n < nfds; ++n) { + if (events[n].data.fd == apiFd) { + TPortoResponse rsp; + ret = Api.Recv(rsp); + // portod reloaded - async_wait must be repaired + if (ret == EError::SocketError) { + ret = Api.Connect(); + if (ret) { + Fatal("Can not connect to porto api", ret); + return; + } + + ret = Repair(); + if (ret) { + Fatal("Can not repair", ret); + return; + } + + apiFd = Api.Fd; + + struct epoll_event portoEv; + portoEv.events = EPOLLIN; + portoEv.data.fd = apiFd; + if (epoll_ctl(EpollFd, EPOLL_CTL_ADD, apiFd, &portoEv)) { + Fatal("Can not epoll_ctl", errno); + return; + } + } + } else if (events[n].data.fd == Sock) { + ERequestType requestType = static_cast<ERequestType>(RecvInt(Sock)); + + switch (requestType) { + case ERequestType::Add: + HandleAddRequest(); + break; + case ERequestType::Del: + HandleDelRequest(); + break; + case ERequestType::Stop: + return; + case ERequestType::None: + default: + Fatal("Unknown request", static_cast<int>(requestType)); + } + } + } + } +} + +void TAsyncWaiter::SendInt(int fd, int value) { + int ret = write(fd, &value, sizeof(value)); + if (ret != sizeof(value)) + Fatal("Can not send int", errno); +} + +int TAsyncWaiter::RecvInt(int fd) { + int value; + int ret = read(fd, &value, sizeof(value)); + if (ret != sizeof(value)) + Fatal("Can not recv int", errno); + + return value; +} + +void TAsyncWaiter::HandleAddRequest() { + int ret = 0; + + auto it = AsyncCallbacks.find(ReqCt); + if (it != AsyncCallbacks.end()) { + ret = Api.StopAsyncWait({ReqCt}, {}, it->second.State); + AsyncCallbacks.erase(it); + } + + AsyncCallbacks.insert(std::make_pair(ReqCt, TCallbackData({ReqCallback, ReqState}))); + + ret = Api.AsyncWait({ReqCt}, {}, GetMainCallback(), -1, ReqState); + SendInt(Sock, ret); +} + +void TAsyncWaiter::HandleDelRequest() { + int ret = 0; + + auto it = AsyncCallbacks.find(ReqCt); + if (it != AsyncCallbacks.end()) { + ret = Api.StopAsyncWait({ReqCt}, {}, it->second.State); + AsyncCallbacks.erase(it); + } + + SendInt(Sock, ret); +} + +TAsyncWaiter::TAsyncWaiter(std::function<void(const TString &error, int ret)> fatalCallback) + : CallbacksCount(0ul) + , FatalCallback(fatalCallback) +{ + int socketPair[2]; + int ret = socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, socketPair); + if (ret) { + Fatal("Can not make socketpair", ret); + return; + } + + MasterSock = socketPair[0]; + Sock = socketPair[1]; + + ret = Api.Connect(); + if (ret) { + Fatal("Can not connect to porto api", ret); + return; + } + + auto apiFd = Api.Fd; + + EpollFd = epoll_create1(EPOLL_CLOEXEC); + + if (EpollFd == -1) { + Fatal("Can not epoll_create", errno); + return; + } + + struct epoll_event pairEv; + pairEv.events = EPOLLIN; + pairEv.data.fd = Sock; + + struct epoll_event portoEv; + portoEv.events = EPOLLIN; + portoEv.data.fd = apiFd; + + if (epoll_ctl(EpollFd, EPOLL_CTL_ADD, Sock, &pairEv)) { + Fatal("Can not epoll_ctl", errno); + return; + } + + if (epoll_ctl(EpollFd, EPOLL_CTL_ADD, apiFd, &portoEv)) { + Fatal("Can not epoll_ctl", errno); + return; + } + + WatchDogThread = std::unique_ptr<std::thread>(new std::thread(&TAsyncWaiter::WatchDog, this)); +} + +TAsyncWaiter::~TAsyncWaiter() { + SendInt(MasterSock, static_cast<int>(ERequestType::Stop)); + WatchDogThread->join(); + + // pedantic check, that porto api is watching by epoll + if (epoll_ctl(EpollFd, EPOLL_CTL_DEL, Api.Fd, nullptr) || epoll_ctl(EpollFd, EPOLL_CTL_DEL, Sock, nullptr)) { + Fatal("Can not epoll_ctl_del", errno); + } + + close(EpollFd); + close(Sock); + close(MasterSock); +} + +int TAsyncWaiter::Add(const TString &ct, const TString &state, TWaitCallback callback) { + if (FatalError) + return -1; + + ReqCt = ct; + ReqState = state; + ReqCallback = callback; + + SendInt(MasterSock, static_cast<int>(ERequestType::Add)); + return RecvInt(MasterSock); +} + +int TAsyncWaiter::Remove(const TString &ct) { + if (FatalError) + return -1; + + ReqCt = ct; + + SendInt(MasterSock, static_cast<int>(ERequestType::Del)); + return RecvInt(MasterSock); +} +#endif + +} /* namespace Porto */ diff --git a/library/cpp/porto/libporto.hpp b/library/cpp/porto/libporto.hpp new file mode 100644 index 0000000000..e30f22a41e --- /dev/null +++ b/library/cpp/porto/libporto.hpp @@ -0,0 +1,492 @@ +#pragma once + +#include <atomic> +#include <thread> + +#include <util/string/cast.h> +#include <util/generic/hash.h> +#include <util/generic/map.h> +#include <util/generic/vector.h> + +#include <library/cpp/porto/proto/rpc.pb.h> + +namespace Porto { + +constexpr int INFINITE_TIMEOUT = -1; +constexpr int DEFAULT_TIMEOUT = 300; // 5min +constexpr int DEFAULT_DISK_TIMEOUT = 900; // 15min + +constexpr char SOCKET_PATH[] = "/run/portod.socket"; + +typedef std::function<void(const TWaitResponse &event)> TWaitCallback; + +enum { + GET_NONBLOCK = 1, // try lock container state + GET_SYNC = 2, // refresh cached values, cache ttl 5s + GET_REAL = 4, // no faked or inherited values +}; + +struct DockerImage { + std::string Id; + std::vector<std::string> Tags; + std::vector<std::string> Digests; + std::vector<std::string> Layers; + uint64_t Size; + struct Config { + std::vector<std::string> Cmd; + std::vector<std::string> Env; + } Config; + + DockerImage() = default; + DockerImage(const TDockerImage &i); + + DockerImage(const DockerImage &i) = default; + DockerImage(DockerImage &&i) = default; + + DockerImage& operator=(const DockerImage &i) = default; + DockerImage& operator=(DockerImage &&i) = default; +}; + +class TPortoApi { +#ifdef __linux__ + friend class TAsyncWaiter; +#endif +private: + int Fd = -1; + int Timeout = DEFAULT_TIMEOUT; + int DiskTimeout = DEFAULT_DISK_TIMEOUT; + bool AutoReconnect = true; + + EError LastError = EError::Success; + TString LastErrorMsg; + + /* + * These keep last request and response. Method might return + * pointers to Rsp innards -> pointers valid until next call. + */ + TPortoRequest Req; + TPortoResponse Rsp; + + std::vector<TString> AsyncWaitNames; + std::vector<TString> AsyncWaitLabels; + int AsyncWaitTimeout = INFINITE_TIMEOUT; + TWaitCallback AsyncWaitCallback; + bool AsyncWaitOneShot = false; + + EError SetError(const TString &prefix, int _errno) Y_WARN_UNUSED_RESULT; + + EError SetSocketTimeout(int direction, int timeout) Y_WARN_UNUSED_RESULT; + + EError Send(const TPortoRequest &req) Y_WARN_UNUSED_RESULT; + + EError Recv(TPortoResponse &rsp) Y_WARN_UNUSED_RESULT; + + EError Call(int extra_timeout = 0) Y_WARN_UNUSED_RESULT; + + EError CallWait(TString &result_state, int wait_timeout) Y_WARN_UNUSED_RESULT; + +public: + TPortoApi() { } + ~TPortoApi(); + + int GetFd() const { + return Fd; + } + + bool Connected() const { + return Fd >= 0; + } + + EError Connect(const char *socket_path = SOCKET_PATH) Y_WARN_UNUSED_RESULT; + void Disconnect(); + + /* Requires signal(SIGPIPE, SIG_IGN) */ + void SetAutoReconnect(bool auto_reconnect) { + AutoReconnect = auto_reconnect; + } + + /* Request and response timeout in seconds */ + int GetTimeout() const { + return Timeout; + } + EError SetTimeout(int timeout); + + /* Extra timeout for disk operations in seconds */ + int GetDiskTimeout() const { + return DiskTimeout; + } + EError SetDiskTimeout(int timeout); + + EError Error() const Y_WARN_UNUSED_RESULT { + return LastError; + } + + EError GetLastError(TString &msg) const Y_WARN_UNUSED_RESULT { + msg = LastErrorMsg; + return LastError; + } + + /* Returns "LastError:(LastErrorMsg)" */ + TString GetLastError() const Y_WARN_UNUSED_RESULT; + + /* Returns text protobuf */ + TString GetLastRequest() const { + return Req.DebugString(); + } + TString GetLastResponse() const { + return Rsp.DebugString(); + } + + /* To be used for next changed_since */ + uint64_t ResponseTimestamp() const Y_WARN_UNUSED_RESULT { + return Rsp.timestamp(); + } + + // extra_timeout: 0 - none, -1 - infinite + EError Call(const TPortoRequest &req, + TPortoResponse &rsp, + int extra_timeout = 0) Y_WARN_UNUSED_RESULT; + + EError Call(const TString &req, + TString &rsp, + int extra_timeout = 0) Y_WARN_UNUSED_RESULT; + + /* System */ + + EError GetVersion(TString &tag, TString &revision) Y_WARN_UNUSED_RESULT; + + const TGetSystemResponse *GetSystem(); + + EError SetSystem(const TString &key, const TString &val) Y_WARN_UNUSED_RESULT; + + /* Container */ + + const TListPropertiesResponse *ListProperties(); + + EError ListProperties(TVector<TString> &properties) Y_WARN_UNUSED_RESULT; + + const TListResponse *List(const TString &mask = ""); + + EError List(TVector<TString> &names, const TString &mask = "") Y_WARN_UNUSED_RESULT; + + EError Create(const TString &name) Y_WARN_UNUSED_RESULT; + + EError CreateWeakContainer(const TString &name) Y_WARN_UNUSED_RESULT; + + EError Destroy(const TString &name) Y_WARN_UNUSED_RESULT; + + EError Start(const TString &name)Y_WARN_UNUSED_RESULT; + + // stop_timeout: time between SIGTERM and SIGKILL, -1 - default + EError Stop(const TString &name, int stop_timeout = -1) Y_WARN_UNUSED_RESULT; + + EError Kill(const TString &name, int sig = 9) Y_WARN_UNUSED_RESULT; + + EError Pause(const TString &name) Y_WARN_UNUSED_RESULT; + + EError Resume(const TString &name) Y_WARN_UNUSED_RESULT; + + EError Respawn(const TString &name) Y_WARN_UNUSED_RESULT; + + // wait_timeout: 0 - nonblock, -1 - infinite + EError WaitContainer(const TString &name, + TString &result_state, + int wait_timeout = INFINITE_TIMEOUT) Y_WARN_UNUSED_RESULT; + + EError WaitContainers(const TVector<TString> &names, + TString &result_name, + TString &result_state, + int wait_timeout = INFINITE_TIMEOUT) Y_WARN_UNUSED_RESULT; + + const TWaitResponse *Wait(const TVector<TString> &names, + const TVector<TString> &labels, + int wait_timeout = INFINITE_TIMEOUT) Y_WARN_UNUSED_RESULT; + + EError AsyncWait(const TVector<TString> &names, + const TVector<TString> &labels, + TWaitCallback callbacks, + int wait_timeout = INFINITE_TIMEOUT, + const TString &targetState = "") Y_WARN_UNUSED_RESULT; + + EError StopAsyncWait(const TVector<TString> &names, + const TVector<TString> &labels, + const TString &targetState = "") Y_WARN_UNUSED_RESULT; + + const TGetResponse *Get(const TVector<TString> &names, + const TVector<TString> &properties, + int flags = 0) Y_WARN_UNUSED_RESULT; + + /* Porto v5 api */ + EError GetContainerSpec(const TString &name, TContainer &container) Y_WARN_UNUSED_RESULT ; + EError ListContainersBy(const TListContainersRequest &listContainersRequest, TVector<TContainer> &containers) Y_WARN_UNUSED_RESULT; + EError CreateFromSpec(const TContainerSpec &container, TVector<TVolumeSpec> volumes, bool start = false) Y_WARN_UNUSED_RESULT; + EError UpdateFromSpec(const TContainerSpec &container) Y_WARN_UNUSED_RESULT; + + EError GetProperty(const TString &name, + const TString &property, + TString &value, + int flags = 0) Y_WARN_UNUSED_RESULT; + + EError GetProperty(const TString &name, + const TString &property, + const TString &index, + TString &value, + int flags = 0) Y_WARN_UNUSED_RESULT { + return GetProperty(name, property + "[" + index + "]", value, flags); + } + + EError SetProperty(const TString &name, + const TString &property, + const TString &value) Y_WARN_UNUSED_RESULT; + + EError SetProperty(const TString &name, + const TString &property, + const TString &index, + const TString &value) Y_WARN_UNUSED_RESULT { + return SetProperty(name, property + "[" + index + "]", value); + } + + EError GetInt(const TString &name, + const TString &property, + const TString &index, + uint64_t &value) Y_WARN_UNUSED_RESULT; + + EError GetInt(const TString &name, + const TString &property, + uint64_t &value) Y_WARN_UNUSED_RESULT { + return GetInt(name, property, "", value); + } + + EError SetInt(const TString &name, + const TString &property, + const TString &index, + uint64_t value) Y_WARN_UNUSED_RESULT; + + EError SetInt(const TString &name, + const TString &property, + uint64_t value) Y_WARN_UNUSED_RESULT { + return SetInt(name, property, "", value); + } + + EError GetProcMetric(const TVector<TString> &names, + const TString &metric, + TMap<TString, uint64_t> &values); + + EError GetLabel(const TString &name, + const TString &label, + TString &value) Y_WARN_UNUSED_RESULT { + return GetProperty(name, "labels", label, value); + } + + EError SetLabel(const TString &name, + const TString &label, + const TString &value, + const TString &prev_value = " ") Y_WARN_UNUSED_RESULT; + + EError IncLabel(const TString &name, + const TString &label, + int64_t add, + int64_t &result) Y_WARN_UNUSED_RESULT; + + EError IncLabel(const TString &name, + const TString &label, + int64_t add = 1) Y_WARN_UNUSED_RESULT { + int64_t result; + return IncLabel(name, label, add, result); + } + + EError ConvertPath(const TString &path, + const TString &src_name, + const TString &dst_name, + TString &result_path) Y_WARN_UNUSED_RESULT; + + EError AttachProcess(const TString &name, int pid, + const TString &comm = "") Y_WARN_UNUSED_RESULT; + + EError AttachThread(const TString &name, int pid, + const TString &comm = "") Y_WARN_UNUSED_RESULT; + + EError LocateProcess(int pid, + const TString &comm /* = "" */, + TString &name) Y_WARN_UNUSED_RESULT; + + /* Volume */ + + const TListVolumePropertiesResponse *ListVolumeProperties(); + + EError ListVolumeProperties(TVector<TString> &properties) Y_WARN_UNUSED_RESULT; + + const TListVolumesResponse *ListVolumes(const TString &path = "", + const TString &container = ""); + + EError ListVolumes(TVector<TString> &paths) Y_WARN_UNUSED_RESULT; + + const TVolumeDescription *GetVolumeDesc(const TString &path); + + /* Porto v5 api */ + EError ListVolumesBy(const TGetVolumeRequest &getVolumeRequest, TVector<TVolumeSpec> &volumes) Y_WARN_UNUSED_RESULT; + EError CreateVolumeFromSpec(const TVolumeSpec &volume, TVolumeSpec &resultSpec) Y_WARN_UNUSED_RESULT; + + const TVolumeSpec *GetVolume(const TString &path); + + const TGetVolumeResponse *GetVolumes(uint64_t changed_since = 0); + + EError CreateVolume(TString &path, + const TMap<TString, TString> &config) Y_WARN_UNUSED_RESULT; + + EError LinkVolume(const TString &path, + const TString &container = "", + const TString &target = "", + bool read_only = false, + bool required = false) Y_WARN_UNUSED_RESULT; + + EError UnlinkVolume(const TString &path, + const TString &container = "", + const TString &target = "***", + bool strict = false) Y_WARN_UNUSED_RESULT; + + EError TuneVolume(const TString &path, + const TMap<TString, TString> &config) Y_WARN_UNUSED_RESULT; + + /* Layer */ + + const TListLayersResponse *ListLayers(const TString &place = "", + const TString &mask = ""); + + EError ListLayers(TVector<TString> &layers, + const TString &place = "", + const TString &mask = "") Y_WARN_UNUSED_RESULT; + + EError ImportLayer(const TString &layer, + const TString &tarball, + bool merge = false, + const TString &place = "", + const TString &private_value = "", + bool verboseError = false) Y_WARN_UNUSED_RESULT; + + EError ExportLayer(const TString &volume, + const TString &tarball, + const TString &compress = "") Y_WARN_UNUSED_RESULT; + + EError ReExportLayer(const TString &layer, + const TString &tarball, + const TString &compress = "") Y_WARN_UNUSED_RESULT; + + EError RemoveLayer(const TString &layer, + const TString &place = "", + bool async = false) Y_WARN_UNUSED_RESULT; + + EError GetLayerPrivate(TString &private_value, + const TString &layer, + const TString &place = "") Y_WARN_UNUSED_RESULT; + + EError SetLayerPrivate(const TString &private_value, + const TString &layer, + const TString &place = "") Y_WARN_UNUSED_RESULT; + + /* Docker images */ + + EError DockerImageStatus(DockerImage &image, + const TString &name, + const TString &place = "") Y_WARN_UNUSED_RESULT; + + EError ListDockerImages(std::vector<DockerImage> &images, + const TString &place = "", + const TString &mask = "") Y_WARN_UNUSED_RESULT; + + EError PullDockerImage(DockerImage &image, + const TString &name, + const TString &place = "", + const TString &auth_token = "", + const TString &auth_host = "", + const TString &auth_service = "") Y_WARN_UNUSED_RESULT; + + EError RemoveDockerImage(const TString &name, + const TString &place = ""); + + /* Storage */ + + const TListStoragesResponse *ListStorages(const TString &place = "", + const TString &mask = ""); + + EError ListStorages(TVector<TString> &storages, + const TString &place = "", + const TString &mask = "") Y_WARN_UNUSED_RESULT; + + EError RemoveStorage(const TString &storage, + const TString &place = "") Y_WARN_UNUSED_RESULT; + + EError ImportStorage(const TString &storage, + const TString &archive, + const TString &place = "", + const TString &compression = "", + const TString &private_value = "") Y_WARN_UNUSED_RESULT; + + EError ExportStorage(const TString &storage, + const TString &archive, + const TString &place = "", + const TString &compression = "") Y_WARN_UNUSED_RESULT; +}; + +#ifdef __linux__ +class TAsyncWaiter { + struct TCallbackData { + const TWaitCallback Callback; + const TString State; + }; + + enum class ERequestType { + None, + Add, + Del, + Stop, + }; + + THashMap<TString, TCallbackData> AsyncCallbacks; + std::unique_ptr<std::thread> WatchDogThread; + std::atomic<uint64_t> CallbacksCount; + int EpollFd = -1; + TPortoApi Api; + + int Sock, MasterSock; + TString ReqCt; + TString ReqState; + TWaitCallback ReqCallback; + + std::function<void(const TString &error, int ret)> FatalCallback; + bool FatalError = false; + + void MainCallback(const TWaitResponse &event); + inline TWaitCallback GetMainCallback() { + return [this](const TWaitResponse &event) { + MainCallback(event); + }; + } + + int Repair(); + void WatchDog(); + + void SendInt(int fd, int value); + int RecvInt(int fd); + + void HandleAddRequest(); + void HandleDelRequest(); + + void Fatal(const TString &error, int ret) { + FatalError = true; + FatalCallback(error, ret); + } + + public: + TAsyncWaiter(std::function<void(const TString &error, int ret)> fatalCallback); + ~TAsyncWaiter(); + + int Add(const TString &ct, const TString &state, TWaitCallback callback); + int Remove(const TString &ct); + uint64_t InvocationCount() const { + return CallbacksCount; + } +}; +#endif + +} /* namespace Porto */ diff --git a/library/cpp/porto/libporto_ut.cpp b/library/cpp/porto/libporto_ut.cpp new file mode 100644 index 0000000000..9d78397fb8 --- /dev/null +++ b/library/cpp/porto/libporto_ut.cpp @@ -0,0 +1,226 @@ +#include <library/cpp/testing/unittest/registar.h> +#include <libporto.hpp> + +#include <signal.h> +#include <cassert> + +#define Expect(a) assert(a) +#define ExpectEq(a, b) assert((a) == (b)) +#define ExpectNeq(a, b) assert((a) != (b)) +#define ExpectSuccess(ret) assert((ret) == Porto::EError::Success) + +const TString CT_NAME = "test-a"; + +void test_porto() { + TVector<TString> list; + TString str, path; + + signal(SIGPIPE, SIG_IGN); + + Porto::TPortoApi api; + + Expect(!api.Connected()); + Expect(api.GetFd() < 0); + + // Connect + ExpectSuccess(api.Connect()); + + Expect(api.Connected()); + Expect(api.GetFd() >= 0); + + // Disconnect + api.Disconnect(); + + Expect(!api.Connected()); + Expect(api.GetFd() < 0); + + // Auto connect + ExpectSuccess(api.GetVersion(str, str)); + Expect(api.Connected()); + + // Auto reconnect + api.Disconnect(); + ExpectSuccess(api.GetVersion(str, str)); + Expect(api.Connected()); + + // No auto reconnect + api.Disconnect(); + api.SetAutoReconnect(false); + ExpectEq(api.GetVersion(str, str), Porto::EError::SocketError); + api.SetAutoReconnect(true); + + uint64_t val = api.GetTimeout(); + ExpectNeq(val, 0); + ExpectSuccess(api.SetTimeout(5)); + + ExpectSuccess(api.List(list)); + + ExpectSuccess(api.ListProperties(list)); + + ExpectSuccess(api.ListVolumes(list)); + + ExpectSuccess(api.ListVolumeProperties(list)); + + ExpectSuccess(api.ListLayers(list)); + + ExpectSuccess(api.ListStorages(list)); + + ExpectSuccess(api.Call("Version {}", str)); + + ExpectSuccess(api.GetProperty("/", "state", str)); + ExpectEq(str, "meta"); + + ExpectSuccess(api.GetProperty("/", "controllers", "memory", str)); + ExpectEq(str, "true"); + + ExpectSuccess(api.GetProperty("/", "memory_usage", str)); + ExpectNeq(str, "0"); + + val = 0; + ExpectSuccess(api.GetInt("/", "memory_usage", val)); + ExpectNeq(val, 0); + + Porto::TContainer ct; + ExpectSuccess(api.GetContainerSpec("/", ct)); + ExpectEq(ct.spec().name(), "/"); + + ExpectEq(api.GetInt("/", "__wrong__", val), Porto::EError::InvalidProperty); + ExpectEq(api.Error(), Porto::EError::InvalidProperty); + ExpectEq(api.GetLastError(str), Porto::EError::InvalidProperty); + + ExpectSuccess(api.GetContainerSpec(CT_NAME, ct)); + ExpectEq(ct.error().error(), Porto::EError::ContainerDoesNotExist); + + ExpectSuccess(api.CreateWeakContainer(CT_NAME)); + + ExpectSuccess(api.SetProperty(CT_NAME, "memory_limit", "20M")); + ExpectSuccess(api.GetProperty(CT_NAME, "memory_limit", str)); + ExpectEq(str, "20971520"); + + ExpectSuccess(api.SetInt(CT_NAME, "memory_limit", 10<<20)); + ExpectSuccess(api.GetInt(CT_NAME, "memory_limit", val)); + ExpectEq(val, 10485760); + + ExpectSuccess(api.SetLabel(CT_NAME, "TEST.a", ".")); + + ExpectSuccess(api.GetContainerSpec(CT_NAME, ct)); + ExpectEq(ct.status().state(), "stopped"); + ExpectEq(ct.spec().memory_limit(), 10 << 20); + + ExpectSuccess(api.WaitContainer(CT_NAME, str)); + ExpectEq(str, "stopped"); + + ExpectSuccess(api.CreateVolume(path, { + {"containers", CT_NAME}, + {"backend", "native"}, + {"space_limit", "1G"}})); + ExpectNeq(path, ""); + + [[maybe_unused]] auto vd = api.GetVolumeDesc(path); + Expect(vd != nullptr); + ExpectEq(vd->path(), path); + + [[maybe_unused]] auto vs = api.GetVolume(path); + Expect(vs != nullptr); + ExpectEq(vs->path(), path); + + ExpectSuccess(api.SetProperty(CT_NAME, "command", "sleep 1000")); + ExpectSuccess(api.Start(CT_NAME)); + + ExpectSuccess(api.GetProperty(CT_NAME, "state", str)); + ExpectEq(str, "running"); + + ExpectSuccess(api.Destroy(CT_NAME)); + + TMap<TString, uint64_t> values; + auto CT_NAME_0 = CT_NAME + "abcd"; + auto CT_NAME_CHILD = CT_NAME + "/b"; + + ExpectSuccess(api.Create(CT_NAME_0)); + ExpectSuccess(api.SetProperty(CT_NAME_0, "command", "sleep 15")); + ExpectSuccess(api.Start(CT_NAME_0)); + + ExpectSuccess(api.Create(CT_NAME)); + ExpectSuccess(api.SetProperty(CT_NAME, "command", "sleep 10")); + ExpectSuccess(api.GetProcMetric(TVector<TString>{CT_NAME, CT_NAME_0}, "ctxsw", values)); + ExpectEq(values[CT_NAME], 0); + ExpectNeq(values[CT_NAME_0], 0); + + ExpectSuccess(api.Start(CT_NAME)); + ExpectSuccess(api.GetProcMetric(TVector<TString>{CT_NAME}, "ctxsw", values)); + ExpectNeq(values[CT_NAME], 0); + + ExpectSuccess(api.Create(CT_NAME_CHILD)); + ExpectSuccess(api.SetProperty(CT_NAME_CHILD, "command", "sleep 10")); + ExpectSuccess(api.GetProcMetric(TVector<TString>{CT_NAME_CHILD}, "ctxsw", values)); + ExpectEq(values[CT_NAME_CHILD], 0); + + ExpectSuccess(api.Start(CT_NAME_CHILD)); + ExpectSuccess(api.GetProcMetric(TVector<TString>{CT_NAME, CT_NAME_CHILD}, "ctxsw", values)); + ExpectNeq(values[CT_NAME_CHILD], 0); + Expect(values[CT_NAME] > values[CT_NAME_CHILD]); + + ExpectSuccess(api.Stop(CT_NAME_CHILD)); + ExpectSuccess(api.GetProcMetric(TVector<TString>{CT_NAME_CHILD}, "ctxsw", values)); + ExpectEq(values[CT_NAME_CHILD], 0); + + ExpectSuccess(api.Stop(CT_NAME)); + ExpectSuccess(api.GetProcMetric(TVector<TString>{CT_NAME}, "ctxsw", values)); + ExpectEq(values[CT_NAME], 0); + + ExpectSuccess(api.Destroy(CT_NAME_CHILD)); + ExpectSuccess(api.Destroy(CT_NAME)); + ExpectSuccess(api.Destroy(CT_NAME_0)); + +#ifdef __linux__ + // test TAsyncWaiter + Porto::TAsyncWaiter waiter([](const TString &error, int ret) { + Y_UNUSED(error); + Y_UNUSED(ret); + + Expect(false); + }); + + TString result; + waiter.Add("abc", "starting", [&result](const Porto::TWaitResponse &event) { + result += event.name() + "-" + event.state(); + }); + + TString name = "abcdef"; + ExpectSuccess(api.Create(name)); + ExpectSuccess(api.SetProperty(name, "command", "sleep 1")); + ExpectSuccess(api.Start(name)); + ExpectSuccess(api.Destroy(name)); + ExpectEq(result, ""); + + // callback work only once + for (int i = 0; i < 3; i++) { + name = "abc"; + ExpectSuccess(api.Create(name)); + ExpectSuccess(api.SetProperty(name, "command", "sleep 1")); + ExpectSuccess(api.Start(name)); + ExpectSuccess(api.Destroy(name)); + ExpectEq(result, "abc-starting"); + } + + waiter.Add("abc", "starting", [&result](const Porto::TWaitResponse &event) { + result += event.name() + "-" + event.state(); + }); + waiter.Remove("abc"); + + name = "abc"; + ExpectSuccess(api.Create(name)); + ExpectSuccess(api.SetProperty(name, "command", "sleep 1")); + ExpectSuccess(api.Start(name)); + ExpectSuccess(api.Destroy(name)); + ExpectEq(result, "abc-starting"); +#endif + + api.Disconnect(); +} + +Y_UNIT_TEST_SUITE(Porto) { + Y_UNIT_TEST(All) { + test_porto(); + } +} diff --git a/library/cpp/porto/metrics.cpp b/library/cpp/porto/metrics.cpp new file mode 100644 index 0000000000..7d17d0aee4 --- /dev/null +++ b/library/cpp/porto/metrics.cpp @@ -0,0 +1,183 @@ +#include "metrics.hpp" + +#include <util/folder/path.h> +#include <util/generic/maybe.h> +#include <util/stream/file.h> + +namespace Porto { + +TMap<TString, TMetric*> ProcMetrics; + +TMetric::TMetric(const TString& name, EMetric metric) { + Name = name; + Metric = metric; + ProcMetrics[name] = this; +} + +void TMetric::ClearValues(const TVector<TString>& names, TMap<TString, uint64_t>& values) const { + values.clear(); + + for (const auto&name : names) + values[name] = 0; +} + +EError TMetric::GetValues(const TVector<TString>& names, TMap<TString, uint64_t>& values, TPortoApi& api) const { + ClearValues(names, values); + + int procFd = open("/proc", O_RDONLY | O_CLOEXEC | O_DIRECTORY | O_NOCTTY); + TFileHandle procFdHandle(procFd); + if (procFd == -1) + return EError::Unknown; + + TVector<TString> tids; + TidSnapshot(tids); + + auto getResponse = api.Get(names, TVector<TString>{"cgroups[freezer]"}); + + if (getResponse == nullptr) + return EError::Unknown; + + const auto containersCgroups = GetCtFreezerCgroups(getResponse); + + for (const auto& tid : tids) { + const TString tidCgroup = GetFreezerCgroup(procFd, tid); + if (tidCgroup == "") + continue; + + TMaybe<uint64_t> metricValue; + + for (const auto& keyval : containersCgroups) { + const TString& containerCgroup = keyval.second; + if (MatchCgroups(tidCgroup, containerCgroup)) { + if (!metricValue) + metricValue = GetMetric(procFd, tid); + values[keyval.first] += *metricValue; + } + } + } + + return EError::Success; +} + +uint64_t TMetric::GetTidSchedMetricValue(int procFd, const TString& tid, const TString& metricName) const { + const TString schedPath = tid + "/sched"; + try { + int fd = openat(procFd, schedPath.c_str(), O_RDONLY | O_CLOEXEC | O_NOCTTY, 0); + TFile file(fd); + if (!file.IsOpen()) + return 0ul; + + TIFStream iStream(file); + TString line; + while (iStream.ReadLine(line)) { + auto metricPos = line.find(metricName); + + if (metricPos != TString::npos) { + auto valuePos = metricPos; + + while (valuePos < line.size() && !::isdigit(line[valuePos])) + ++valuePos; + + TString value = line.substr(valuePos); + if (!value.empty() && IsNumber(value)) + return IntFromString<uint64_t, 10>(value); + } + } + } + catch(...) {} + + return 0ul; +} + +void TMetric::GetPidTasks(const TString& pid, TVector<TString>& tids) const { + TFsPath task("/proc/" + pid + "/task"); + TVector<TString> rawTids; + + try { + task.ListNames(rawTids); + } + catch(...) {} + + for (const auto& tid : rawTids) { + tids.push_back(tid); + } +} + +void TMetric::TidSnapshot(TVector<TString>& tids) const { + TFsPath proc("/proc"); + TVector<TString> rawPids; + + try { + proc.ListNames(rawPids); + } + catch(...) {} + + for (const auto& pid : rawPids) { + if (IsNumber(pid)) + GetPidTasks(pid, tids); + } +} + +TString TMetric::GetFreezerCgroup(int procFd, const TString& tid) const { + const TString cgroupPath = tid + "/cgroup"; + try { + int fd = openat(procFd, cgroupPath.c_str(), O_RDONLY | O_CLOEXEC | O_NOCTTY, 0); + TFile file(fd); + if (!file.IsOpen()) + return TString(); + + TIFStream iStream(file); + TString line; + + while (iStream.ReadLine(line)) { + static const TString freezer = ":freezer:"; + auto freezerPos = line.find(freezer); + + if (freezerPos != TString::npos) { + line = line.substr(freezerPos + freezer.size()); + return line; + } + } + } + catch(...){} + + return TString(); +} + +TMap<TString, TString> TMetric::GetCtFreezerCgroups(const TGetResponse* response) const { + TMap<TString, TString> containersProps; + + for (const auto& ctGetListResponse : response->list()) { + for (const auto& keyval : ctGetListResponse.keyval()) { + if (!keyval.error()) { + TString value = keyval.value(); + static const TString freezerPath = "/sys/fs/cgroup/freezer"; + + if (value.find(freezerPath) != TString::npos) + value = value.substr(freezerPath.size()); + + containersProps[ctGetListResponse.name()] = value; + } + } + } + + return containersProps; +} + +bool TMetric::MatchCgroups(const TString& tidCgroup, const TString& ctCgroup) const { + if (tidCgroup.size() <= ctCgroup.size()) + return tidCgroup == ctCgroup; + return ctCgroup == tidCgroup.substr(0, ctCgroup.size()) && tidCgroup[ctCgroup.size()] == '/'; +} + +class TCtxsw : public TMetric { +public: + TCtxsw() : TMetric(M_CTXSW, EMetric::CTXSW) + {} + + uint64_t GetMetric(int procFd, const TString& tid) const override { + return GetTidSchedMetricValue(procFd, tid, "nr_switches"); + } +} static Ctxsw; + +} /* namespace Porto */ diff --git a/library/cpp/porto/metrics.hpp b/library/cpp/porto/metrics.hpp new file mode 100644 index 0000000000..5b2ffde8d9 --- /dev/null +++ b/library/cpp/porto/metrics.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include "libporto.hpp" + +#include <util/generic/map.h> +#include <util/generic/vector.h> +#include <util/string/cast.h> +#include <util/string/type.h> + +#include <library/cpp/porto/proto/rpc.pb.h> +namespace Porto { + +constexpr const char *M_CTXSW = "ctxsw"; + +enum class EMetric { + NONE, + CTXSW, +}; + +class TMetric { +public: + TString Name; + EMetric Metric; + + TMetric(const TString& name, EMetric metric); + + void ClearValues(const TVector<TString>& names, TMap<TString, uint64_t>& values) const; + EError GetValues(const TVector<TString>& names, TMap<TString, uint64_t>& values, TPortoApi& api) const; + + // Returns value of metric from /proc/tid/sched for some tid + uint64_t GetTidSchedMetricValue(int procFd, const TString& tid, const TString& metricName) const; + + void TidSnapshot(TVector<TString>& tids) const; + void GetPidTasks(const TString& pid, TVector<TString>& tids) const; + + // Returns freezer cgroup from /proc/tid/cgroup + TString GetFreezerCgroup(int procFd, const TString& tid) const; + + // Resurns clean cgroup[freezer] for containers names + TMap<TString, TString> GetCtFreezerCgroups(const TGetResponse* response) const; + + // Verify inclusion of container cgroup in process cgroup + bool MatchCgroups(const TString& tidCgroup, const TString& ctCgroup) const; + +private: + virtual uint64_t GetMetric(int procFd, const TString& tid) const = 0; +}; + +extern TMap<TString, TMetric*> ProcMetrics; +} /* namespace Porto */ diff --git a/library/cpp/porto/proto/CMakeLists.darwin-x86_64.txt b/library/cpp/porto/proto/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..9b8be22fe6 --- /dev/null +++ b/library/cpp/porto/proto/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,43 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(cpp-porto-proto) +target_link_libraries(cpp-porto-proto PUBLIC + contrib-libs-cxxsupp + yutil + contrib-libs-protobuf +) +target_proto_messages(cpp-porto-proto PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/porto/proto/rpc.proto +) +target_proto_addincls(cpp-porto-proto + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(cpp-porto-proto + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) diff --git a/library/cpp/porto/proto/CMakeLists.linux-aarch64.txt b/library/cpp/porto/proto/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..ba0aa7060d --- /dev/null +++ b/library/cpp/porto/proto/CMakeLists.linux-aarch64.txt @@ -0,0 +1,44 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(cpp-porto-proto) +target_link_libraries(cpp-porto-proto PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + contrib-libs-protobuf +) +target_proto_messages(cpp-porto-proto PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/porto/proto/rpc.proto +) +target_proto_addincls(cpp-porto-proto + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(cpp-porto-proto + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) diff --git a/library/cpp/porto/proto/CMakeLists.linux-x86_64.txt b/library/cpp/porto/proto/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..ba0aa7060d --- /dev/null +++ b/library/cpp/porto/proto/CMakeLists.linux-x86_64.txt @@ -0,0 +1,44 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(cpp-porto-proto) +target_link_libraries(cpp-porto-proto PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + contrib-libs-protobuf +) +target_proto_messages(cpp-porto-proto PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/porto/proto/rpc.proto +) +target_proto_addincls(cpp-porto-proto + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(cpp-porto-proto + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) diff --git a/library/cpp/porto/proto/CMakeLists.txt b/library/cpp/porto/proto/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/library/cpp/porto/proto/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/library/cpp/porto/proto/CMakeLists.windows-x86_64.txt b/library/cpp/porto/proto/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..9b8be22fe6 --- /dev/null +++ b/library/cpp/porto/proto/CMakeLists.windows-x86_64.txt @@ -0,0 +1,43 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(cpp-porto-proto) +target_link_libraries(cpp-porto-proto PUBLIC + contrib-libs-cxxsupp + yutil + contrib-libs-protobuf +) +target_proto_messages(cpp-porto-proto PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/porto/proto/rpc.proto +) +target_proto_addincls(cpp-porto-proto + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(cpp-porto-proto + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) diff --git a/library/cpp/porto/proto/rpc.proto b/library/cpp/porto/proto/rpc.proto new file mode 100644 index 0000000000..abb8c63905 --- /dev/null +++ b/library/cpp/porto/proto/rpc.proto @@ -0,0 +1,1607 @@ +syntax = "proto2"; + +option go_package = "github.com/ydb-platform/ydb/library/cpp/porto/proto;myapi"; + +/* + Portod daemon listens on /run/portod.socket unix socket. + + Request: Varint32 length, TPortoRequest request + Response: Varint32 length, TPortoResponse response + + Command is defined by optional nested message field. + Result will be in nested message with the same name. + + Push notification is send as out of order response. + + Access level depends on client container and uid. + + See defails in porto.md or manpage porto + + TContainer, TVolume and related methods are Porto v5 API. +*/ + +package Porto; + +// List of error codes +enum EError { + // No errors occured. + Success = 0; + + // Unclassified error, usually unexpected syscall fail. + Unknown = 1; + + // Unknown method or bad request. + InvalidMethod = 2; + + // Container with specified name already exists. + ContainerAlreadyExists = 3; + + // Container with specified name doesn't exist. + ContainerDoesNotExist = 4; + + // Unknown property specified. + InvalidProperty = 5; + + // Unknown data specified. + InvalidData = 6; + + // Invalid value of property or data. + InvalidValue = 7; + + // Can't perform specified operation in current container state. + InvalidState = 8; + + // Permanent faulure: old kernel version, missing feature, configuration, etc. + NotSupported = 9; + + // Temporary failure: too much objects, not enough memory, etc. + ResourceNotAvailable = 10; + + // Insufficient rights for performing requested operation. + Permission = 11; + + // Can't create new volume with specified name, because there is already one. + VolumeAlreadyExists = 12; + + // Volume with specified name doesn't exist. + VolumeNotFound = 13; + + // Not enough disk space. + NoSpace = 14; + + // Object in use. + Busy = 15; + + // Volume already linked with container. + VolumeAlreadyLinked = 16; + + // Volume not linked with container. + VolumeNotLinked = 17; + + // Layer with this name already exists. + LayerAlreadyExists = 18; + + // Layer with this name not found. + LayerNotFound = 19; + + // Property has no value, data source permanently not available. + NoValue = 20; + + // Volume under construction or destruction. + VolumeNotReady = 21; + + // Cannot parse or execute command. + InvalidCommand = 22; + + // Error code is lost or came from future. + LostError = 23; + + // Device node not found. + DeviceNotFound = 24; + + // Path does not match restricitons or does not exist. + InvalidPath = 25; + + // Wrong or unuseable ip address. + InvalidNetworkAddress = 26; + + // Porto in system maintenance state. + PortoFrozen = 27; + + // Label with this name is not set. + LabelNotFound = 28; + + // Label name does not meet restrictions. + InvalidLabel = 29; + + // Errors in tar, on archive extraction + HelperError = 30; + HelperFatalError = 31; + + // Generic object not found. + NotFound = 404; + + // Reserved error code for client library. + SocketError = 502; + + // Reserved error code for client library. + SocketUnavailable = 503; + + // Reserved error code for client library. + SocketTimeout = 504; + + // Portod close client connections on reload + PortodReloaded = 505; + + // Reserved error code for taints. + Taint = 666; + + // Reserved error codes 700-800 to docker + Docker = 700; + DockerImageNotFound = 701; + + // Internal error code, not for users. + Queued = 1000; +} + + +message TPortoRequest { + + /* System methods */ + + // Get portod version + optional TVersionRequest Version = 14; + + // Get portod statistics + optional TGetSystemRequest GetSystem = 300; + + // Change portod state (for host root user only) + optional TSetSystemRequest SetSystem = 301; + + /* Container methods */ + + // Create new container + optional TCreateRequest Create = 1; + + // Create new contaienr and auto destroy when client disconnects + optional TCreateRequest CreateWeak = 17; + + // Force kill all and destroy container and nested containers + optional TDestroyRequest Destroy = 2; + + // List container names in current namespace + optional TListRequest List = 3; + + // Start contianer and parents if needed + optional TStartRequest Start = 7; + + // Kill all and stop container + optional TStopRequest Stop = 8; + + // Freeze execution + optional TPauseRequest Pause = 9; + + // Resume execution + optional TResumeRequest Resume = 10; + + // Send signal to main process + optional TKillRequest Kill = 13; + + // Restart dead container + optional TRespawnRequest Respawn = 18; + + // Wait for process finish or change of labels + optional TWaitRequest Wait = 16; + + // Subscribe to push notifictaions + optional TWaitRequest AsyncWait = 19; + optional TWaitRequest StopAsyncWait = 128; + + /* Container properties */ + + // List supported container properties + optional TListPropertiesRequest ListProperties = 11; + + // Get one property + optional TGetPropertyRequest GetProperty = 4; + + // Set one property + optional TSetPropertyRequest SetProperty = 5; + + // Deprecated, now data properties are also read-only properties + optional TListDataPropertiesRequest ListDataProperties = 12; + optional TGetDataPropertyRequest GetDataProperty = 6; + + // Get multiple properties for multiple containers + optional TGetRequest Get = 15; + + /* Container API based on TContainer (Porto v5 API) */ + + // Create, configure and start container with volumes + optional TCreateFromSpecRequest CreateFromSpec = 230; + + // Set multiple container properties + optional TUpdateFromSpecRequest UpdateFromSpec = 231; + + // Get multiple properties for multiple containers + optional TListContainersRequest ListContainersBy = 232; + + // Modify symlink in container + optional TSetSymlinkRequest SetSymlink = 125; + + /* Container labels - user defined key-value */ + + // Find containers with labels + optional TFindLabelRequest FindLabel = 20; + + // Atomic compare and set for label + optional TSetLabelRequest SetLabel = 21; + + // Atomic add and return for counter in label + optional TIncLabelRequest IncLabel = 22; + + /* Volume methods */ + + optional TListVolumePropertiesRequest ListVolumeProperties = 103; + + // List layers and their properties + optional TListVolumesRequest ListVolumes = 107; + + // Create, configure and build volume + optional TCreateVolumeRequest CreateVolume = 104; + + // Change volume properties - for now only resize + optional TTuneVolumeRequest TuneVolume = 108; + + // Volume API based on TVolume (Porto v5 API) + optional TNewVolumeRequest NewVolume = 126; + optional TGetVolumeRequest GetVolume = 127; + + // Add link between container and volume + optional TLinkVolumeRequest LinkVolume = 105; + + // Same as LinkVolume but fails if target is not supported + optional TLinkVolumeRequest LinkVolumeTarget = 120; + + // Del link between container and volume + optional TUnlinkVolumeRequest UnlinkVolume = 106; + + // Same as UnlinkVolume but fails if target is not supported + optional TUnlinkVolumeRequest UnlinkVolumeTarget = 121; + + /* Layer methods */ + + // Import layer from tarball + optional TImportLayerRequest ImportLayer = 110; + + // Remove layer + optional TRemoveLayerRequest RemoveLayer = 111; + + // List layers + optional TListLayersRequest ListLayers = 112; + + // Export volume or layer into tarball + optional TExportLayerRequest ExportLayer = 113; + + // Get/set layer private (user defined string) + optional TGetLayerPrivateRequest GetLayerPrivate = 114; + optional TSetLayerPrivateRequest SetLayerPrivate = 115; + + /* Storage methods */ + + // Volume creation creates required storage if missing + + // List storages and meta storages + optional TListStoragesRequest ListStorages = 116; + + optional TRemoveStorageRequest RemoveStorage = 117; + + // Import storage from tarball + optional TImportStorageRequest ImportStorage = 118; + + // Export storage into tarball + optional TExportStorageRequest ExportStorage = 119; + + // Meta storage (bundle for storages and layers) + + optional TMetaStorage CreateMetaStorage = 122; + optional TMetaStorage ResizeMetaStorage = 123; + optional TMetaStorage RemoveMetaStorage = 124; + + // Convert path between containers + optional TConvertPathRequest ConvertPath = 200; + + /* Process methods */ + + // Attach process to nested container + optional TAttachProcessRequest AttachProcess = 201; + + // Find container for process + optional TLocateProcessRequest LocateProcess = 202; + + // Attach one thread to nexted container + optional TAttachProcessRequest AttachThread = 203; + + /* Docker images API */ + + optional TDockerImageStatusRequest dockerImageStatus = 303; + optional TDockerImageListRequest listDockerImages = 304; + optional TDockerImagePullRequest pullDockerImage = 305; + optional TDockerImageRemoveRequest removeDockerImage = 306; +} + + +message TPortoResponse { + // Actually always set, hack for adding new error codes + optional EError error = 1 [ default = LostError ]; + + // Human readable comment - must be shown to user as is + optional string errorMsg = 2; + + optional uint64 timestamp = 1000; // for next changed_since + + /* System methods */ + + optional TVersionResponse Version = 8; + + optional TGetSystemResponse GetSystem = 300; + optional TSetSystemResponse SetSystem = 301; + + /* Container methods */ + + optional TListResponse List = 3; + + optional TWaitResponse Wait = 11; + + optional TWaitResponse AsyncWait = 19; + + /* Container properties */ + + optional TListPropertiesResponse ListProperties = 6; + + optional TGetPropertyResponse GetProperty = 4; + + + // Deprecated + optional TListDataPropertiesResponse ListDataProperties = 7; + optional TGetDataPropertyResponse GetDataProperty = 5; + + optional TGetResponse Get = 10; + + /* Container API based on TContainer (Porto v5 API) */ + + optional TListContainersResponse ListContainersBy = 232; + + /* Container Labels */ + + optional TFindLabelResponse FindLabel = 20; + optional TSetLabelResponse SetLabel = 21; + optional TIncLabelResponse IncLabel = 22; + + /* Volume methods */ + + optional TListVolumePropertiesResponse ListVolumeProperties = 12; + + optional TListVolumesResponse ListVolumes = 9; + + optional TVolumeDescription CreateVolume = 13; + + optional TNewVolumeResponse NewVolume = 126; + + optional TGetVolumeResponse GetVolume = 127; + + optional TListLayersResponse ListLayers = 14; + + optional TGetLayerPrivateResponse GetLayerPrivate = 16; + + // List storages and meta storages + optional TListStoragesResponse ListStorages = 17; + + optional TConvertPathResponse ConvertPath = 15; + + // Process + optional TLocateProcessResponse LocateProcess = 18; + + /* Docker images API */ + + optional TDockerImageStatusResponse dockerImageStatus = 302; + optional TDockerImageListResponse listDockerImages = 303; + optional TDockerImagePullResponse pullDockerImage = 304; +} + + +// Common objects + + +message TStringMap { + message TStringMapEntry { + optional string key = 1; + optional string val = 2; + } + // TODO replace with map + // map<string, string> map = 1; + repeated TStringMapEntry map = 1; + optional bool merge = 2; // in, default: replace +} + + +message TUintMap { + message TUintMapEntry { + optional string key = 1; + optional uint64 val = 2; + } + // TODO replace with map + // map<string, uint64> map = 1; + repeated TUintMapEntry map = 1; + optional bool merge = 2; // in, default: replace +} + + +message TError { + optional EError error = 1 [ default = LostError ]; + optional string msg = 2; +} + + +message TCred { + optional string user = 1; // requires user or uid or both + optional fixed32 uid = 2; + optional string group = 3; + optional fixed32 gid = 4; + repeated fixed32 grp = 5; // out, supplementary groups +} + + +message TCapabilities { + repeated string cap = 1; + optional string hex = 2; // out +} + + +message TContainerCommandArgv { + repeated string argv = 1; +} + + +// Container + + +message TContainerEnvVar { + optional string name = 1; //required + optional string value = 2; + optional bool unset = 3; // out + optional string salt = 4; + optional string hash = 5; +} + +message TContainerEnv { + repeated TContainerEnvVar var = 1; + optional bool merge = 2; // in, default: replace +} + + +message TContainerUlimit { + optional string type = 1; //required + optional bool unlimited = 2; + optional uint64 soft = 3; + optional uint64 hard = 4; + optional bool inherited = 5; // out +} + +message TContainerUlimits { + repeated TContainerUlimit ulimit = 1; + optional bool merge = 2; // in, default: replace +} + + +message TContainerControllers { + repeated string controller = 1; +} + + +message TContainerCgroup { + optional string controller = 1; //required + optional string path = 2; //required + optional bool inherited = 3; +} + +message TContainerCgroups { + repeated TContainerCgroup cgroup = 1; +} + + +message TContainerCpuSet { + optional string policy = 1; // inherit|set|node|reserve|threads|cores + optional uint32 arg = 2; // for node|reserve|threads|cores + optional string list = 3; // for set + repeated uint32 cpu = 4; // for set (used if list isn't set) + optional uint32 count = 5; // out + optional string mems = 6; +} + + +message TContainerBindMount { + optional string source = 1; //required + optional string target = 2; //required + repeated string flag = 3; +} + +message TContainerBindMounts { + repeated TContainerBindMount bind = 1; +} + + +message TContainerVolumeLink { + optional string volume = 1; //required + optional string target = 2; + optional bool required = 3; + optional bool read_only = 4; +} + +message TContainerVolumeLinks { + repeated TContainerVolumeLink link = 1; +} + + +message TContainerVolumes { + repeated string volume = 1; +} + + +message TContainerPlace { + optional string place = 1; //required + optional string alias = 2; +} + +message TContainerPlaceConfig { + repeated TContainerPlace cfg = 1; +} + + +message TContainerDevice { + optional string device = 1; //required + optional string access = 2; //required + optional string path = 3; + optional string mode = 4; + optional string user = 5; + optional string group = 6; +} + +message TContainerDevices { + repeated TContainerDevice device = 1; + optional bool merge = 2; // in, default: replace +} + + +message TContainerNetOption { + optional string opt = 1; //required + repeated string arg = 2; +} + +message TContainerNetConfig { + repeated TContainerNetOption cfg = 1; + optional bool inherited = 2; // out +} + + +message TContainerIpLimit { + optional string policy = 1; //required any|none|some + repeated string ip = 2; +} + + +message TContainerIpConfig { + message TContainerIp { + optional string dev = 1; //required + optional string ip = 2; //required + } + repeated TContainerIp cfg = 1; +} + + +message TVmStat { + optional uint64 count = 1; + optional uint64 size = 2; + optional uint64 max_size = 3; + optional uint64 used = 4; + optional uint64 max_used = 5; + optional uint64 anon = 6; + optional uint64 file = 7; + optional uint64 shmem = 8; + optional uint64 huge = 9; + optional uint64 swap = 10; + optional uint64 data = 11; + optional uint64 stack = 12; + optional uint64 code = 13; + optional uint64 locked = 14; + optional uint64 table = 15; +} + +message TContainerStatus { + optional string absolute_name = 1; // out, "/porto/..." + optional string state = 2; // out + optional uint64 id = 3; // out + optional uint32 level = 4; // out + optional string parent = 5; // out, "/porto/..." + + optional string absolute_namespace = 6; // out + + optional int32 root_pid = 7; // out + optional int32 exit_status = 8; // out + optional int32 exit_code = 9; // out + optional bool core_dumped = 10; // out + optional TError start_error = 11; // out + optional uint64 time = 12; // out + optional uint64 dead_time = 13; // out + + optional TCapabilities capabilities_allowed = 14; // out + optional TCapabilities capabilities_ambient_allowed = 15; // out + optional string root_path = 16; // out, in client namespace + optional uint64 stdout_offset = 17; // out + optional uint64 stderr_offset = 18; // out + optional string std_err = 69; // out + optional string std_out = 70; // out + + optional uint64 creation_time = 19; // out + optional uint64 start_time = 20; // out + optional uint64 death_time = 21; // out + optional uint64 change_time = 22; // out + optional bool no_changes = 23; // out, change_time < changed_since + optional string extra_properties = 73; + + optional TContainerCgroups cgroups = 24; // out + optional TContainerCpuSet cpu_set_affinity = 25; // out + + optional uint64 cpu_usage = 26; // out + optional uint64 cpu_usage_system = 27; // out + optional uint64 cpu_wait = 28; // out + optional uint64 cpu_throttled = 29; // out + + optional uint64 process_count = 30; // out + optional uint64 thread_count = 31; // out + + optional TUintMap io_read = 32; // out, bytes + optional TUintMap io_write = 33; // out, bytes + optional TUintMap io_ops = 34; // out, ops + optional TUintMap io_read_ops = 341; // out, ops + optional TUintMap io_write_ops = 342; // out, ops + optional TUintMap io_time = 35; // out, ns + optional TUintMap io_pressure = 351; // out + + optional TUintMap place_usage = 36; + optional uint64 memory_usage = 37; // out, bytes + + optional uint64 memory_guarantee_total = 38; // out + + optional uint64 memory_limit_total = 39; // out + + optional uint64 anon_limit_total = 40; + optional uint64 anon_usage = 41; // out, bytes + optional double cpu_guarantee_total = 42; + optional double cpu_guarantee_bound = 421; + optional double cpu_limit_total = 422; + optional double cpu_limit_bound = 423; + + optional uint64 cache_usage = 43; // out, bytes + + optional uint64 hugetlb_usage = 44; // out, bytes + optional uint64 hugetlb_limit = 45; + + optional uint64 minor_faults = 46; // out + optional uint64 major_faults = 47; // out + optional uint64 memory_reclaimed = 48; // out + optional TVmStat virtual_memory = 49; // out + + optional uint64 shmem_usage = 71; // out, bytes + optional uint64 mlock_usage = 72; // out, bytes + + optional uint64 oom_kills = 50; // out + optional uint64 oom_kills_total = 51; // out + optional bool oom_killed = 52; // out + + optional TUintMap net_bytes = 54; // out + optional TUintMap net_packets = 55; // out + optional TUintMap net_drops = 56; // out + optional TUintMap net_overlimits = 57; // out + optional TUintMap net_rx_bytes = 58; // out + optional TUintMap net_rx_packets = 59; // out + optional TUintMap net_rx_drops = 60; // out + optional TUintMap net_tx_bytes = 61; // out + optional TUintMap net_tx_packets = 62; // out + optional TUintMap net_tx_drops = 63; // out + + optional TContainerVolumeLinks volumes_linked = 64; // out + optional TContainerVolumes volumes_owned = 65; + + repeated TError error = 66; // out + repeated TError warning = 67; // out + repeated TError taint = 68; // out +} + +message TContainerSpec { + optional string name = 1; // required / in client namespace + optional bool weak = 2; + optional string private = 3; + optional TStringMap labels = 4; + + optional string command = 5; + optional TContainerCommandArgv command_argv = 76; + optional TContainerEnv env = 6; + optional TContainerEnv env_secret = 90; // in, out hides values + optional TContainerUlimits ulimit = 7; + optional string core_command = 8; + + optional bool isolate = 9; + optional string virt_mode = 10; + optional string enable_porto = 11; + optional string porto_namespace = 12; + optional string cgroupfs = 78; + optional bool userns = 79; + + optional uint64 aging_time = 13; + + optional TCred task_cred = 14; + optional string user = 15; + optional string group = 16; + + optional TCred owner_cred = 17; + optional string owner_user = 18; + optional string owner_group = 19; + optional string owner_containers = 77; + + optional TCapabilities capabilities = 20; + optional TCapabilities capabilities_ambient = 21; + + optional string root = 22; // in parent namespace + optional bool root_readonly = 23; + optional TContainerBindMounts bind = 24; + optional TStringMap symlink = 25; + optional TContainerDevices devices = 26; + optional TContainerPlaceConfig place = 27; + optional TUintMap place_limit = 28; + + optional string cwd = 29; + optional string stdin_path = 30; + optional string stdout_path = 31; + optional string stderr_path = 32; + optional uint64 stdout_limit = 33; + optional uint32 umask = 34; + + optional bool respawn = 35; + optional uint64 respawn_count = 36; + optional int64 max_respawns = 37; + optional uint64 respawn_delay = 38; + + optional TContainerControllers controllers = 39; + + optional string cpu_policy = 40; // normal|idle|batch|high|rt + optional double cpu_weight = 41; // 0.01 .. 100 + + optional double cpu_guarantee = 42; // in cores + optional double cpu_limit = 43; // in cores + optional double cpu_limit_total = 44; // deprecated (value moved to TContainerStatus) + optional uint64 cpu_period = 45; // ns + + optional TContainerCpuSet cpu_set = 46; + + optional uint64 thread_limit = 47; + + optional string io_policy = 48; // none|rt|high|normal|batch|idle + optional double io_weight = 49; // 0.01 .. 100 + + optional TUintMap io_limit = 50; // bps + optional TUintMap io_guarantee = 84; // bps + optional TUintMap io_ops_limit = 51; // iops + optional TUintMap io_ops_guarantee = 85; // iops + + optional uint64 memory_guarantee = 52; // bytes + + optional uint64 memory_limit = 53; // bytes + + optional uint64 anon_limit = 54; + optional uint64 anon_max_usage = 55; + + optional uint64 dirty_limit = 56; + + optional uint64 hugetlb_limit = 57; + + optional bool recharge_on_pgfault = 58; + optional bool pressurize_on_death = 59; + optional bool anon_only = 60; + + optional int32 oom_score_adj = 61; // -1000 .. +1000 + optional bool oom_is_fatal = 62; + + optional TContainerNetConfig net = 63; + optional TContainerIpLimit ip_limit = 64; + optional TContainerIpConfig ip = 65; + optional TContainerIpConfig default_gw = 66; + optional string hostname = 67; + optional string resolv_conf = 68; + optional string etc_hosts = 69; + optional TStringMap sysctl = 70; + optional TUintMap net_guarantee = 71; // bytes per second + optional TUintMap net_limit = 72; // bytes per second + optional TUintMap net_rx_limit = 73; // bytes per second + + optional TContainerVolumes volumes_required = 75; +} + +message TContainer { + optional TContainerSpec spec = 1; //required + optional TContainerStatus status = 2; + optional TError error = 3; +} + + +// Volumes + +message TVolumeDescription { + required string path = 1; // path in client namespace + map<string, string> properties = 2; + repeated string containers = 3; // linked containers (legacy) + repeated TVolumeLink links = 4; // linked containers with details + + optional uint64 change_time = 5; // sec since epoch + optional bool no_changes = 6; // change_time < changed_since +} + + +message TVolumeLink { + optional string container = 1; + optional string target = 2; // absolute path in container, default: anon + optional bool required = 3; // container cannot work without it + optional bool read_only = 4; + optional string host_target = 5; // out, absolute path in host + optional bool container_root = 6; // in, set container root + optional bool container_cwd = 7; // in, set container cwd +} + +message TVolumeResource { + optional uint64 limit = 1; // bytes or inodes + optional uint64 guarantee = 2; // bytes or inodes + optional uint64 usage = 3; // out, bytes or inodes + optional uint64 available = 4; // out, bytes or inodes +} + +message TVolumeDirectory { + optional string path = 1; // relative path in volume + optional TCred cred = 2; // default: volume cred + optional fixed32 permissions = 3; // default: volume permissions +} + +message TVolumeSymlink { + optional string path = 1; // relative path in volume + optional string target_path = 2; +} + +message TVolumeShare { + optional string path = 1; // relative path in volume + optional string origin_path = 2; // absolute path to origin + optional bool cow = 3; // default: mutable share +} + +// Structured Volume description (Porto V5 API) + +message TVolumeSpec { + optional string path = 1; // path in container, default: auto + optional string container = 2; // defines root for paths, default: self (client container) + repeated TVolumeLink links = 3; // initial links, default: anon link to self + + optional string id = 4; // out + optional string state = 5; // out + + optional string private_value = 6; // at most 4096 bytes + + optional string device_name = 7; // out + + optional string backend = 10; // default: auto + optional string place = 11; // path in host or alias, default from client container + optional string storage = 12; // persistent storage, path or name, default: non-persistent + optional string image = 52; + repeated string layers = 13; // name or path + optional bool read_only = 14; + + // defines root directory user, group and permissions + optional TCred cred = 20; // default: self task cred + optional fixed32 permissions = 21; // default: 0775 + + optional TVolumeResource space = 22; + optional TVolumeResource inodes = 23; + + optional TCred owner = 30; // default: self owner + optional string owner_container = 31; // default: self + optional string place_key = 32; // out, key for place_limit + optional string creator = 33; // out + optional bool auto_path = 34; // out + optional uint32 device_index = 35; // out + optional uint64 build_time = 37; // out, sec since epoch + + // customization at creation + repeated TVolumeDirectory directories = 40; // in + repeated TVolumeSymlink symlinks = 41; // in + repeated TVolumeShare shares = 42; // in + + optional uint64 change_time = 50; // out, sec since epoch + optional bool no_changes = 51; // out, change_time < changed_since +} + + +message TLayer { + optional string name = 1; // name or meta/name + optional string owner_user = 2; + optional string owner_group = 3; + optional uint64 last_usage = 4; // out, sec since last usage + optional string private_value = 5; +} + + +message TStorage { + optional string name = 1; // name or meta/name + optional string owner_user = 2; + optional string owner_group = 3; + optional uint64 last_usage = 4; // out, sec since last usage + optional string private_value = 5; +} + + +message TMetaStorage { + optional string name = 1; + optional string place = 2; + optional string private_value = 3; + optional uint64 space_limit = 4; // bytes + optional uint64 inode_limit = 5; // inodes + + optional uint64 space_used = 6; // out, bytes + optional uint64 space_available = 7; // out, bytes + optional uint64 inode_used = 8; // out, inodes + optional uint64 inode_available = 9; // out, inodes + optional string owner_user = 10; // out + optional string owner_group = 11; // out + optional uint64 last_usage = 12; // out, sec since last usage +} + + +// COMMANDS + +// System + +// Get porto version +message TVersionRequest { +} + +message TVersionResponse { + optional string tag = 1; + optional string revision = 2; +} + + +// Get porto statistics +message TGetSystemRequest { +} + +message TGetSystemResponse { + optional string porto_version = 1; + optional string porto_revision = 2; + optional string kernel_version = 3; + + optional fixed64 errors = 4; + optional fixed64 warnings = 5; + optional fixed64 porto_starts = 6; + optional fixed64 porto_uptime = 7; + optional fixed64 master_uptime = 8; + optional fixed64 taints = 9; + + optional bool frozen = 10; + optional bool verbose = 100; + optional bool debug = 101; + optional fixed64 log_lines = 102; + optional fixed64 log_bytes = 103; + + optional fixed64 stream_rotate_bytes = 104; + optional fixed64 stream_rotate_errors = 105; + + optional fixed64 log_lines_lost = 106; + optional fixed64 log_bytes_lost = 107; + optional fixed64 log_open = 108; + + optional fixed64 container_count = 200; + optional fixed64 container_limit = 201; + optional fixed64 container_running = 202; + optional fixed64 container_created = 203; + optional fixed64 container_started = 204; + optional fixed64 container_start_failed = 205; + optional fixed64 container_oom = 206; + optional fixed64 container_buried = 207; + optional fixed64 container_lost = 208; + optional fixed64 container_tainted = 209; + + optional fixed64 volume_count = 300; + optional fixed64 volume_limit = 301; + optional fixed64 volume_created = 303; + optional fixed64 volume_failed = 304; + optional fixed64 volume_links = 305; + optional fixed64 volume_links_mounted = 306; + optional fixed64 volume_lost = 307; + + optional fixed64 layer_import = 390; + optional fixed64 layer_export = 391; + optional fixed64 layer_remove = 392; + + optional fixed64 client_count = 400; + optional fixed64 client_max = 401; + optional fixed64 client_connected = 402; + + optional fixed64 request_queued = 500; + optional fixed64 request_completed = 501; + optional fixed64 request_failed = 502; + optional fixed64 request_threads = 503; + optional fixed64 request_longer_1s = 504; + optional fixed64 request_longer_3s = 505; + optional fixed64 request_longer_30s = 506; + optional fixed64 request_longer_5m = 507; + + optional fixed64 fail_system = 600; + optional fixed64 fail_invalid_value = 601; + optional fixed64 fail_invalid_command = 602; + optional fixed64 fail_memory_guarantee = 603; + optional fixed64 fail_invalid_netaddr = 604; + + optional fixed64 porto_crash = 666; + + optional fixed64 network_count = 700; + optional fixed64 network_created = 701; + optional fixed64 network_problems = 702; + optional fixed64 network_repairs = 703; +} + + +// Change porto state +message TSetSystemRequest { + optional bool frozen = 10; + optional bool verbose = 100; + optional bool debug = 101; +} + +message TSetSystemResponse { +} + +message TCreateFromSpecRequest { + optional TContainerSpec container = 1; //required + repeated TVolumeSpec volumes = 2; + optional bool start = 3; +} + +message TUpdateFromSpecRequest { + optional TContainerSpec container = 1; //required + optional bool start = 2; +} + +message TListContainersFilter { + optional string name = 1; // name or wildcards, default: all + optional TStringMap labels = 2; + optional uint64 changed_since = 3; // change_time >= changed_since +} + +message TStreamDumpOptions { + optional uint64 stdstream_offset = 2; // default: 0 + optional uint64 stdstream_limit = 3; // default: 8Mb +} + +message TListContainersFieldOptions { + repeated string properties = 1; // property names, default: all + optional TStreamDumpOptions stdout_options = 2; // for GetIndexed stdout + optional TStreamDumpOptions stderr_options = 3; // for GetIndexed stderr +} + +message TListContainersRequest { + repeated TListContainersFilter filters = 1; + optional TListContainersFieldOptions field_options = 2; +} + +message TListContainersResponse { + repeated TContainer containers = 1; +} + +// List available properties +message TListPropertiesRequest { +} + +message TListPropertiesResponse { + message TContainerPropertyListEntry { + optional string name = 1; + optional string desc = 2; + optional bool read_only = 3; + optional bool dynamic = 4; + } + repeated TContainerPropertyListEntry list = 1; +} + + +// deprecated, use ListProperties +message TListDataPropertiesRequest { +} + +message TListDataPropertiesResponse { + message TContainerDataListEntry { + optional string name = 1; + optional string desc = 2; + } + repeated TContainerDataListEntry list = 1; +} + + +// Create stopped container +message TCreateRequest { + optional string name = 1; +} + + +// Stop and destroy container +message TDestroyRequest { + optional string name = 1; +} + + +// List container names +message TListRequest { + optional string mask = 1; + optional uint64 changed_since = 2; // change_time >= changed_since +} + +message TListResponse { + repeated string name = 1; + optional string absolute_namespace = 2; +} + + +// Read one property +message TGetPropertyRequest { + optional string name = 1; + optional string property = 2; + // update cached counters + optional bool sync = 3; + optional bool real = 4; +} + +message TGetPropertyResponse { + optional string value = 1; +} + + +// Alias for GetProperty, deprecated +message TGetDataPropertyRequest { + optional string name = 1; + optional string data = 2; + // update cached counters + optional bool sync = 3; + optional bool real = 4; +} + +message TGetDataPropertyResponse { + optional string value = 1; +} + + +// Change one property +message TSetPropertyRequest { + optional string name = 1; + optional string property = 2; + optional string value = 3; +} + + +// Get multiple properties/data of many containers with one request +message TGetRequest { + // list of containers or wildcards, "***" - all + repeated string name = 1; + + // list of properties/data + repeated string variable = 2; + + // do not wait busy containers + optional bool nonblock = 3; + + // update cached counters + optional bool sync = 4; + optional bool real = 5; + + // change_time >= changed_since + optional uint64 changed_since = 6; +} + +message TGetResponse { + message TContainerGetValueResponse { + optional string variable = 1; + optional EError error = 2; + optional string errorMsg = 3; + optional string value = 4; + } + + message TContainerGetListResponse { + optional string name = 1; + repeated TContainerGetValueResponse keyval = 2; + + optional uint64 change_time = 3; + optional bool no_changes = 4; // change_time < changed_since + } + + repeated TContainerGetListResponse list = 1; +} + + +// Start stopped container +message TStartRequest { + optional string name = 1; +} + + +// Restart dead container +message TRespawnRequest { + optional string name = 1; +} + + +// Stop dead or running container +message TStopRequest { + optional string name = 1; + // Timeout in 1/1000 seconds between SIGTERM and SIGKILL, default 30s + optional uint32 timeout_ms = 2; +} + + +// Freeze running container +message TPauseRequest { + optional string name = 1; +} + + +// Unfreeze paused container +message TResumeRequest { + optional string name = 1; +} + + +// Translate filesystem path between containers +message TConvertPathRequest { + optional string path = 1; + optional string source = 2; + optional string destination = 3; +} + +message TConvertPathResponse { + optional string path = 1; +} + + +// Wait while container(s) is/are in running state +message TWaitRequest { + // list of containers or wildcards, "***" - all + repeated string name = 1; + + // timeout in 1/1000 seconds, 0 - nonblock + optional uint32 timeout_ms = 2; + + // list of label names or wildcards + repeated string label = 3; + + // async wait with target_state works only once + optional string target_state = 4; +} + +message TWaitResponse { + optional string name = 1; // container name + optional string state = 2; // container state or "timeout" + optional uint64 when = 3; // unix time stamp in seconds + optional string label = 4; + optional string value = 5; +} + + +// Send signal main process in container +message TKillRequest { + optional string name = 1; + optional int32 sig = 2; +} + + +// Move process into container +message TAttachProcessRequest { + optional string name = 1; + optional uint32 pid = 2; + optional string comm = 3; // ignored if empty +} + + +// Determine container by pid +message TLocateProcessRequest { + optional uint32 pid = 1; + optional string comm = 2; // ignored if empty +} + +message TLocateProcessResponse { + optional string name = 1; +} + + +// Labels + + +message TFindLabelRequest { + optional string mask = 1; // containers name or wildcard + optional string state = 2; // filter by container state + optional string label = 3; // label name or wildcard + optional string value = 4; // filter by label value +} + +message TFindLabelResponse { + message TFindLabelEntry { + optional string name = 1; + optional string state = 2; + optional string label = 3; + optional string value = 4; + } + repeated TFindLabelEntry list = 1; +} + + +message TSetLabelRequest { + optional string name = 1; + optional string label = 2; + optional string value = 3; + optional string prev_value = 4; // fail with Busy if does not match + optional string state = 5; // fail with InvalidState if not match +} + +message TSetLabelResponse { + optional string prev_value = 1; + optional string state = 2; +} + + +message TIncLabelRequest { + optional string name = 1; + optional string label = 2; // missing label starts from 0 + optional int64 add = 3 [ default = 1]; +} + +message TIncLabelResponse { + optional int64 result = 1; +} + + +message TSetSymlinkRequest { + optional string container = 1; + optional string symlink = 2; + optional string target = 3; +} + + +// Volumes + + +message TNewVolumeRequest { + optional TVolumeSpec volume = 1; +} + +message TNewVolumeResponse { + optional TVolumeSpec volume = 1; +} + + +message TGetVolumeRequest { + optional string container = 1; // get paths in container, default: self (client container) + repeated string path = 2; // volume path in container, default: all + optional uint64 changed_since = 3; // change_time >= changed_since + repeated string label = 4; // labels or wildcards +} + +message TGetVolumeResponse { + repeated TVolumeSpec volume = 1; +} + + +// List available volume properties +message TListVolumePropertiesRequest { +} + +message TListVolumePropertiesResponse { + message TVolumePropertyDescription { + optional string name = 1; + optional string desc = 2; + } + repeated TVolumePropertyDescription list = 1; +} + + +// Create new volume +// "createVolume" returns TVolumeDescription in "volume" +message TCreateVolumeRequest { + optional string path = 1; + map<string, string> properties = 2; +} + + +message TLinkVolumeRequest { + optional string path = 1; + optional string container = 2; // default - self (client container) + optional string target = 3; // path in container, "" - anon + optional bool required = 4; // stop container at fail + optional bool read_only = 5; +} + + +message TUnlinkVolumeRequest { + optional string path = 1; + optional string container = 2; // default - self, "***" - all + optional bool strict = 3; // non-lazy umount + optional string target = 4; // path in container, "" - anon, default - "***" - all +} + + +message TListVolumesRequest { + optional string path = 1; // volume path or wildcard + optional string container = 2; + optional uint64 changed_since = 3; // change_time >= changed_since +} + +message TListVolumesResponse { + repeated TVolumeDescription volumes = 1; +} + + +message TTuneVolumeRequest { + optional string path = 1; + map<string, string> properties = 2; +} + +// Layers + + +message TListLayersRequest { + optional string place = 1; // default from client container + optional string mask = 2; +} + +message TListLayersResponse { + repeated string layer = 1; // layer names (legacy) + repeated TLayer layers = 2; // layer with description +} + + +message TImportLayerRequest { + optional string layer = 1; + optional string tarball = 2; + optional bool merge = 3; + optional string place = 4; + optional string private_value = 5; + optional string compress = 6; + optional bool verbose_error = 7; +} + + +message TExportLayerRequest { + optional string volume = 1; + optional string tarball = 2; + optional string layer = 3; + optional string place = 4; + optional string compress = 5; +} + + +message TRemoveLayerRequest { + optional string layer = 1; + optional string place = 2; + optional bool async = 3; +} + + +message TGetLayerPrivateRequest { + optional string layer = 1; + optional string place = 2; +} + +message TGetLayerPrivateResponse { + optional string private_value = 1; +} + + +message TSetLayerPrivateRequest { + optional string layer = 1; + optional string place = 2; + optional string private_value = 3; +} + + +// Storages + + +message TListStoragesRequest { + optional string place = 1; + optional string mask = 2; // "name" - storage, "name/" - meta-storage +} + +message TListStoragesResponse { + repeated TStorage storages = 1; + repeated TMetaStorage meta_storages = 2; +} + + +message TRemoveStorageRequest { + optional string name = 1; + optional string place = 2; +} + + +message TImportStorageRequest { + optional string name = 1; + optional string tarball = 2; + optional string place = 3; + optional string private_value = 5; + optional string compress = 6; +} + + +message TExportStorageRequest { + optional string name = 1; + optional string tarball = 2; + optional string place = 3; + optional string compress = 4; +} + + +// Docker images API + + +message TDockerImageConfig { + repeated string cmd = 1; + repeated string env = 2; +} + +message TDockerImage { + required string id = 1; + repeated string tags = 2; + repeated string digests = 3; + repeated string layers = 4; + optional uint64 size = 5; + optional TDockerImageConfig config = 6; +} + + +message TDockerImageStatusRequest { + required string name = 1; + optional string place = 2; +} + +message TDockerImageStatusResponse { + optional TDockerImage image = 1; +} + + +message TDockerImageListRequest { + optional string place = 1; + optional string mask = 2; +} + +message TDockerImageListResponse { + repeated TDockerImage images = 1; +} + + +message TDockerImagePullRequest { + required string name = 1; + optional string place = 2; + optional string auth_token = 3; + optional string auth_path = 4; + optional string auth_service = 5; +} + +message TDockerImagePullResponse { + optional TDockerImage image = 1; +} + + +message TDockerImageRemoveRequest { + required string name = 1; + optional string place = 2; +} diff --git a/library/cpp/porto/proto/ya.make b/library/cpp/porto/proto/ya.make new file mode 100644 index 0000000000..525a807ee0 --- /dev/null +++ b/library/cpp/porto/proto/ya.make @@ -0,0 +1,5 @@ +PROTO_LIBRARY() +INCLUDE_TAGS(GO_PROTO) +SRCS(rpc.proto) +END() + diff --git a/library/cpp/porto/ut/ya.make b/library/cpp/porto/ut/ya.make new file mode 100644 index 0000000000..766a45eb56 --- /dev/null +++ b/library/cpp/porto/ut/ya.make @@ -0,0 +1,4 @@ +UNITTEST_FOR(library/cpp/porto) +TAG(ya:manual sb:portod) +SRCS(libporto_ut.cpp) +END() diff --git a/library/cpp/porto/ya.make b/library/cpp/porto/ya.make new file mode 100644 index 0000000000..e1ccbac281 --- /dev/null +++ b/library/cpp/porto/ya.make @@ -0,0 +1,17 @@ +LIBRARY() + +BUILD_ONLY_IF(WARNING WARNING LINUX) + +PEERDIR( + library/cpp/porto/proto + contrib/libs/protobuf +) + +SRCS( + libporto.cpp + metrics.cpp +) + +END() + +RECURSE_FOR_TESTS(ut) diff --git a/library/cpp/yt/CMakeLists.txt b/library/cpp/yt/CMakeLists.txt index b1dc1594fc..d05e8fb68e 100644 --- a/library/cpp/yt/CMakeLists.txt +++ b/library/cpp/yt/CMakeLists.txt @@ -16,7 +16,9 @@ add_subdirectory(logging) add_subdirectory(malloc) add_subdirectory(memory) add_subdirectory(misc) +add_subdirectory(mlock) add_subdirectory(small_containers) +add_subdirectory(stockpile) add_subdirectory(string) add_subdirectory(system) add_subdirectory(threading) diff --git a/library/cpp/yt/backtrace/cursors/CMakeLists.darwin-x86_64.txt b/library/cpp/yt/backtrace/cursors/CMakeLists.darwin-x86_64.txt index 6c6f5d1c50..76c3eda332 100644 --- a/library/cpp/yt/backtrace/cursors/CMakeLists.darwin-x86_64.txt +++ b/library/cpp/yt/backtrace/cursors/CMakeLists.darwin-x86_64.txt @@ -6,4 +6,6 @@ # original buildsystem will not be accepted. +add_subdirectory(frame_pointer) +add_subdirectory(interop) add_subdirectory(libunwind) diff --git a/library/cpp/yt/backtrace/cursors/CMakeLists.linux-aarch64.txt b/library/cpp/yt/backtrace/cursors/CMakeLists.linux-aarch64.txt index 6c6f5d1c50..76c3eda332 100644 --- a/library/cpp/yt/backtrace/cursors/CMakeLists.linux-aarch64.txt +++ b/library/cpp/yt/backtrace/cursors/CMakeLists.linux-aarch64.txt @@ -6,4 +6,6 @@ # original buildsystem will not be accepted. +add_subdirectory(frame_pointer) +add_subdirectory(interop) add_subdirectory(libunwind) diff --git a/library/cpp/yt/backtrace/cursors/CMakeLists.linux-x86_64.txt b/library/cpp/yt/backtrace/cursors/CMakeLists.linux-x86_64.txt index 6c6f5d1c50..76c3eda332 100644 --- a/library/cpp/yt/backtrace/cursors/CMakeLists.linux-x86_64.txt +++ b/library/cpp/yt/backtrace/cursors/CMakeLists.linux-x86_64.txt @@ -6,4 +6,6 @@ # original buildsystem will not be accepted. +add_subdirectory(frame_pointer) +add_subdirectory(interop) add_subdirectory(libunwind) diff --git a/library/cpp/yt/backtrace/cursors/CMakeLists.windows-x86_64.txt b/library/cpp/yt/backtrace/cursors/CMakeLists.windows-x86_64.txt index 961a9a908b..27fb2d8417 100644 --- a/library/cpp/yt/backtrace/cursors/CMakeLists.windows-x86_64.txt +++ b/library/cpp/yt/backtrace/cursors/CMakeLists.windows-x86_64.txt @@ -7,3 +7,6 @@ add_subdirectory(dummy) +add_subdirectory(frame_pointer) +add_subdirectory(interop) +add_subdirectory(libunwind) diff --git a/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.darwin-x86_64.txt b/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..9078cd7245 --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,20 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(backtrace-cursors-frame_pointer) +target_compile_options(backtrace-cursors-frame_pointer PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(backtrace-cursors-frame_pointer PUBLIC + contrib-libs-cxxsupp + yutil +) +target_sources(backtrace-cursors-frame_pointer PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/backtrace/cursors/frame_pointer/frame_pointer_cursor.cpp +) diff --git a/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.linux-aarch64.txt b/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..ce9e059d81 --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.linux-aarch64.txt @@ -0,0 +1,21 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(backtrace-cursors-frame_pointer) +target_compile_options(backtrace-cursors-frame_pointer PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(backtrace-cursors-frame_pointer PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil +) +target_sources(backtrace-cursors-frame_pointer PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/backtrace/cursors/frame_pointer/frame_pointer_cursor.cpp +) diff --git a/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.linux-x86_64.txt b/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..ce9e059d81 --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.linux-x86_64.txt @@ -0,0 +1,21 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(backtrace-cursors-frame_pointer) +target_compile_options(backtrace-cursors-frame_pointer PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(backtrace-cursors-frame_pointer PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil +) +target_sources(backtrace-cursors-frame_pointer PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/backtrace/cursors/frame_pointer/frame_pointer_cursor.cpp +) diff --git a/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.txt b/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.windows-x86_64.txt b/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..ebfaf2065e --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/frame_pointer/CMakeLists.windows-x86_64.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(backtrace-cursors-frame_pointer) +target_link_libraries(backtrace-cursors-frame_pointer PUBLIC + contrib-libs-cxxsupp + yutil +) +target_sources(backtrace-cursors-frame_pointer PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/backtrace/cursors/frame_pointer/frame_pointer_cursor.cpp +) diff --git a/library/cpp/yt/backtrace/cursors/interop/CMakeLists.darwin-x86_64.txt b/library/cpp/yt/backtrace/cursors/interop/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..8ddc8397b7 --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/interop/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,22 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(backtrace-cursors-interop) +target_compile_options(backtrace-cursors-interop PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(backtrace-cursors-interop PUBLIC + contrib-libs-cxxsupp + yutil + backtrace-cursors-frame_pointer + contrib-libs-libunwind +) +target_sources(backtrace-cursors-interop PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/backtrace/cursors/interop/interop.cpp +) diff --git a/library/cpp/yt/backtrace/cursors/interop/CMakeLists.linux-aarch64.txt b/library/cpp/yt/backtrace/cursors/interop/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..221213fdae --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/interop/CMakeLists.linux-aarch64.txt @@ -0,0 +1,23 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(backtrace-cursors-interop) +target_compile_options(backtrace-cursors-interop PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(backtrace-cursors-interop PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + backtrace-cursors-frame_pointer + contrib-libs-libunwind +) +target_sources(backtrace-cursors-interop PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/backtrace/cursors/interop/interop.cpp +) diff --git a/library/cpp/yt/backtrace/cursors/interop/CMakeLists.linux-x86_64.txt b/library/cpp/yt/backtrace/cursors/interop/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..221213fdae --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/interop/CMakeLists.linux-x86_64.txt @@ -0,0 +1,23 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(backtrace-cursors-interop) +target_compile_options(backtrace-cursors-interop PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(backtrace-cursors-interop PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + backtrace-cursors-frame_pointer + contrib-libs-libunwind +) +target_sources(backtrace-cursors-interop PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/backtrace/cursors/interop/interop.cpp +) diff --git a/library/cpp/yt/backtrace/cursors/interop/CMakeLists.txt b/library/cpp/yt/backtrace/cursors/interop/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/interop/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/library/cpp/yt/backtrace/cursors/interop/CMakeLists.windows-x86_64.txt b/library/cpp/yt/backtrace/cursors/interop/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..9a7660f685 --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/interop/CMakeLists.windows-x86_64.txt @@ -0,0 +1,19 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(backtrace-cursors-interop) +target_link_libraries(backtrace-cursors-interop PUBLIC + contrib-libs-cxxsupp + yutil + backtrace-cursors-frame_pointer + contrib-libs-libunwind +) +target_sources(backtrace-cursors-interop PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/backtrace/cursors/interop/interop.cpp +) diff --git a/library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.txt b/library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.txt index 606ff46b4b..f8b31df0c1 100644 --- a/library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.txt +++ b/library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.txt @@ -10,6 +10,8 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarc include(CMakeLists.linux-aarch64.txt) elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) include(CMakeLists.linux-x86_64.txt) endif() diff --git a/library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.windows-x86_64.txt b/library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..bea2a794c1 --- /dev/null +++ b/library/cpp/yt/backtrace/cursors/libunwind/CMakeLists.windows-x86_64.txt @@ -0,0 +1,18 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(backtrace-cursors-libunwind) +target_link_libraries(backtrace-cursors-libunwind PUBLIC + contrib-libs-cxxsupp + yutil + contrib-libs-libunwind +) +target_sources(backtrace-cursors-libunwind PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/backtrace/cursors/libunwind/libunwind_cursor.cpp +) diff --git a/library/cpp/yt/mlock/CMakeLists.darwin-x86_64.txt b/library/cpp/yt/mlock/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..ca50021faf --- /dev/null +++ b/library/cpp/yt/mlock/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,20 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(cpp-yt-mlock) +target_compile_options(cpp-yt-mlock PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(cpp-yt-mlock PUBLIC + contrib-libs-cxxsupp + yutil +) +target_sources(cpp-yt-mlock PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/mlock/mlock_other.cpp +) diff --git a/library/cpp/yt/mlock/CMakeLists.linux-aarch64.txt b/library/cpp/yt/mlock/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..68f500a75a --- /dev/null +++ b/library/cpp/yt/mlock/CMakeLists.linux-aarch64.txt @@ -0,0 +1,21 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(cpp-yt-mlock) +target_compile_options(cpp-yt-mlock PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(cpp-yt-mlock PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil +) +target_sources(cpp-yt-mlock PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/mlock/mlock_linux.cpp +) diff --git a/library/cpp/yt/mlock/CMakeLists.linux-x86_64.txt b/library/cpp/yt/mlock/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..68f500a75a --- /dev/null +++ b/library/cpp/yt/mlock/CMakeLists.linux-x86_64.txt @@ -0,0 +1,21 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(cpp-yt-mlock) +target_compile_options(cpp-yt-mlock PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(cpp-yt-mlock PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil +) +target_sources(cpp-yt-mlock PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/mlock/mlock_linux.cpp +) diff --git a/library/cpp/yt/mlock/CMakeLists.txt b/library/cpp/yt/mlock/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/library/cpp/yt/mlock/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/library/cpp/yt/mlock/CMakeLists.windows-x86_64.txt b/library/cpp/yt/mlock/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..1537ee764b --- /dev/null +++ b/library/cpp/yt/mlock/CMakeLists.windows-x86_64.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(cpp-yt-mlock) +target_link_libraries(cpp-yt-mlock PUBLIC + contrib-libs-cxxsupp + yutil +) +target_sources(cpp-yt-mlock PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/mlock/mlock_other.cpp +) diff --git a/library/cpp/yt/mlock/README.md b/library/cpp/yt/mlock/README.md new file mode 100644 index 0000000000..b61b6072c4 --- /dev/null +++ b/library/cpp/yt/mlock/README.md @@ -0,0 +1,11 @@ +# mlock + +MlockFileMappings подгружает и лочит в память все страницы исполняемого файла. + +В отличии от вызова mlockall, функция не лочит другие страницы процесса. +mlockall явно выделяет физическую память под все vma. Типичный процесс сначала +стартует и инициализирует аллокатор, а потом уже вызывает функцию для mlock страниц. +Аллокатор при старте выделяет большие диапазоны через mmap, но реально их не использует. +Поэтому mlockall приводит в повышенному потреблению памяти. + +Также, в отличии от mlockall, функция может подгрузить страницы в память сразу. diff --git a/library/cpp/yt/mlock/mlock.h b/library/cpp/yt/mlock/mlock.h new file mode 100644 index 0000000000..035fc47e37 --- /dev/null +++ b/library/cpp/yt/mlock/mlock.h @@ -0,0 +1,11 @@ +#pragma once + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +bool MlockFileMappings(bool populate = true); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/mlock/mlock_linux.cpp b/library/cpp/yt/mlock/mlock_linux.cpp new file mode 100644 index 0000000000..8791869f95 --- /dev/null +++ b/library/cpp/yt/mlock/mlock_linux.cpp @@ -0,0 +1,89 @@ +#include "mlock.h" + +#include <stdio.h> +#include <sys/mman.h> +#include <stdint.h> +#include <inttypes.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +void PopulateFile(void* ptr, size_t size) +{ + constexpr size_t PageSize = 4096; + + auto* begin = static_cast<volatile char*>(ptr); + for (auto* current = begin; current < begin + size; current += PageSize) { + *current; + } +} + +bool MlockFileMappings(bool populate) +{ + auto* file = ::fopen("/proc/self/maps", "r"); + if (!file) { + return false; + } + + // Each line of /proc/<pid>/smaps has the following format: + // address perms offset dev inode path + // E.g. + // 08048000-08056000 r-xp 00000000 03:0c 64593 /usr/sbin/gpm + + bool failed = false; + while (true) { + char line[1024]; + if (!fgets(line, sizeof(line), file)) { + break; + } + + char addressStr[64]; + char permsStr[64]; + char offsetStr[64]; + char devStr[64]; + int inode; + if (sscanf(line, "%s %s %s %s %d", + addressStr, + permsStr, + offsetStr, + devStr, + &inode) != 5) + { + continue; + } + + if (inode == 0) { + continue; + } + + if (permsStr[0] != 'r') { + continue; + } + + uintptr_t startAddress; + uintptr_t endAddress; + if (sscanf(addressStr, "%" PRIx64 "-%" PRIx64, + &startAddress, + &endAddress) != 2) + { + continue; + } + + if (::mlock(reinterpret_cast<const void*>(startAddress), endAddress - startAddress) != 0) { + failed = true; + continue; + } + + if (populate) { + PopulateFile(reinterpret_cast<void*>(startAddress), endAddress - startAddress); + } + } + + ::fclose(file); + return !failed; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/mlock/mlock_other.cpp b/library/cpp/yt/mlock/mlock_other.cpp new file mode 100644 index 0000000000..269c5c3cb9 --- /dev/null +++ b/library/cpp/yt/mlock/mlock_other.cpp @@ -0,0 +1,14 @@ +#include "mlock.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +bool MlockFileMappings(bool /* populate */) +{ + return false; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/mlock/unittests/mlock_ut.cpp b/library/cpp/yt/mlock/unittests/mlock_ut.cpp new file mode 100644 index 0000000000..98386622e8 --- /dev/null +++ b/library/cpp/yt/mlock/unittests/mlock_ut.cpp @@ -0,0 +1,19 @@ +#include <gtest/gtest.h> + +#include <library/cpp/yt/mlock/mlock.h> + +namespace NYT { +namespace { + +//////////////////////////////////////////////////////////////////////////////// + +TEST(TMlockTest, Call) +{ + ASSERT_TRUE(MlockFileMappings(false)); + ASSERT_TRUE(MlockFileMappings(true)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT
\ No newline at end of file diff --git a/library/cpp/yt/mlock/unittests/ya.make b/library/cpp/yt/mlock/unittests/ya.make new file mode 100644 index 0000000000..f1f956d468 --- /dev/null +++ b/library/cpp/yt/mlock/unittests/ya.make @@ -0,0 +1,13 @@ +GTEST() + +INCLUDE(${ARCADIA_ROOT}/library/cpp/yt/ya_cpp.make.inc) + +SRCS( + mlock_ut.cpp +) + +PEERDIR( + library/cpp/yt/mlock +) + +END() diff --git a/library/cpp/yt/mlock/ya.make b/library/cpp/yt/mlock/ya.make new file mode 100644 index 0000000000..2603d128ed --- /dev/null +++ b/library/cpp/yt/mlock/ya.make @@ -0,0 +1,16 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/library/cpp/yt/ya_cpp.make.inc) + +IF (OS_LINUX AND NOT SANITIZER_TYPE) + SRCS(mlock_linux.cpp) +ELSE() + SRCS(mlock_other.cpp) +ENDIF() + +END() + +IF (OS_LINUX AND NOT SANITIZER_TYPE) + RECURSE(unittests) +ENDIF() + diff --git a/library/cpp/yt/stockpile/CMakeLists.darwin-x86_64.txt b/library/cpp/yt/stockpile/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..8036bd5d7e --- /dev/null +++ b/library/cpp/yt/stockpile/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,20 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(cpp-yt-stockpile) +target_compile_options(cpp-yt-stockpile PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(cpp-yt-stockpile PUBLIC + contrib-libs-cxxsupp + yutil +) +target_sources(cpp-yt-stockpile PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/stockpile/stockpile_other.cpp +) diff --git a/library/cpp/yt/stockpile/CMakeLists.linux-aarch64.txt b/library/cpp/yt/stockpile/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..d023cce4da --- /dev/null +++ b/library/cpp/yt/stockpile/CMakeLists.linux-aarch64.txt @@ -0,0 +1,21 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(cpp-yt-stockpile) +target_compile_options(cpp-yt-stockpile PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(cpp-yt-stockpile PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil +) +target_sources(cpp-yt-stockpile PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/stockpile/stockpile_linux.cpp +) diff --git a/library/cpp/yt/stockpile/CMakeLists.linux-x86_64.txt b/library/cpp/yt/stockpile/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..d023cce4da --- /dev/null +++ b/library/cpp/yt/stockpile/CMakeLists.linux-x86_64.txt @@ -0,0 +1,21 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(cpp-yt-stockpile) +target_compile_options(cpp-yt-stockpile PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(cpp-yt-stockpile PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil +) +target_sources(cpp-yt-stockpile PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/stockpile/stockpile_linux.cpp +) diff --git a/library/cpp/yt/stockpile/CMakeLists.txt b/library/cpp/yt/stockpile/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/library/cpp/yt/stockpile/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/library/cpp/yt/stockpile/CMakeLists.windows-x86_64.txt b/library/cpp/yt/stockpile/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..d60191d7fe --- /dev/null +++ b/library/cpp/yt/stockpile/CMakeLists.windows-x86_64.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(cpp-yt-stockpile) +target_link_libraries(cpp-yt-stockpile PUBLIC + contrib-libs-cxxsupp + yutil +) +target_sources(cpp-yt-stockpile PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/yt/stockpile/stockpile_other.cpp +) diff --git a/library/cpp/yt/stockpile/README.md b/library/cpp/yt/stockpile/README.md new file mode 100644 index 0000000000..6ee4cd1b1f --- /dev/null +++ b/library/cpp/yt/stockpile/README.md @@ -0,0 +1,12 @@ +# stockpile + +При приближении к лимиту памяти в memory cgroup, linux запускает механизм direct reclaim, +чтобы освободить свободную память. По опыту YT, direct reclaim очень сильно замедляет работу +всего процесса. + +Проблема возникает не только, когда память занята анонимными страницами. 50% памяти контейнера +может быть занято не dirty страницами page cache, но проблема всёравно будет проявляться. Например, +если процесс активно читает файлы с диска без O_DIRECT, вся память очень быстро будет забита. + +Чтобы бороться с этой проблемой, в яндексовом ядре добавлена ручка `madvise(MADV_STOCKPILE)`. +Больше подробностей в https://st.yandex-team.ru/KERNEL-186
\ No newline at end of file diff --git a/library/cpp/yt/stockpile/stockpile.h b/library/cpp/yt/stockpile/stockpile.h new file mode 100644 index 0000000000..1df9591de4 --- /dev/null +++ b/library/cpp/yt/stockpile/stockpile.h @@ -0,0 +1,29 @@ +#pragma once + +#include <util/system/types.h> + +#include <util/generic/size_literals.h> + +#include <util/datetime/base.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +struct TStockpileOptions +{ + static constexpr i64 DefaultBufferSize = 4_GBs; + i64 BufferSize = DefaultBufferSize; + + static constexpr int DefaultThreadCount = 4; + int ThreadCount = DefaultThreadCount; + + static constexpr TDuration DefaultPeriod = TDuration::MilliSeconds(10); + TDuration Period = DefaultPeriod; +}; + +void ConfigureStockpile(const TStockpileOptions& options); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/stockpile/stockpile_linux.cpp b/library/cpp/yt/stockpile/stockpile_linux.cpp new file mode 100644 index 0000000000..3ee83d9334 --- /dev/null +++ b/library/cpp/yt/stockpile/stockpile_linux.cpp @@ -0,0 +1,42 @@ +#include "stockpile.h" + +#include <thread> +#include <mutex> + +#include <sys/mman.h> + +#include <util/system/thread.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +void RunStockpile(const TStockpileOptions& options) +{ + TThread::SetCurrentThreadName("Stockpile"); + + constexpr int MADV_STOCKPILE = 0x59410004; + + while (true) { + ::madvise(nullptr, options.BufferSize, MADV_STOCKPILE); + Sleep(options.Period); + } +} + +} // namespace + +void ConfigureStockpile(const TStockpileOptions& options) +{ + static std::once_flag OnceFlag; + std::call_once(OnceFlag, [options] { + for (int i = 0; i < options.ThreadCount; i++) { + std::thread(RunStockpile, options).detach(); + } + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/stockpile/stockpile_other.cpp b/library/cpp/yt/stockpile/stockpile_other.cpp new file mode 100644 index 0000000000..3495d9c1cb --- /dev/null +++ b/library/cpp/yt/stockpile/stockpile_other.cpp @@ -0,0 +1,12 @@ +#include "stockpile.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +void ConfigureStockpile(const TStockpileOptions& /*options*/) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/stockpile/ya.make b/library/cpp/yt/stockpile/ya.make new file mode 100644 index 0000000000..39d51aaf97 --- /dev/null +++ b/library/cpp/yt/stockpile/ya.make @@ -0,0 +1,11 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/library/cpp/yt/ya_cpp.make.inc) + +IF (OS_LINUX AND NOT SANITIZER_TYPE) + SRCS(stockpile_linux.cpp) +ELSE() + SRCS(stockpile_other.cpp) +ENDIF() + +END() diff --git a/yt/yql/plugin/bridge/interface.h b/yt/yql/plugin/bridge/interface.h index 08e765b3d0..ba08b657aa 100644 --- a/yt/yql/plugin/bridge/interface.h +++ b/yt/yql/plugin/bridge/interface.h @@ -9,8 +9,18 @@ //////////////////////////////////////////////////////////////////////////////// +// NB(mpereskokova): don't forget to update min_required_abi_version at yt/yql/agent/config.cpp and abi_version in yt/yql/plugin/dynamic/impl.cpp during breaking changes +using TFuncBridgeGetABIVersion = ssize_t(); + +//////////////////////////////////////////////////////////////////////////////// + struct TBridgeYqlPluginOptions { + ssize_t RequiredABIVersion; + + const char* SingletonsConfig; + ssize_t SingletonsConfigLength; + const char* MRJobBinary; const char* UdfDirectory; @@ -95,6 +105,7 @@ using TFuncBridgeGetProgress = TBridgeQueryResult*(TBridgeYqlPlugin* plugin, con XX(BridgeFreeYqlPlugin) \ XX(BridgeFreeQueryResult) \ XX(BridgeRun) \ - XX(BridgeGetProgress) + XX(BridgeGetProgress) \ + XX(BridgeGetABIVersion) //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yql/plugin/bridge/plugin.cpp b/yt/yql/plugin/bridge/plugin.cpp index 18e2e9e639..5c461616e9 100644 --- a/yt/yql/plugin/bridge/plugin.cpp +++ b/yt/yql/plugin/bridge/plugin.cpp @@ -72,7 +72,12 @@ public: ? options.DefaultCluster->data() : nullptr; + TString singletonsConfig = options.SingletonsConfig ? options.SingletonsConfig.ToString() : "{}"; + TBridgeYqlPluginOptions bridgeOptions { + .RequiredABIVersion = options.RequiredABIVersion, + .SingletonsConfig = singletonsConfig.data(), + .SingletonsConfigLength = static_cast<int>(singletonsConfig.size()), .MRJobBinary = options.MRJobBinary.data(), .UdfDirectory = options.UdfDirectory.data(), .ClusterCount = ssize(bridgeClusters), diff --git a/yt/yql/plugin/dynamic/dylib.exports b/yt/yql/plugin/dynamic/dylib.exports index fc77529eaf..b979ce7e18 100644 --- a/yt/yql/plugin/dynamic/dylib.exports +++ b/yt/yql/plugin/dynamic/dylib.exports @@ -4,6 +4,7 @@ BridgeFreeYqlPlugin BridgeFreeQueryResult BridgeRun BridgeGetProgress +BridgeGetABIVersion # YQL <-> YQL UDFs interface. UdfAllocateWithSize diff --git a/yt/yql/plugin/dynamic/impl.cpp b/yt/yql/plugin/dynamic/impl.cpp index 5966a8a8ec..6beb5bf8c2 100644 --- a/yt/yql/plugin/dynamic/impl.cpp +++ b/yt/yql/plugin/dynamic/impl.cpp @@ -10,8 +10,17 @@ extern "C" { //////////////////////////////////////////////////////////////////////////////// +ssize_t BridgeGetABIVersion() +{ + return 0; +} + TBridgeYqlPlugin* BridgeCreateYqlPlugin(const TBridgeYqlPluginOptions* bridgeOptions) { + YT_VERIFY(bridgeOptions->RequiredABIVersion == BridgeGetABIVersion()); + + static const TYsonString EmptyMap = TYsonString(TString("{}")); + THashMap<TString, TString> clusters; for (auto clusterIndex = 0; clusterIndex < bridgeOptions->ClusterCount; ++clusterIndex) { const auto& Cluster = bridgeOptions->Clusters[clusterIndex]; @@ -20,9 +29,14 @@ TBridgeYqlPlugin* BridgeCreateYqlPlugin(const TBridgeYqlPluginOptions* bridgeOpt auto operationAttributes = bridgeOptions->OperationAttributes ? TYsonString(TString(bridgeOptions->OperationAttributes, bridgeOptions->OperationAttributesLength)) - : TYsonString(); + : EmptyMap; + + auto singletonsConfig = bridgeOptions->SingletonsConfig + ? TYsonString(TString(bridgeOptions->SingletonsConfig, bridgeOptions->SingletonsConfigLength)) + : EmptyMap; TYqlPluginOptions options{ + .SingletonsConfig = singletonsConfig, .MRJobBinary = TString(bridgeOptions->MRJobBinary), .UdfDirectory = TString(bridgeOptions->UdfDirectory), .Clusters = std::move(clusters), diff --git a/yt/yql/plugin/native/CMakeLists.darwin-x86_64.txt b/yt/yql/plugin/native/CMakeLists.darwin-x86_64.txt index 7c002160ee..04a6b539a5 100644 --- a/yt/yql/plugin/native/CMakeLists.darwin-x86_64.txt +++ b/yt/yql/plugin/native/CMakeLists.darwin-x86_64.txt @@ -20,6 +20,7 @@ target_link_libraries(yql-plugin-native PUBLIC cpp-yson-node cpp-mapreduce-client cpp-mapreduce-common + yt-library-program library-yql-ast yql-sql-pg yql-parser-pg_wrapper @@ -65,6 +66,7 @@ target_link_libraries(yql-plugin-native.global PUBLIC cpp-yson-node cpp-mapreduce-client cpp-mapreduce-common + yt-library-program library-yql-ast yql-sql-pg yql-parser-pg_wrapper diff --git a/yt/yql/plugin/native/CMakeLists.linux-aarch64.txt b/yt/yql/plugin/native/CMakeLists.linux-aarch64.txt index b85008e3c5..b9ac27203c 100644 --- a/yt/yql/plugin/native/CMakeLists.linux-aarch64.txt +++ b/yt/yql/plugin/native/CMakeLists.linux-aarch64.txt @@ -21,6 +21,7 @@ target_link_libraries(yql-plugin-native PUBLIC cpp-yson-node cpp-mapreduce-client cpp-mapreduce-common + yt-library-program library-yql-ast yql-sql-pg yql-parser-pg_wrapper @@ -67,6 +68,7 @@ target_link_libraries(yql-plugin-native.global PUBLIC cpp-yson-node cpp-mapreduce-client cpp-mapreduce-common + yt-library-program library-yql-ast yql-sql-pg yql-parser-pg_wrapper diff --git a/yt/yql/plugin/native/CMakeLists.linux-x86_64.txt b/yt/yql/plugin/native/CMakeLists.linux-x86_64.txt index b85008e3c5..b9ac27203c 100644 --- a/yt/yql/plugin/native/CMakeLists.linux-x86_64.txt +++ b/yt/yql/plugin/native/CMakeLists.linux-x86_64.txt @@ -21,6 +21,7 @@ target_link_libraries(yql-plugin-native PUBLIC cpp-yson-node cpp-mapreduce-client cpp-mapreduce-common + yt-library-program library-yql-ast yql-sql-pg yql-parser-pg_wrapper @@ -67,6 +68,7 @@ target_link_libraries(yql-plugin-native.global PUBLIC cpp-yson-node cpp-mapreduce-client cpp-mapreduce-common + yt-library-program library-yql-ast yql-sql-pg yql-parser-pg_wrapper diff --git a/yt/yql/plugin/native/CMakeLists.windows-x86_64.txt b/yt/yql/plugin/native/CMakeLists.windows-x86_64.txt index 7c002160ee..04a6b539a5 100644 --- a/yt/yql/plugin/native/CMakeLists.windows-x86_64.txt +++ b/yt/yql/plugin/native/CMakeLists.windows-x86_64.txt @@ -20,6 +20,7 @@ target_link_libraries(yql-plugin-native PUBLIC cpp-yson-node cpp-mapreduce-client cpp-mapreduce-common + yt-library-program library-yql-ast yql-sql-pg yql-parser-pg_wrapper @@ -65,6 +66,7 @@ target_link_libraries(yql-plugin-native.global PUBLIC cpp-yson-node cpp-mapreduce-client cpp-mapreduce-common + yt-library-program library-yql-ast yql-sql-pg yql-parser-pg_wrapper diff --git a/yt/yql/plugin/native/plugin.cpp b/yt/yql/plugin/native/plugin.cpp index 9bb6c2261e..38ccbbdf3b 100644 --- a/yt/yql/plugin/native/plugin.cpp +++ b/yt/yql/plugin/native/plugin.cpp @@ -24,7 +24,11 @@ #include <ydb/library/yql/utils/log/log.h> #include <ydb/library/yql/utils/backtrace/backtrace.h> -#include <yt/cpp/mapreduce/interface/config.h> +#include <yt/yt/core/ytree/convert.h> + +#include <yt/yt/library/program/config.h> +#include <yt/yt/library/program/helpers.h> + #include <yt/cpp/mapreduce/interface/logging/logger.h> #include <library/cpp/yt/threading/rw_spin_lock.h> @@ -115,6 +119,9 @@ public: TYqlPlugin(TYqlPluginOptions& options) { try { + auto singletonsConfig = NYTree::ConvertTo<TSingletonsConfigPtr>(options.SingletonsConfig); + ConfigureSingletons(singletonsConfig); + NYql::NLog::InitLogger(std::move(options.LogBackend)); auto& logger = NYql::NLog::YqlLogger(); diff --git a/yt/yql/plugin/native/ya.make b/yt/yql/plugin/native/ya.make index 15f3851411..fe1a657c69 100644 --- a/yt/yql/plugin/native/ya.make +++ b/yt/yql/plugin/native/ya.make @@ -13,6 +13,7 @@ PEERDIR( library/cpp/yson/node yt/cpp/mapreduce/client yt/cpp/mapreduce/common + yt/yt/library/program ydb/library/yql/ast ydb/library/yql/sql/pg ydb/library/yql/parser/pg_wrapper diff --git a/yt/yql/plugin/plugin.h b/yt/yql/plugin/plugin.h index 2d2c45d54c..68b01da922 100644 --- a/yt/yql/plugin/plugin.h +++ b/yt/yql/plugin/plugin.h @@ -26,6 +26,10 @@ using TQueryId = TGuid; class TYqlPluginOptions { public: + int RequiredABIVersion; + + TYsonString SingletonsConfig; + TString MRJobBinary = "./mrjob"; TString UdfDirectory; diff --git a/yt/yt/core/CMakeLists.darwin-x86_64.txt b/yt/yt/core/CMakeLists.darwin-x86_64.txt index 0b4da77950..5b4d4b633f 100644 --- a/yt/yt/core/CMakeLists.darwin-x86_64.txt +++ b/yt/yt/core/CMakeLists.darwin-x86_64.txt @@ -9,6 +9,8 @@ add_subdirectory(http) add_subdirectory(https) add_subdirectory(misc) +add_subdirectory(rpc) +add_subdirectory(service_discovery) add_library(yt-yt-core) target_compile_options(yt-yt-core PRIVATE diff --git a/yt/yt/core/CMakeLists.linux-aarch64.txt b/yt/yt/core/CMakeLists.linux-aarch64.txt index 51eeb4ff56..640e862300 100644 --- a/yt/yt/core/CMakeLists.linux-aarch64.txt +++ b/yt/yt/core/CMakeLists.linux-aarch64.txt @@ -9,6 +9,8 @@ add_subdirectory(http) add_subdirectory(https) add_subdirectory(misc) +add_subdirectory(rpc) +add_subdirectory(service_discovery) add_library(yt-yt-core) target_compile_options(yt-yt-core PRIVATE diff --git a/yt/yt/core/CMakeLists.linux-x86_64.txt b/yt/yt/core/CMakeLists.linux-x86_64.txt index 164e626f9b..ab2ddf3548 100644 --- a/yt/yt/core/CMakeLists.linux-x86_64.txt +++ b/yt/yt/core/CMakeLists.linux-x86_64.txt @@ -9,6 +9,8 @@ add_subdirectory(http) add_subdirectory(https) add_subdirectory(misc) +add_subdirectory(rpc) +add_subdirectory(service_discovery) add_library(yt-yt-core) target_compile_options(yt-yt-core PRIVATE diff --git a/yt/yt/core/CMakeLists.windows-x86_64.txt b/yt/yt/core/CMakeLists.windows-x86_64.txt index 76ae848b35..741f6b7a8b 100644 --- a/yt/yt/core/CMakeLists.windows-x86_64.txt +++ b/yt/yt/core/CMakeLists.windows-x86_64.txt @@ -9,6 +9,8 @@ add_subdirectory(http) add_subdirectory(https) add_subdirectory(misc) +add_subdirectory(rpc) +add_subdirectory(service_discovery) add_library(yt-yt-core) target_compile_options(yt-yt-core PRIVATE diff --git a/yt/yt/core/rpc/CMakeLists.txt b/yt/yt/core/rpc/CMakeLists.txt new file mode 100644 index 0000000000..68ea682099 --- /dev/null +++ b/yt/yt/core/rpc/CMakeLists.txt @@ -0,0 +1,9 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(grpc) diff --git a/yt/yt/core/rpc/grpc/CMakeLists.darwin-x86_64.txt b/yt/yt/core/rpc/grpc/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..eb0e7def69 --- /dev/null +++ b/yt/yt/core/rpc/grpc/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,64 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(core-rpc-grpc) +target_compile_options(core-rpc-grpc PRIVATE + -Wdeprecated-this-capture +) +target_include_directories(core-rpc-grpc PUBLIC + ${CMAKE_BINARY_DIR}/yt +) +target_include_directories(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/contrib/libs/grpc +) +target_link_libraries(core-rpc-grpc PUBLIC + contrib-libs-cxxsupp + yutil + yt-yt-core + contrib-libs-grpc + contrib-libs-protobuf +) +target_proto_messages(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/proto/grpc.proto +) +target_sources(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/public.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/dispatcher.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/server.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/helpers.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/channel.cpp +) +target_proto_addincls(core-rpc-grpc + ./yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(core-rpc-grpc + --cpp_out=${CMAKE_BINARY_DIR}/yt + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/yt +) diff --git a/yt/yt/core/rpc/grpc/CMakeLists.linux-aarch64.txt b/yt/yt/core/rpc/grpc/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..c3945011ba --- /dev/null +++ b/yt/yt/core/rpc/grpc/CMakeLists.linux-aarch64.txt @@ -0,0 +1,65 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(core-rpc-grpc) +target_compile_options(core-rpc-grpc PRIVATE + -Wdeprecated-this-capture +) +target_include_directories(core-rpc-grpc PUBLIC + ${CMAKE_BINARY_DIR}/yt +) +target_include_directories(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/contrib/libs/grpc +) +target_link_libraries(core-rpc-grpc PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + contrib-libs-grpc + contrib-libs-protobuf +) +target_proto_messages(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/proto/grpc.proto +) +target_sources(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/public.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/dispatcher.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/server.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/helpers.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/channel.cpp +) +target_proto_addincls(core-rpc-grpc + ./yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(core-rpc-grpc + --cpp_out=${CMAKE_BINARY_DIR}/yt + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/yt +) diff --git a/yt/yt/core/rpc/grpc/CMakeLists.linux-x86_64.txt b/yt/yt/core/rpc/grpc/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..c3945011ba --- /dev/null +++ b/yt/yt/core/rpc/grpc/CMakeLists.linux-x86_64.txt @@ -0,0 +1,65 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(core-rpc-grpc) +target_compile_options(core-rpc-grpc PRIVATE + -Wdeprecated-this-capture +) +target_include_directories(core-rpc-grpc PUBLIC + ${CMAKE_BINARY_DIR}/yt +) +target_include_directories(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/contrib/libs/grpc +) +target_link_libraries(core-rpc-grpc PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + contrib-libs-grpc + contrib-libs-protobuf +) +target_proto_messages(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/proto/grpc.proto +) +target_sources(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/public.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/dispatcher.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/server.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/helpers.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/channel.cpp +) +target_proto_addincls(core-rpc-grpc + ./yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(core-rpc-grpc + --cpp_out=${CMAKE_BINARY_DIR}/yt + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/yt +) diff --git a/yt/yt/core/rpc/grpc/CMakeLists.txt b/yt/yt/core/rpc/grpc/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/core/rpc/grpc/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/core/rpc/grpc/CMakeLists.windows-x86_64.txt b/yt/yt/core/rpc/grpc/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..629594878e --- /dev/null +++ b/yt/yt/core/rpc/grpc/CMakeLists.windows-x86_64.txt @@ -0,0 +1,61 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(core-rpc-grpc) +target_include_directories(core-rpc-grpc PUBLIC + ${CMAKE_BINARY_DIR}/yt +) +target_include_directories(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/contrib/libs/grpc +) +target_link_libraries(core-rpc-grpc PUBLIC + contrib-libs-cxxsupp + yutil + yt-yt-core + contrib-libs-grpc + contrib-libs-protobuf +) +target_proto_messages(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/proto/grpc.proto +) +target_sources(core-rpc-grpc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/public.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/dispatcher.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/server.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/helpers.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/rpc/grpc/channel.cpp +) +target_proto_addincls(core-rpc-grpc + ./yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(core-rpc-grpc + --cpp_out=${CMAKE_BINARY_DIR}/yt + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/yt +) diff --git a/yt/yt/core/service_discovery/CMakeLists.txt b/yt/yt/core/service_discovery/CMakeLists.txt new file mode 100644 index 0000000000..7f1d154d54 --- /dev/null +++ b/yt/yt/core/service_discovery/CMakeLists.txt @@ -0,0 +1,9 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(yp) diff --git a/yt/yt/core/service_discovery/yp/CMakeLists.darwin-x86_64.txt b/yt/yt/core/service_discovery/yp/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..78ca098926 --- /dev/null +++ b/yt/yt/core/service_discovery/yp/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,23 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(core-service_discovery-yp) +target_compile_options(core-service_discovery-yp PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(core-service_discovery-yp PUBLIC + contrib-libs-cxxsupp + yutil + yt-yt-core + core-rpc-grpc +) +target_sources(core-service_discovery-yp PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/service_discovery/yp/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/service_discovery/yp/service_discovery_dummy.cpp +) diff --git a/yt/yt/core/service_discovery/yp/CMakeLists.linux-aarch64.txt b/yt/yt/core/service_discovery/yp/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..eb8048e099 --- /dev/null +++ b/yt/yt/core/service_discovery/yp/CMakeLists.linux-aarch64.txt @@ -0,0 +1,24 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(core-service_discovery-yp) +target_compile_options(core-service_discovery-yp PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(core-service_discovery-yp PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + core-rpc-grpc +) +target_sources(core-service_discovery-yp PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/service_discovery/yp/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/service_discovery/yp/service_discovery_dummy.cpp +) diff --git a/yt/yt/core/service_discovery/yp/CMakeLists.linux-x86_64.txt b/yt/yt/core/service_discovery/yp/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..eb8048e099 --- /dev/null +++ b/yt/yt/core/service_discovery/yp/CMakeLists.linux-x86_64.txt @@ -0,0 +1,24 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(core-service_discovery-yp) +target_compile_options(core-service_discovery-yp PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(core-service_discovery-yp PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + core-rpc-grpc +) +target_sources(core-service_discovery-yp PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/service_discovery/yp/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/service_discovery/yp/service_discovery_dummy.cpp +) diff --git a/yt/yt/core/service_discovery/yp/CMakeLists.txt b/yt/yt/core/service_discovery/yp/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/core/service_discovery/yp/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/core/service_discovery/yp/CMakeLists.windows-x86_64.txt b/yt/yt/core/service_discovery/yp/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..2aa18009a5 --- /dev/null +++ b/yt/yt/core/service_discovery/yp/CMakeLists.windows-x86_64.txt @@ -0,0 +1,20 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(core-service_discovery-yp) +target_link_libraries(core-service_discovery-yp PUBLIC + contrib-libs-cxxsupp + yutil + yt-yt-core + core-rpc-grpc +) +target_sources(core-service_discovery-yp PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/core/service_discovery/yp/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/core/service_discovery/yp/service_discovery_dummy.cpp +) diff --git a/yt/yt/core/ya.make b/yt/yt/core/ya.make index e985dd82ae..043d3a1dd9 100644 --- a/yt/yt/core/ya.make +++ b/yt/yt/core/ya.make @@ -324,6 +324,7 @@ PEERDIR( library/cpp/ytalloc/api yt/yt/build + yt/yt/core/misc/isa_crc64 yt/yt_proto/yt/core diff --git a/yt/yt/library/CMakeLists.darwin-x86_64.txt b/yt/yt/library/CMakeLists.darwin-x86_64.txt index d26f3bae28..2debbb7626 100644 --- a/yt/yt/library/CMakeLists.darwin-x86_64.txt +++ b/yt/yt/library/CMakeLists.darwin-x86_64.txt @@ -7,10 +7,14 @@ add_subdirectory(auth) +add_subdirectory(containers) add_subdirectory(decimal) add_subdirectory(erasure) +add_subdirectory(monitoring) add_subdirectory(numeric) +add_subdirectory(process) add_subdirectory(profiling) +add_subdirectory(program) add_subdirectory(quantile_digest) add_subdirectory(re2) add_subdirectory(syncmap) diff --git a/yt/yt/library/CMakeLists.linux-aarch64.txt b/yt/yt/library/CMakeLists.linux-aarch64.txt index d26f3bae28..524ffcf525 100644 --- a/yt/yt/library/CMakeLists.linux-aarch64.txt +++ b/yt/yt/library/CMakeLists.linux-aarch64.txt @@ -7,10 +7,15 @@ add_subdirectory(auth) +add_subdirectory(backtrace_introspector) +add_subdirectory(containers) add_subdirectory(decimal) add_subdirectory(erasure) +add_subdirectory(monitoring) add_subdirectory(numeric) +add_subdirectory(process) add_subdirectory(profiling) +add_subdirectory(program) add_subdirectory(quantile_digest) add_subdirectory(re2) add_subdirectory(syncmap) diff --git a/yt/yt/library/CMakeLists.linux-x86_64.txt b/yt/yt/library/CMakeLists.linux-x86_64.txt index d26f3bae28..524ffcf525 100644 --- a/yt/yt/library/CMakeLists.linux-x86_64.txt +++ b/yt/yt/library/CMakeLists.linux-x86_64.txt @@ -7,10 +7,15 @@ add_subdirectory(auth) +add_subdirectory(backtrace_introspector) +add_subdirectory(containers) add_subdirectory(decimal) add_subdirectory(erasure) +add_subdirectory(monitoring) add_subdirectory(numeric) +add_subdirectory(process) add_subdirectory(profiling) +add_subdirectory(program) add_subdirectory(quantile_digest) add_subdirectory(re2) add_subdirectory(syncmap) diff --git a/yt/yt/library/CMakeLists.windows-x86_64.txt b/yt/yt/library/CMakeLists.windows-x86_64.txt index 20f7fe76fa..4502da4e61 100644 --- a/yt/yt/library/CMakeLists.windows-x86_64.txt +++ b/yt/yt/library/CMakeLists.windows-x86_64.txt @@ -6,7 +6,11 @@ # original buildsystem will not be accepted. +add_subdirectory(containers) +add_subdirectory(monitoring) +add_subdirectory(process) add_subdirectory(profiling) +add_subdirectory(program) add_subdirectory(syncmap) add_subdirectory(tracing) add_subdirectory(tvm) diff --git a/yt/yt/library/backtrace_introspector/CMakeLists.linux-aarch64.txt b/yt/yt/library/backtrace_introspector/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..215573de83 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/CMakeLists.linux-aarch64.txt @@ -0,0 +1,28 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(http) + +add_library(yt-library-backtrace_introspector) +target_compile_options(yt-library-backtrace_introspector PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-backtrace_introspector PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + backtrace-cursors-interop + backtrace-cursors-libunwind + backtrace-cursors-frame_pointer + cpp-yt-misc +) +target_sources(yt-library-backtrace_introspector PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/backtrace_introspector/introspect.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/backtrace_introspector/introspect_linux.cpp +) diff --git a/yt/yt/library/backtrace_introspector/CMakeLists.linux-x86_64.txt b/yt/yt/library/backtrace_introspector/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..215573de83 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/CMakeLists.linux-x86_64.txt @@ -0,0 +1,28 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(http) + +add_library(yt-library-backtrace_introspector) +target_compile_options(yt-library-backtrace_introspector PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-backtrace_introspector PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + backtrace-cursors-interop + backtrace-cursors-libunwind + backtrace-cursors-frame_pointer + cpp-yt-misc +) +target_sources(yt-library-backtrace_introspector PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/backtrace_introspector/introspect.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/backtrace_introspector/introspect_linux.cpp +) diff --git a/yt/yt/library/backtrace_introspector/CMakeLists.txt b/yt/yt/library/backtrace_introspector/CMakeLists.txt new file mode 100644 index 0000000000..4d48dcdee6 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/CMakeLists.txt @@ -0,0 +1,13 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/backtrace_introspector/http/CMakeLists.linux-aarch64.txt b/yt/yt/library/backtrace_introspector/http/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..5b8a9100d2 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/http/CMakeLists.linux-aarch64.txt @@ -0,0 +1,24 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-backtrace_introspector-http) +target_compile_options(library-backtrace_introspector-http PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-backtrace_introspector-http PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + yt-core-http + yt-library-backtrace_introspector +) +target_sources(library-backtrace_introspector-http PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/backtrace_introspector/http/handler.cpp +) diff --git a/yt/yt/library/backtrace_introspector/http/CMakeLists.linux-x86_64.txt b/yt/yt/library/backtrace_introspector/http/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..5b8a9100d2 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/http/CMakeLists.linux-x86_64.txt @@ -0,0 +1,24 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-backtrace_introspector-http) +target_compile_options(library-backtrace_introspector-http PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-backtrace_introspector-http PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + yt-core-http + yt-library-backtrace_introspector +) +target_sources(library-backtrace_introspector-http PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/backtrace_introspector/http/handler.cpp +) diff --git a/yt/yt/library/backtrace_introspector/http/CMakeLists.txt b/yt/yt/library/backtrace_introspector/http/CMakeLists.txt new file mode 100644 index 0000000000..4d48dcdee6 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/http/CMakeLists.txt @@ -0,0 +1,13 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/backtrace_introspector/http/handler.cpp b/yt/yt/library/backtrace_introspector/http/handler.cpp new file mode 100644 index 0000000000..367e3105c0 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/http/handler.cpp @@ -0,0 +1,89 @@ +#include "handler.h" + +#include <yt/yt/core/http/server.h> + +#include <yt/yt/core/concurrency/action_queue.h> + +#include <yt/yt/library/backtrace_introspector/introspect.h> + +namespace NYT::NBacktraceIntrospector { + +using namespace NHttp; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +class THandlerBase + : public IHttpHandler +{ +public: + void HandleRequest(const IRequestPtr& /*req*/, const IResponseWriterPtr& rsp) override + { + try { + auto dumpFuture = BIND(&THandlerBase::Dump, MakeStrong(this)) + .AsyncVia(Queue_->GetInvoker()) + .Run(); + + auto dump = WaitFor(dumpFuture) + .ValueOrThrow(); + + WaitFor(rsp->WriteBody(TSharedRef::FromString(dump))) + .ThrowOnError(); + + WaitFor(rsp->Close()) + .ThrowOnError(); + } catch (const std::exception& ex) { + if (!rsp->AreHeadersFlushed()) { + rsp->SetStatus(EStatusCode::InternalServerError); + WaitFor(rsp->WriteBody(TSharedRef::FromString(ex.what()))) + .ThrowOnError(); + } + throw; + } + } + +private: + static inline const TActionQueuePtr Queue_ = New<TActionQueue>("BacktraceIntro"); + +protected: + virtual TString Dump() = 0; +}; + +class TThreadsHandler + : public THandlerBase +{ +private: + TString Dump() override + { + return FormatIntrospectionInfos(IntrospectThreads()); + } +}; + +class TFibersHandler + : public THandlerBase +{ +private: + TString Dump() override + { + return FormatIntrospectionInfos(IntrospectFibers()); + } +}; + +void Register( + const IRequestPathMatcherPtr& handlers, + const TString& prefix) +{ + handlers->Add(prefix + "/threads", New<TThreadsHandler>()); + handlers->Add(prefix + "/fibers", New<TFibersHandler>()); +} + +void Register( + const IServerPtr& server, + const TString& prefix) +{ + Register(server->GetPathMatcher(), prefix); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/http/handler.h b/yt/yt/library/backtrace_introspector/http/handler.h new file mode 100644 index 0000000000..be795b7e5d --- /dev/null +++ b/yt/yt/library/backtrace_introspector/http/handler.h @@ -0,0 +1,20 @@ +#pragma once + +#include <yt/yt/core/http/public.h> + +namespace NYT::NBacktraceIntrospector { + +//////////////////////////////////////////////////////////////////////////////// + +//! Registers introspector handlers. +void Register( + const NHttp::IRequestPathMatcherPtr& handlers, + const TString& prefix = {}); + +void Register( + const NHttp::IServerPtr& server, + const TString& prefix = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/http/ya.make b/yt/yt/library/backtrace_introspector/http/ya.make new file mode 100644 index 0000000000..504d20a2e3 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/http/ya.make @@ -0,0 +1,16 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + handler.cpp +) + +PEERDIR( + yt/yt/core + yt/yt/core/http + + yt/yt/library/backtrace_introspector +) + +END() diff --git a/yt/yt/library/backtrace_introspector/introspect.cpp b/yt/yt/library/backtrace_introspector/introspect.cpp new file mode 100644 index 0000000000..592c232f0f --- /dev/null +++ b/yt/yt/library/backtrace_introspector/introspect.cpp @@ -0,0 +1,216 @@ +#include "introspect.h" + +#include "private.h" + +#include <yt/yt/core/misc/finally.h> +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/concurrency/fiber.h> +#include <yt/yt/core/concurrency/scheduler_api.h> + +#include <yt/yt/core/tracing/trace_context.h> + +#include <library/cpp/yt/memory/safe_memory_reader.h> + +#include <library/cpp/yt/backtrace/backtrace.h> + +#include <library/cpp/yt/backtrace/cursors/libunwind/libunwind_cursor.h> + +#include <library/cpp/yt/backtrace/cursors/frame_pointer/frame_pointer_cursor.h> + +#include <library/cpp/yt/backtrace/cursors/interop/interop.h> + +#include <util/system/yield.h> + +namespace NYT::NBacktraceIntrospector { + +using namespace NConcurrency; +using namespace NThreading; +using namespace NTracing; +using namespace NBacktrace; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = BacktraceIntrospectorLogger; + +//////////////////////////////////////////////////////////////////////////////// + +std::vector<TFiberIntrospectionInfo> IntrospectFibers() +{ + YT_LOG_INFO("Fiber introspection started"); + + auto fibers = TFiber::List(); + + YT_LOG_INFO("Collecting waiting fibers backtraces"); + + std::vector<TFiberIntrospectionInfo> infos; + THashSet<TFiberId> waitingFiberIds; + THashSet<TFiberId> fiberIds; + for (const auto& fiber : fibers) { + auto fiberId = fiber->GetFiberId(); + if (fiberId == InvalidFiberId) { + continue; + } + + InsertOrCrash(fiberIds, fiberId); + + EFiberState state; + if (!fiber->TryIntrospectWaiting(state, [&] { + YT_LOG_DEBUG("Waiting fiber is successfully locked for introspection (FiberId: %x)", + fiberId); + + const auto& propagatingStorage = fiber->GetPropagatingStorage(); + const auto* traceContext = TryGetTraceContextFromPropagatingStorage(propagatingStorage); + + TFiberIntrospectionInfo info{ + .State = EFiberState::Waiting, + .FiberId = fiberId, + .WaitingSince = fiber->GetWaitingSince(), + .TraceId = traceContext ? traceContext->GetTraceId() : TTraceId(), + .TraceLoggingTag = traceContext ? traceContext->GetLoggingTag() : TString(), + }; + + auto optionalContext = TrySynthesizeLibunwindContextFromMachineContext(*fiber->GetMachineContext()); + if (!optionalContext) { + YT_LOG_WARNING("Failed to synthesize libunwind context (FiberId: %x)", + fiberId); + return; + } + + TLibunwindCursor cursor(*optionalContext); + while (!cursor.IsFinished()) { + info.Backtrace.push_back(cursor.GetCurrentIP()); + cursor.MoveNext(); + } + + infos.push_back(std::move(info)); + InsertOrCrash(waitingFiberIds, fiberId); + + YT_LOG_DEBUG("Fiber introspection completed (FiberId: %x)", + info.FiberId); + })) { + YT_LOG_DEBUG("Failed to lock fiber for introspection (FiberId: %x, State: %v)", + fiberId, + state); + } + } + + YT_LOG_INFO("Collecting running fibers backtraces"); + + THashSet<TFiberId> runningFiberIds; + for (auto& info : IntrospectThreads()) { + if (info.FiberId == InvalidFiberId) { + continue; + } + + if (waitingFiberIds.contains(info.FiberId)) { + continue; + } + + if (!runningFiberIds.insert(info.FiberId).second) { + continue; + } + + infos.push_back(TFiberIntrospectionInfo{ + .State = EFiberState::Running, + .FiberId = info.FiberId, + .ThreadId = info.ThreadId, + .ThreadName = std::move(info.ThreadName), + .TraceId = info.TraceId, + .TraceLoggingTag = std::move(info.TraceLoggingTag), + .Backtrace = std::move(info.Backtrace), + }); + } + + for (const auto& fiber : fibers) { + auto fiberId = fiber->GetFiberId(); + if (fiberId == InvalidFiberId) { + continue; + } + if (runningFiberIds.contains(fiberId)) { + continue; + } + if (waitingFiberIds.contains(fiberId)) { + continue; + } + + infos.push_back(TFiberIntrospectionInfo{ + .State = fiber->GetState(), + .FiberId = fiberId, + }); + } + + YT_LOG_INFO("Fiber introspection completed"); + + return infos; +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +void FormatBacktrace(TStringBuilder* builder, const std::vector<const void*>& backtrace) +{ + if (!backtrace.empty()) { + builder->AppendString("Backtrace:\n"); + SymbolizeBacktrace( + MakeRange(backtrace), + [&] (TStringBuf str) { + builder->AppendFormat(" %v", str); + }); + } +} + +} // namespace + +TString FormatIntrospectionInfos(const std::vector<TThreadIntrospectionInfo>& infos) +{ + TStringBuilder builder; + for (const auto& info : infos) { + builder.AppendFormat("Thread id: %v\n", info.ThreadId); + builder.AppendFormat("Thread name: %v\n", info.ThreadName); + if (info.FiberId != InvalidFiberId) { + builder.AppendFormat("Fiber id: %x\n", info.FiberId); + } + if (info.TraceId) { + builder.AppendFormat("Trace id: %v\n", info.TraceId); + } + if (info.TraceLoggingTag) { + builder.AppendFormat("Trace logging tag: %v\n", info.TraceLoggingTag); + } + FormatBacktrace(&builder, info.Backtrace); + builder.AppendString("\n"); + } + return builder.Flush(); +} + +TString FormatIntrospectionInfos(const std::vector<TFiberIntrospectionInfo>& infos) +{ + TStringBuilder builder; + for (const auto& info : infos) { + builder.AppendFormat("Fiber id: %x\n", info.FiberId); + builder.AppendFormat("State: %v\n", info.State); + if (info.WaitingSince) { + builder.AppendFormat("Waiting since: %v\n", info.WaitingSince); + } + if (info.ThreadId != InvalidThreadId) { + builder.AppendFormat("Thread id: %v\n", info.ThreadId); + } + if (!info.ThreadName.empty()) { + builder.AppendFormat("Thread name: %v\n", info.ThreadName); + } + if (info.TraceId) { + builder.AppendFormat("Trace id: %v\n", info.TraceId); + } + if (info.TraceLoggingTag) { + builder.AppendFormat("Trace logging tag: %v\n", info.TraceLoggingTag); + } + FormatBacktrace(&builder, info.Backtrace); + builder.AppendString("\n"); + } + return builder.Flush(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/introspect.h b/yt/yt/library/backtrace_introspector/introspect.h new file mode 100644 index 0000000000..2be09d2ec8 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/introspect.h @@ -0,0 +1,57 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/concurrency/public.h> + +#include <yt/yt/core/threading/public.h> + +#include <yt/yt/core/tracing/public.h> + +namespace NYT::NBacktraceIntrospector { + +//////////////////////////////////////////////////////////////////////////////// +// Thread introspection API + +struct TThreadIntrospectionInfo +{ + NThreading::TThreadId ThreadId; + NConcurrency::TFiberId FiberId; + TString ThreadName; + NTracing::TTraceId TraceId; + //! Empty if no trace context is known. + TString TraceLoggingTag; + std::vector<const void*> Backtrace; +}; + +std::vector<TThreadIntrospectionInfo> IntrospectThreads(); + +//////////////////////////////////////////////////////////////////////////////// +// Fiber introspection API + +struct TFiberIntrospectionInfo +{ + NConcurrency::EFiberState State; + NConcurrency::TFiberId FiberId; + //! Zero if fiber is not waiting. + TInstant WaitingSince; + //! |InvalidThreadId| is fiber is not running. + NThreading::TThreadId ThreadId; + //! Empty if fiber is not running. + TString ThreadName; + NTracing::TTraceId TraceId; + //! Empty if no trace context is known. + TString TraceLoggingTag; + std::vector<const void*> Backtrace; +}; + +std::vector<TFiberIntrospectionInfo> IntrospectFibers(); + +//////////////////////////////////////////////////////////////////////////////// + +TString FormatIntrospectionInfos(const std::vector<TThreadIntrospectionInfo>& infos); +TString FormatIntrospectionInfos(const std::vector<TFiberIntrospectionInfo>& infos); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/introspect_dummy.cpp b/yt/yt/library/backtrace_introspector/introspect_dummy.cpp new file mode 100644 index 0000000000..e29293c7f5 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/introspect_dummy.cpp @@ -0,0 +1,14 @@ +#include "introspect.h" + +namespace NYT::NBacktraceIntrospector { + +//////////////////////////////////////////////////////////////////////////////// + +std::vector<TThreadIntrospectionInfo> IntrospectThreads() +{ + return {}; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/introspect_linux.cpp b/yt/yt/library/backtrace_introspector/introspect_linux.cpp new file mode 100644 index 0000000000..3fc1a077f6 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/introspect_linux.cpp @@ -0,0 +1,211 @@ +#include "introspect.h" + +#include "private.h" + +#include <yt/yt/core/misc/finally.h> +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/concurrency/fiber.h> +#include <yt/yt/core/concurrency/scheduler_api.h> + +#include <yt/yt/core/tracing/trace_context.h> + +#include <library/cpp/yt/memory/safe_memory_reader.h> + +#include <library/cpp/yt/backtrace/backtrace.h> + +#include <library/cpp/yt/backtrace/cursors/libunwind/libunwind_cursor.h> + +#include <library/cpp/yt/backtrace/cursors/frame_pointer/frame_pointer_cursor.h> + +#include <library/cpp/yt/backtrace/cursors/interop/interop.h> + +#include <library/cpp/yt/misc/thread_name.h> + +#include <util/system/yield.h> + +#include <sys/syscall.h> + +namespace NYT::NBacktraceIntrospector { + +using namespace NConcurrency; +using namespace NTracing; +using namespace NBacktrace; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = BacktraceIntrospectorLogger; + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +struct TStaticString +{ + TStaticString() = default; + + explicit TStaticString(TStringBuf str) + { + Length = std::min(std::ssize(str), std::ssize(Buffer)); + std::copy(str.data(), str.data() + Length, Buffer.data()); + } + + operator TString() const + { + return TString(Buffer.data(), static_cast<size_t>(Length)); + } + + std::array<char, 256> Buffer; + int Length = 0; +}; + +struct TStaticBacktrace +{ + operator std::vector<const void*>() const + { + return std::vector<const void*>(Frames.data(), Frames.data() + FrameCount); + } + + std::array<const void*, 100> Frames; + int FrameCount = 0; +}; + +struct TSignalHandlerContext +{ + TSignalHandlerContext(); + ~TSignalHandlerContext(); + + std::atomic<bool> Finished = false; + + TFiberId FiberId = {}; + TTraceId TraceId = {}; + TStaticString TraceLoggingTag; + TStaticBacktrace Backtrace; + TThreadName ThreadName = {}; + + TSafeMemoryReader* MemoryReader = Singleton<TSafeMemoryReader>(); + + void SetFinished() + { + Finished.store(true); + } + + void WaitUntilFinished() + { + while (!Finished.load()) { + ThreadYield(); + } + } +}; + +static TSignalHandlerContext* SignalHandlerContext; + +TSignalHandlerContext::TSignalHandlerContext() +{ + YT_VERIFY(!SignalHandlerContext); + SignalHandlerContext = this; +} + +TSignalHandlerContext::~TSignalHandlerContext() +{ + YT_VERIFY(SignalHandlerContext == this); + SignalHandlerContext = nullptr; +} + +void SignalHandler(int sig, siginfo_t* /*info*/, void* threadContext) +{ + YT_VERIFY(sig == SIGUSR1); + + SignalHandlerContext->FiberId = GetCurrentFiberId(); + SignalHandlerContext->ThreadName = GetCurrentThreadName(); + if (const auto* traceContext = TryGetCurrentTraceContext()) { + SignalHandlerContext->TraceId = traceContext->GetTraceId(); + SignalHandlerContext->TraceLoggingTag = TStaticString(traceContext->GetLoggingTag()); + } + + auto cursorContext = FramePointerCursorContextFromUcontext(*static_cast<const ucontext_t*>(threadContext)); + TFramePointerCursor cursor(SignalHandlerContext->MemoryReader, cursorContext); + while (!cursor.IsFinished() && SignalHandlerContext->Backtrace.FrameCount < std::ssize(SignalHandlerContext->Backtrace.Frames)) { + SignalHandlerContext->Backtrace.Frames[SignalHandlerContext->Backtrace.FrameCount++] = cursor.GetCurrentIP(); + cursor.MoveNext(); + } + + SignalHandlerContext->SetFinished(); +} + +} // namespace + +std::vector<TThreadIntrospectionInfo> IntrospectThreads() +{ + static std::atomic<bool> IntrospectionLock; + + if (IntrospectionLock.exchange(true)) { + THROW_ERROR_EXCEPTION("Thread introspection is already in progress"); + } + + auto introspectionLockGuard = Finally([] { + YT_VERIFY(IntrospectionLock.exchange(false)); + }); + + YT_LOG_INFO("Thread introspection started"); + + { + struct sigaction action; + action.sa_flags = SA_SIGINFO | SA_RESTART; + ::sigemptyset(&action.sa_mask); + action.sa_sigaction = SignalHandler; + + if (::sigaction(SIGUSR1, &action, nullptr) != 0) { + THROW_ERROR_EXCEPTION("Failed to install signal handler") + << TError::FromSystem(); + } + } + + std::vector<TThreadIntrospectionInfo> infos; + for (auto threadId : GetCurrentProcessThreadIds()) { + TSignalHandlerContext signalHandlerContext; + if (::syscall(SYS_tkill, threadId, SIGUSR1) != 0) { + YT_LOG_DEBUG(TError::FromSystem(), "Failed to signal to thread (ThreadId: %v)", + threadId); + continue; + } + + YT_LOG_DEBUG("Sent signal to thread (ThreadId: %v)", + threadId); + + signalHandlerContext.WaitUntilFinished(); + + YT_LOG_DEBUG("Signal handler finished (ThreadId: %v, FiberId: %x)", + threadId, + signalHandlerContext.FiberId); + + infos.push_back(TThreadIntrospectionInfo{ + .ThreadId = threadId, + .FiberId = signalHandlerContext.FiberId, + .ThreadName = TString(signalHandlerContext.ThreadName.Buffer.data(), static_cast<size_t>(signalHandlerContext.ThreadName.Length)), + .TraceId = signalHandlerContext.TraceId, + .TraceLoggingTag = signalHandlerContext.TraceLoggingTag, + .Backtrace = signalHandlerContext.Backtrace, + }); + } + + { + struct sigaction action; + action.sa_flags = SA_RESTART; + ::sigemptyset(&action.sa_mask); + action.sa_handler = SIG_IGN; + + if (::sigaction(SIGUSR1, &action, nullptr) != 0) { + THROW_ERROR_EXCEPTION("Failed to de-install signal handler") + << TError::FromSystem(); + } + } + + YT_LOG_INFO("Thread introspection completed"); + + return infos; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/private.h b/yt/yt/library/backtrace_introspector/private.h new file mode 100644 index 0000000000..59f25e6023 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/private.h @@ -0,0 +1,16 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/logging/log.h> + +namespace NYT::NBacktraceIntrospector { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger BacktraceIntrospectorLogger("BacktraceIntrospector"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector + diff --git a/yt/yt/library/backtrace_introspector/public.h b/yt/yt/library/backtrace_introspector/public.h new file mode 100644 index 0000000000..54a8bd06ed --- /dev/null +++ b/yt/yt/library/backtrace_introspector/public.h @@ -0,0 +1,12 @@ +#pragma once + +namespace NYT::NBacktraceIntrospector { + +//////////////////////////////////////////////////////////////////////////////// + +struct TThreadIntrospectionInfo; +struct TFiberIntrospectionInfo; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/unittests/introspect_ut.cpp b/yt/yt/library/backtrace_introspector/unittests/introspect_ut.cpp new file mode 100644 index 0000000000..a939417958 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/unittests/introspect_ut.cpp @@ -0,0 +1,198 @@ +#include <yt/yt/core/test_framework/framework.h> + +#include <yt/yt/library/backtrace_introspector/introspect.h> + +#include <yt/yt/core/concurrency/action_queue.h> +#include <yt/yt/core/concurrency/delayed_executor.h> + +#include <yt/yt/core/actions/bind.h> +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/tracing/trace_context.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/collection_helpers.h> + +namespace NYT::NBacktraceIntrospector { +namespace { + +using namespace NConcurrency; +using namespace NTracing; + +//////////////////////////////////////////////////////////////////////////////// + +NLogging::TLogger Logger("Test"); + +//////////////////////////////////////////////////////////////////////////////// + +TEST(TBacktraceIntrospectorTest, Fibers) +{ + constexpr int HeavyQueueCount = 5; + std::vector<TActionQueuePtr> heavyQueues; + const TString HeavyThreadNamePrefix("Heavy:"); + for (int index = 0; index < HeavyQueueCount; ++index) { + heavyQueues.push_back(New<TActionQueue>(HeavyThreadNamePrefix + ToString(index))); + } + + constexpr int LightQueueCount = 3; + std::vector<TActionQueuePtr> lightQueues; + const TString LightThreadNamePrefix("Light:"); + for (int index = 0; index < LightQueueCount; ++index) { + lightQueues.push_back(New<TActionQueue>(LightThreadNamePrefix + ToString(index))); + } + + constexpr int HeavyCallbackCount = 3; + std::vector<TTraceContextPtr> heavyTraceContexts; + std::set<TTraceId> expectedHeavyTraceIds; + for (int index = 0; index < HeavyCallbackCount; ++index) { + auto traceContext = TTraceContext::NewRoot("Heavy"); + traceContext->SetLoggingTag(Format("HeavyLoggingTag:%v", index)); + heavyTraceContexts.push_back(traceContext); + InsertOrCrash(expectedHeavyTraceIds, traceContext->GetTraceId()); + } + + std::vector<TFuture<void>> heavyFutures; + for (int index = 0; index < HeavyCallbackCount; ++index) { + heavyFutures.push_back( + BIND([&, index] { + TTraceContextGuard traceContextGuard(heavyTraceContexts[index]); + YT_LOG_INFO("Heavy callback started (Index: %v)", index); + Sleep(TDuration::Seconds(3)); + YT_LOG_INFO("Heavy callback finished (Index: %v)", index); + }) + .AsyncVia(heavyQueues[index % HeavyQueueCount]->GetInvoker()) + .Run()); + } + + constexpr int LightCallbackCount = 10; + std::vector<TTraceContextPtr> lightTraceContexts; + std::set<TTraceId> expectedLightTraceIds; + for (int index = 0; index < LightCallbackCount; ++index) { + auto traceContext = TTraceContext::NewRoot("Light"); + traceContext->SetLoggingTag(Format("LightLoggingTag:%v", index)); + lightTraceContexts.push_back(traceContext); + InsertOrCrash(expectedLightTraceIds, traceContext->GetTraceId()); + } + + std::vector<TFuture<void>> lightFutures; + for (int index = 0; index < LightCallbackCount; ++index) { + lightFutures.push_back( + BIND([&, index] { + TTraceContextGuard traceContextGuard(lightTraceContexts[index]); + YT_LOG_INFO("Light callback started (Index: %v)", index); + TDelayedExecutor::WaitForDuration(TDuration::Seconds(1)); + YT_LOG_INFO("Light callback finished (Index: %v)", index); + }) + .AsyncVia(lightQueues[index % LightQueueCount]->GetInvoker()) + .Run()); + } + + Sleep(TDuration::MilliSeconds(100)); + + auto infos = IntrospectFibers(); + Cerr << FormatIntrospectionInfos(infos); + + std::set<TTraceId> actualHeavyTraceIds; + std::set<TTraceId> actualLightTraceIds; + for (const auto& info : infos) { + if (!info.TraceId) { + continue; + } + switch (info.State) { + case EFiberState::Running: + EXPECT_TRUE(actualHeavyTraceIds.insert(info.TraceId).second); + if (expectedHeavyTraceIds.contains(info.TraceId)) { + EXPECT_TRUE(info.ThreadName.StartsWith(HeavyThreadNamePrefix)); + } + break; + + case EFiberState::Waiting: + EXPECT_TRUE(actualLightTraceIds.insert(info.TraceId).second); + break; + + default: + break; + } + } + + EXPECT_EQ(expectedLightTraceIds, actualLightTraceIds); + EXPECT_EQ(expectedHeavyTraceIds, actualHeavyTraceIds); + + for (const auto& future : heavyFutures) { + future.Get().ThrowOnError(); + } + + for (const auto& future : lightFutures) { + future.Get().ThrowOnError(); + } + + for (const auto& queue : heavyQueues) { + queue->Shutdown(/*graceful*/ true); + } + for (const auto& queue : lightQueues) { + queue->Shutdown(/*graceful*/ true); + } +} + +TEST(TBacktraceIntrospectorTest, Threads) +{ + constexpr int QueueCount = 5; + std::vector<TActionQueuePtr> queues; + const TString ThreadNamePrefix("Queue:"); + for (int index = 0; index < QueueCount; ++index) { + queues.push_back(New<TActionQueue>(ThreadNamePrefix + ToString(index))); + } + + constexpr int CallbackCount = 3; + std::vector<TTraceContextPtr> traceContexts; + std::set<TTraceId> expectedTraceIds; + for (int index = 0; index < CallbackCount; ++index) { + auto traceContext = TTraceContext::NewRoot("Heavy"); + traceContexts.push_back(traceContext); + InsertOrCrash(expectedTraceIds, traceContext->GetTraceId()); + } + + std::vector<TFuture<void>> futures; + for (int index = 0; index < CallbackCount; ++index) { + futures.push_back( + BIND([&, index] { + TTraceContextGuard traceContextGuard(traceContexts[index]); + YT_LOG_INFO("Callback started (Index: %v)", index); + Sleep(TDuration::Seconds(3)); + YT_LOG_INFO("Callback finished (Index: %v)", index); + }) + .AsyncVia(queues[index % QueueCount]->GetInvoker()) + .Run()); + } + + Sleep(TDuration::MilliSeconds(100)); + + auto infos = IntrospectThreads(); + Cerr << FormatIntrospectionInfos(infos); + + std::set<TTraceId> actualTraceIds; + for (const auto& info : infos) { + if (!info.TraceId) { + continue; + } + EXPECT_TRUE(actualTraceIds.insert(info.TraceId).second); + if (expectedTraceIds.contains(info.TraceId)) { + EXPECT_TRUE(info.ThreadName.StartsWith(ThreadNamePrefix)); + } + } + + EXPECT_EQ(expectedTraceIds, actualTraceIds); + + for (const auto& future : futures) { + future.Get().ThrowOnError(); + } + for (const auto& queue : queues) { + queue->Shutdown(/*graceful*/ true); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/unittests/ya.make b/yt/yt/library/backtrace_introspector/unittests/ya.make new file mode 100644 index 0000000000..953dc020a8 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/unittests/ya.make @@ -0,0 +1,15 @@ +GTEST() + +SRCS( + introspect_ut.cpp +) + +INCLUDE(${ARCADIA_ROOT}/yt/opensource_tests.inc) + +PEERDIR( + yt/yt/library/backtrace_introspector + + yt/yt/core/test_framework +) + +END() diff --git a/yt/yt/library/backtrace_introspector/ya.make b/yt/yt/library/backtrace_introspector/ya.make new file mode 100644 index 0000000000..884b8fb562 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/ya.make @@ -0,0 +1,31 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + introspect.cpp +) +IF (OS_LINUX) + SRCS(introspect_linux.cpp) +ELSE() + SRCS(introspect_dummy.cpp) +ENDIF() + +PEERDIR( + yt/yt/core + + library/cpp/yt/backtrace/cursors/interop + library/cpp/yt/backtrace/cursors/libunwind + library/cpp/yt/backtrace/cursors/frame_pointer + library/cpp/yt/misc +) + +END() + +RECURSE( + http +) + +RECURSE_FOR_TESTS( + unittests +) diff --git a/yt/yt/library/containers/CMakeLists.darwin-x86_64.txt b/yt/yt/library/containers/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..faab79bbf6 --- /dev/null +++ b/yt/yt/library/containers/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,30 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-containers) +target_compile_options(yt-library-containers PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-containers PUBLIC + contrib-libs-cxxsupp + yutil + cpp-porto-proto + yt-library-process + yt-yt-core +) +target_sources(yt-library-containers PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/cgroup.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/instance.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/instance_limits_tracker.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/process.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_executor.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_resource_tracker.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_health_checker.cpp +) diff --git a/yt/yt/library/containers/CMakeLists.linux-aarch64.txt b/yt/yt/library/containers/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..d3ab3811e0 --- /dev/null +++ b/yt/yt/library/containers/CMakeLists.linux-aarch64.txt @@ -0,0 +1,32 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-containers) +target_compile_options(yt-library-containers PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-containers PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + cpp-porto-proto + yt-library-process + yt-yt-core + library-cpp-porto +) +target_sources(yt-library-containers PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/cgroup.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/instance.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/instance_limits_tracker.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/process.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_executor.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_resource_tracker.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_health_checker.cpp +) diff --git a/yt/yt/library/containers/CMakeLists.linux-x86_64.txt b/yt/yt/library/containers/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..d3ab3811e0 --- /dev/null +++ b/yt/yt/library/containers/CMakeLists.linux-x86_64.txt @@ -0,0 +1,32 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-containers) +target_compile_options(yt-library-containers PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-containers PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + cpp-porto-proto + yt-library-process + yt-yt-core + library-cpp-porto +) +target_sources(yt-library-containers PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/cgroup.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/instance.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/instance_limits_tracker.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/process.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_executor.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_resource_tracker.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_health_checker.cpp +) diff --git a/yt/yt/library/containers/CMakeLists.txt b/yt/yt/library/containers/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/library/containers/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/containers/CMakeLists.windows-x86_64.txt b/yt/yt/library/containers/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..998e1690fa --- /dev/null +++ b/yt/yt/library/containers/CMakeLists.windows-x86_64.txt @@ -0,0 +1,27 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-containers) +target_link_libraries(yt-library-containers PUBLIC + contrib-libs-cxxsupp + yutil + cpp-porto-proto + yt-library-process + yt-yt-core +) +target_sources(yt-library-containers PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/cgroup.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/instance.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/instance_limits_tracker.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/process.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_executor.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_resource_tracker.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/containers/porto_health_checker.cpp +) diff --git a/yt/yt/library/containers/cgroup.cpp b/yt/yt/library/containers/cgroup.cpp new file mode 100644 index 0000000000..b43ab1e14b --- /dev/null +++ b/yt/yt/library/containers/cgroup.cpp @@ -0,0 +1,752 @@ +#include "cgroup.h" +#include "private.h" + +#include <yt/yt/core/misc/fs.h> +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/ytree/fluent.h> + +#include <util/string/split.h> +#include <util/system/filemap.h> + +#include <util/system/yield.h> + +#ifdef _linux_ + #include <unistd.h> + #include <sys/stat.h> + #include <errno.h> +#endif + +namespace NYT::NContainers { + +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = ContainersLogger; +static const TString CGroupRootPath("/sys/fs/cgroup"); +#ifdef _linux_ +static const int ReadByAll = S_IRUSR | S_IRGRP | S_IROTH; +static const int ReadExecuteByAll = ReadByAll | S_IXUSR | S_IXGRP | S_IXOTH; +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +TString GetParentFor(const TString& type) +{ +#ifdef _linux_ + auto rawData = TUnbufferedFileInput("/proc/self/cgroup") + .ReadAll(); + auto result = ParseProcessCGroups(rawData); + return result[type]; +#else + Y_UNUSED(type); + return "_parent_"; +#endif +} + +#ifdef _linux_ + +std::vector<TString> ReadAllValues(const TString& fileName) +{ + auto raw = TUnbufferedFileInput(fileName) + .ReadAll(); + + YT_LOG_DEBUG("File %v contains %Qv", + fileName, + raw); + + TVector<TString> values; + StringSplitter(raw.data()) + .SplitBySet(" \n") + .SkipEmpty() + .Collect(&values); + return values; +} + +TDuration FromJiffies(ui64 jiffies) +{ + static const auto TicksPerSecond = sysconf(_SC_CLK_TCK); + return TDuration::MicroSeconds(1000 * 1000 * jiffies / TicksPerSecond); +} + +#endif + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +void TKillProcessGroupTool::operator()(const TString& processGroupPath) const +{ + SafeSetUid(0); + TNonOwningCGroup group(processGroupPath); + group.Kill(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TNonOwningCGroup::TNonOwningCGroup(const TString& fullPath) + : FullPath_(fullPath) +{ } + +TNonOwningCGroup::TNonOwningCGroup(const TString& type, const TString& name) + : FullPath_(NFS::CombinePaths({ + CGroupRootPath, + type, + GetParentFor(type), + name + })) +{ } + +TNonOwningCGroup::TNonOwningCGroup(TNonOwningCGroup&& other) + : FullPath_(std::move(other.FullPath_)) +{ } + +void TNonOwningCGroup::AddTask(int pid) const +{ + YT_LOG_INFO( + "Adding task to cgroup (Task: %v, Cgroup: %v)", + pid, + FullPath_); + Append("tasks", ToString(pid)); +} + +void TNonOwningCGroup::AddCurrentTask() const +{ + YT_VERIFY(!IsNull()); +#ifdef _linux_ + auto pid = getpid(); + AddTask(pid); +#endif +} + +TString TNonOwningCGroup::Get(const TString& name) const +{ + YT_VERIFY(!IsNull()); + TString result; +#ifdef _linux_ + const auto path = GetPath(name); + result = TFileInput(path).ReadLine(); +#else + Y_UNUSED(name); +#endif + return result; +} + +void TNonOwningCGroup::Set(const TString& name, const TString& value) const +{ + YT_VERIFY(!IsNull()); +#ifdef _linux_ + auto path = GetPath(name); + TUnbufferedFileOutput output(TFile(path, EOpenModeFlag::WrOnly)); + output << value; +#else + Y_UNUSED(name); + Y_UNUSED(value); +#endif +} + +void TNonOwningCGroup::Append(const TString& name, const TString& value) const +{ + YT_VERIFY(!IsNull()); +#ifdef _linux_ + auto path = GetPath(name); + TUnbufferedFileOutput output(TFile(path, EOpenModeFlag::ForAppend)); + output << value; +#else + Y_UNUSED(name); + Y_UNUSED(value); +#endif +} + +bool TNonOwningCGroup::IsRoot() const +{ + return FullPath_ == CGroupRootPath; +} + +bool TNonOwningCGroup::IsNull() const +{ + return FullPath_.empty(); +} + +bool TNonOwningCGroup::Exists() const +{ + return NFS::Exists(FullPath_); +} + +std::vector<int> TNonOwningCGroup::GetProcesses() const +{ + std::vector<int> results; + if (!IsNull()) { +#ifdef _linux_ + auto values = ReadAllValues(GetPath("cgroup.procs")); + for (const auto& value : values) { + int pid = FromString<int>(value); + results.push_back(pid); + } +#endif + } + return results; +} + +std::vector<int> TNonOwningCGroup::GetTasks() const +{ + std::vector<int> results; + if (!IsNull()) { +#ifdef _linux_ + auto values = ReadAllValues(GetPath("tasks")); + for (const auto& value : values) { + int pid = FromString<int>(value); + results.push_back(pid); + } +#endif + } + return results; +} + +const TString& TNonOwningCGroup::GetFullPath() const +{ + return FullPath_; +} + +std::vector<TNonOwningCGroup> TNonOwningCGroup::GetChildren() const +{ + // We retry enumerating directories, since it may fail with weird diagnostics if + // number of subcgroups changes. + while (true) { + try { + std::vector<TNonOwningCGroup> result; + + if (IsNull()) { + return result; + } + + auto directories = NFS::EnumerateDirectories(FullPath_); + for (const auto& directory : directories) { + result.emplace_back(NFS::CombinePaths(FullPath_, directory)); + } + return result; + } catch (const std::exception& ex) { + YT_LOG_WARNING(ex, "Failed to list subcgroups (Path: %v)", FullPath_); + } + } +} + +void TNonOwningCGroup::EnsureExistence() const +{ + YT_LOG_INFO("Creating cgroup (Cgroup: %v)", FullPath_); + + YT_VERIFY(!IsNull()); + +#ifdef _linux_ + NFS::MakeDirRecursive(FullPath_, 0755); +#endif +} + +void TNonOwningCGroup::Lock() const +{ + Traverse( + BIND([] (const TNonOwningCGroup& group) { group.DoLock(); }), + BIND([] (const TNonOwningCGroup& /*group*/) {})); +} + +void TNonOwningCGroup::Unlock() const +{ + Traverse( + BIND([] (const TNonOwningCGroup& /*group*/) {}), + BIND([] (const TNonOwningCGroup& group) { group.DoUnlock(); })); +} + +void TNonOwningCGroup::Kill() const +{ + YT_VERIFY(!IsRoot()); + + Traverse( + BIND([] (const TNonOwningCGroup& group) { group.DoKill(); }), + BIND([] (const TNonOwningCGroup& /*group*/) {})); +} + +void TNonOwningCGroup::RemoveAllSubcgroups() const +{ + Traverse( + BIND([] (const TNonOwningCGroup& group) { + group.TryUnlock(); + }), + BIND([this_ = this] (const TNonOwningCGroup& group) { + if (this_ != &group) { + group.DoRemove(); + } + })); +} + +void TNonOwningCGroup::RemoveRecursive() const +{ + RemoveAllSubcgroups(); + DoRemove(); +} + +void TNonOwningCGroup::DoLock() const +{ + YT_LOG_INFO("Locking cgroup (Cgroup: %v)", FullPath_); + +#ifdef _linux_ + if (!IsNull()) { + int code = chmod(FullPath_.data(), ReadExecuteByAll); + YT_VERIFY(code == 0); + + code = chmod(GetPath("tasks").data(), ReadByAll); + YT_VERIFY(code == 0); + } +#endif +} + +bool TNonOwningCGroup::TryUnlock() const +{ + YT_LOG_INFO("Unlocking cgroup (Cgroup: %v)", FullPath_); + + if (!Exists()) { + return true; + } + + bool result = true; + +#ifdef _linux_ + if (!IsNull()) { + int code = chmod(GetPath("tasks").data(), ReadByAll | S_IWUSR); + if (code != 0) { + result = false; + } + + code = chmod(FullPath_.data(), ReadExecuteByAll | S_IWUSR); + if (code != 0) { + result = false; + } + } +#endif + + return result; +} + +void TNonOwningCGroup::DoUnlock() const +{ + YT_VERIFY(TryUnlock()); +} + +void TNonOwningCGroup::DoKill() const +{ + YT_LOG_DEBUG("Started killing processes in cgroup (Cgroup: %v)", FullPath_); + +#ifdef _linux_ + while (true) { + auto pids = GetTasks(); + if (pids.empty()) + break; + + YT_LOG_DEBUG("Killing processes (Pids: %v)", pids); + + for (int pid : pids) { + auto result = kill(pid, SIGKILL); + if (result == -1) { + YT_VERIFY(errno == ESRCH); + } + } + + ThreadYield(); + } +#endif + + YT_LOG_DEBUG("Finished killing processes in cgroup (Cgroup: %v)", FullPath_); +} + +void TNonOwningCGroup::DoRemove() const +{ + if (NFS::Exists(FullPath_)) { + NFS::Remove(FullPath_); + } +} + +void TNonOwningCGroup::Traverse( + const TCallback<void(const TNonOwningCGroup&)>& preorderAction, + const TCallback<void(const TNonOwningCGroup&)>& postorderAction) const +{ + preorderAction(*this); + + for (const auto& child : GetChildren()) { + child.Traverse(preorderAction, postorderAction); + } + + postorderAction(*this); +} + +TString TNonOwningCGroup::GetPath(const TString& filename) const +{ + return NFS::CombinePaths(FullPath_, filename); +} + +//////////////////////////////////////////////////////////////////////////////// + +TCGroup::TCGroup(const TString& type, const TString& name) + : TNonOwningCGroup(type, name) +{ } + +TCGroup::TCGroup(TCGroup&& other) + : TNonOwningCGroup(std::move(other)) + , Created_(other.Created_) +{ + other.Created_ = false; +} + +TCGroup::TCGroup(TNonOwningCGroup&& other) + : TNonOwningCGroup(std::move(other)) + , Created_(false) +{ } + +TCGroup::~TCGroup() +{ + if (Created_) { + Destroy(); + } +} + +void TCGroup::Create() +{ + EnsureExistence(); + Created_ = true; +} + +void TCGroup::Destroy() +{ + YT_LOG_INFO("Destroying cgroup (Cgroup: %v)", FullPath_); + YT_VERIFY(Created_); + +#ifdef _linux_ + try { + NFS::Remove(FullPath_); + } catch (const std::exception& ex) { + YT_LOG_FATAL(ex, "Failed to destroy cgroup (Cgroup: %v)", FullPath_); + } +#endif + Created_ = false; +} + +bool TCGroup::IsCreated() const +{ + return Created_; +} + +//////////////////////////////////////////////////////////////////////////////// + +const TString TCpuAccounting::Name = "cpuacct"; + +TCpuAccounting::TStatistics& operator-=(TCpuAccounting::TStatistics& lhs, const TCpuAccounting::TStatistics& rhs) +{ + #define XX(name) lhs.name = lhs.name.ValueOrThrow() - rhs.name.ValueOrThrow(); + XX(UserUsageTime) + XX(SystemUsageTime) + XX(WaitTime) + XX(ThrottledTime) + XX(ContextSwitchesDelta) + XX(PeakThreadCount) + #undef XX + return lhs; +} + +TCpuAccounting::TCpuAccounting(const TString& name) + : TCGroup(Name, name) +{ } + +TCpuAccounting::TCpuAccounting(TNonOwningCGroup&& nonOwningCGroup) + : TCGroup(std::move(nonOwningCGroup)) +{ } + +TCpuAccounting::TStatistics TCpuAccounting::GetStatisticsRecursive() const +{ + TCpuAccounting::TStatistics result; +#ifdef _linux_ + try { + auto path = NFS::CombinePaths(GetFullPath(), "cpuacct.stat"); + auto values = ReadAllValues(path); + YT_VERIFY(values.size() == 4); + + TString type[2]; + ui64 jiffies[2]; + + for (int i = 0; i < 2; ++i) { + type[i] = values[2 * i]; + jiffies[i] = FromString<ui64>(values[2 * i + 1]); + } + + for (int i = 0; i < 2; ++i) { + if (type[i] == "user") { + result.UserUsageTime = FromJiffies(jiffies[i]); + } else if (type[i] == "system") { + result.SystemUsageTime = FromJiffies(jiffies[i]); + } + } + } catch (const std::exception& ex) { + YT_LOG_FATAL( + ex, + "Failed to retrieve CPU statistics from cgroup (Cgroup: %v)", + GetFullPath()); + } +#endif + return result; +} + +TCpuAccounting::TStatistics TCpuAccounting::GetStatistics() const +{ + auto statistics = GetStatisticsRecursive(); + for (auto& cgroup : GetChildren()) { + auto cpuCGroup = TCpuAccounting(std::move(cgroup)); + statistics -= cpuCGroup.GetStatisticsRecursive(); + } + return statistics; +} + + +//////////////////////////////////////////////////////////////////////////////// + +const TString TCpu::Name = "cpu"; + +static const int DefaultCpuShare = 1024; + +TCpu::TCpu(const TString& name) + : TCGroup(Name, name) +{ } + +void TCpu::SetShare(double share) +{ + int cpuShare = static_cast<int>(share * DefaultCpuShare); + Set("cpu.shares", ToString(cpuShare)); +} + +//////////////////////////////////////////////////////////////////////////////// + +const TString TBlockIO::Name = "blkio"; + +TBlockIO::TBlockIO(const TString& name) + : TCGroup(Name, name) +{ } + +// For more information about format of data +// read https://www.kernel.org/doc/Documentation/cgroups/blkio-controller.txt + +TBlockIO::TStatistics TBlockIO::GetStatistics() const +{ + TBlockIO::TStatistics result; +#ifdef _linux_ + auto bytesStats = GetDetailedStatistics("blkio.io_service_bytes"); + for (const auto& item : bytesStats) { + if (item.Type == "Read") { + result.IOReadByte = result.IOReadByte.ValueOrThrow() + item.Value; + } else if (item.Type == "Write") { + result.IOWriteByte = result.IOReadByte.ValueOrThrow() + item.Value; + } + } + + auto ioStats = GetDetailedStatistics("blkio.io_serviced"); + for (const auto& item : ioStats) { + if (item.Type == "Read") { + result.IOReadOps = result.IOReadOps.ValueOrThrow() + item.Value; + result.IOOps = result.IOOps.ValueOrThrow() + item.Value; + } else if (item.Type == "Write") { + result.IOWriteOps = result.IOWriteOps.ValueOrThrow() + item.Value; + result.IOOps = result.IOOps.ValueOrThrow() + item.Value; + } + } +#endif + return result; +} + +std::vector<TBlockIO::TStatisticsItem> TBlockIO::GetIOServiceBytes() const +{ + return GetDetailedStatistics("blkio.io_service_bytes"); +} + +std::vector<TBlockIO::TStatisticsItem> TBlockIO::GetIOServiced() const +{ + return GetDetailedStatistics("blkio.io_serviced"); +} + +std::vector<TBlockIO::TStatisticsItem> TBlockIO::GetDetailedStatistics(const char* filename) const +{ + std::vector<TBlockIO::TStatisticsItem> result; +#ifdef _linux_ + try { + auto path = NFS::CombinePaths(GetFullPath(), filename); + auto values = ReadAllValues(path); + + int lineNumber = 0; + while (3 * lineNumber + 2 < std::ssize(values)) { + TStatisticsItem item; + item.DeviceId = values[3 * lineNumber]; + item.Type = values[3 * lineNumber + 1]; + item.Value = FromString<ui64>(values[3 * lineNumber + 2]); + + { + auto guard = Guard(SpinLock_); + DeviceIds_.insert(item.DeviceId); + } + + if (item.Type == "Read" || item.Type == "Write") { + result.push_back(item); + + YT_LOG_DEBUG("IO operations serviced (OperationCount: %v, OperationType: %v, DeviceId: %v)", + item.Value, + item.Type, + item.DeviceId); + } + ++lineNumber; + } + } catch (const std::exception& ex) { + YT_LOG_FATAL( + ex, + "Failed to retrieve block IO statistics from cgroup (Cgroup: %v)", + GetFullPath()); + } +#else + Y_UNUSED(filename); +#endif + return result; +} + +void TBlockIO::ThrottleOperations(i64 operations) const +{ + auto guard = Guard(SpinLock_); + for (const auto& deviceId : DeviceIds_) { + auto value = Format("%v %v", deviceId, operations); + Append("blkio.throttle.read_iops_device", value); + Append("blkio.throttle.write_iops_device", value); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +const TString TMemory::Name = "memory"; + +TMemory::TMemory(const TString& name) + : TCGroup(Name, name) +{ } + +TMemory::TStatistics TMemory::GetStatistics() const +{ + TMemory::TStatistics result; +#ifdef _linux_ + try { + auto values = ReadAllValues(GetPath("memory.stat")); + int lineNumber = 0; + while (2 * lineNumber + 1 < std::ssize(values)) { + const auto& type = values[2 * lineNumber]; + const auto& unparsedValue = values[2 * lineNumber + 1]; + if (type == "rss") { + result.Rss = FromString<ui64>(unparsedValue); + } + if (type == "mapped_file") { + result.MappedFile = FromString<ui64>(unparsedValue); + } + if (type == "pgmajfault") { + result.MajorPageFaults = FromString<ui64>(unparsedValue); + } + ++lineNumber; + } + } catch (const std::exception& ex) { + YT_LOG_FATAL( + ex, + "Failed to retrieve memory statistics from cgroup (Cgroup: %v)", + GetFullPath()); + } +#endif + return result; +} + +i64 TMemory::GetMaxMemoryUsage() const +{ + return FromString<i64>(Get("memory.max_usage_in_bytes")); +} + +void TMemory::SetLimitInBytes(i64 bytes) const +{ + Set("memory.limit_in_bytes", ToString(bytes)); +} + +void TMemory::ForceEmpty() const +{ + Set("memory.force_empty", "0"); +} + +//////////////////////////////////////////////////////////////////////////////// + +const TString TFreezer::Name = "freezer"; + +TFreezer::TFreezer(const TString& name) + : TCGroup(Name, name) +{ } + +TString TFreezer::GetState() const +{ + return Get("freezer.state"); +} + +void TFreezer::Freeze() const +{ + Set("freezer.state", "FROZEN"); +} + +void TFreezer::Unfreeze() const +{ + Set("freezer.state", "THAWED"); +} + +//////////////////////////////////////////////////////////////////////////////// + +std::map<TString, TString> ParseProcessCGroups(const TString& str) +{ + std::map<TString, TString> result; + + TVector<TString> values; + StringSplitter(str.data()).SplitBySet(":\n").SkipEmpty().Collect(&values); + for (size_t i = 0; i + 2 < values.size(); i += 3) { + // Check format. + FromString<int>(values[i]); + + const auto& subsystemsSet = values[i + 1]; + const auto& name = values[i + 2]; + + TVector<TString> subsystems; + StringSplitter(subsystemsSet.data()).Split(',').SkipEmpty().Collect(&subsystems); + for (const auto& subsystem : subsystems) { + if (!subsystem.StartsWith("name=")) { + int start = 0; + if (name.StartsWith("/")) { + start = 1; + } + result[subsystem] = name.substr(start); + } + } + } + + return result; +} + +std::map<TString, TString> GetProcessCGroups(pid_t pid) +{ + auto cgroupsPath = Format("/proc/%v/cgroup", pid); + auto rawCgroups = TFileInput{cgroupsPath}.ReadAll(); + return ParseProcessCGroups(rawCgroups); +} + +bool IsValidCGroupType(const TString& type) +{ + return + type == TCpuAccounting::Name || + type == TCpu::Name || + type == TBlockIO::Name || + type == TMemory::Name || + type == TFreezer::Name; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/cgroup.h b/yt/yt/library/containers/cgroup.h new file mode 100644 index 0000000000..a61fbbddc3 --- /dev/null +++ b/yt/yt/library/containers/cgroup.h @@ -0,0 +1,290 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/actions/public.h> + +#include <yt/yt/core/ytree/yson_struct.h> +#include <yt/yt/core/yson/public.h> + +#include <yt/yt/core/misc/property.h> + +#include <library/cpp/yt/threading/spin_lock.h> + +#include <vector> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +void RemoveAllSubcgroups(const TString& path); + +//////////////////////////////////////////////////////////////////////////////// + +struct TKillProcessGroupTool +{ + void operator()(const TString& processGroupPath) const; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TNonOwningCGroup + : private TNonCopyable +{ +public: + DEFINE_BYREF_RO_PROPERTY(TString, FullPath); + +public: + TNonOwningCGroup() = default; + explicit TNonOwningCGroup(const TString& fullPath); + TNonOwningCGroup(const TString& type, const TString& name); + TNonOwningCGroup(TNonOwningCGroup&& other); + + void AddTask(int pid) const; + void AddCurrentTask() const; + + bool IsRoot() const; + bool IsNull() const; + bool Exists() const; + + std::vector<int> GetProcesses() const; + std::vector<int> GetTasks() const; + const TString& GetFullPath() const; + + std::vector<TNonOwningCGroup> GetChildren() const; + + void EnsureExistence() const; + + void Lock() const; + void Unlock() const; + + void Kill() const; + + void RemoveAllSubcgroups() const; + void RemoveRecursive() const; + +protected: + TString Get(const TString& name) const; + void Set(const TString& name, const TString& value) const; + void Append(const TString& name, const TString& value) const; + + void DoLock() const; + void DoUnlock() const; + + bool TryUnlock() const; + + void DoKill() const; + + void DoRemove() const; + + void Traverse( + const TCallback<void(const TNonOwningCGroup&)>& preorderAction, + const TCallback<void(const TNonOwningCGroup&)>& postorderAction) const; + + TString GetPath(const TString& filename) const; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TCGroup + : public TNonOwningCGroup +{ +protected: + TCGroup(const TString& type, const TString& name); + TCGroup(TNonOwningCGroup&& other); + TCGroup(TCGroup&& other); + +public: + ~TCGroup(); + + void Create(); + void Destroy(); + + bool IsCreated() const; + +private: + bool Created_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TCpuAccounting + : public TCGroup +{ +public: + static const TString Name; + + struct TStatistics + { + TErrorOr<TDuration> TotalUsageTime; + TErrorOr<TDuration> UserUsageTime; + TErrorOr<TDuration> SystemUsageTime; + TErrorOr<TDuration> WaitTime; + TErrorOr<TDuration> ThrottledTime; + + TErrorOr<ui64> ThreadCount; + TErrorOr<ui64> ContextSwitches; + TErrorOr<ui64> ContextSwitchesDelta; + TErrorOr<ui64> PeakThreadCount; + + TErrorOr<TDuration> LimitTime; + TErrorOr<TDuration> GuaranteeTime; + }; + + explicit TCpuAccounting(const TString& name); + + TStatistics GetStatisticsRecursive() const; + TStatistics GetStatistics() const; + +private: + explicit TCpuAccounting(TNonOwningCGroup&& nonOwningCGroup); +}; + +void Serialize(const TCpuAccounting::TStatistics& statistics, NYson::IYsonConsumer* consumer); + +//////////////////////////////////////////////////////////////////////////////// + +class TCpu + : public TCGroup +{ +public: + static const TString Name; + + explicit TCpu(const TString& name); + + void SetShare(double share); +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TBlockIO + : public TCGroup +{ +public: + static const TString Name; + + struct TStatistics + { + TErrorOr<ui64> IOReadByte; + TErrorOr<ui64> IOWriteByte; + TErrorOr<ui64> IOBytesLimit; + + TErrorOr<ui64> IOReadOps; + TErrorOr<ui64> IOWriteOps; + TErrorOr<ui64> IOOps; + TErrorOr<ui64> IOOpsLimit; + + TErrorOr<TDuration> IOTotalTime; + TErrorOr<TDuration> IOWaitTime; + }; + + struct TStatisticsItem + { + TString DeviceId; + TString Type; + ui64 Value = 0; + }; + + explicit TBlockIO(const TString& name); + + TStatistics GetStatistics() const; + void ThrottleOperations(i64 iops) const; + +private: + //! Guards device ids. + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, SpinLock_); + //! Set of all seen device ids. + mutable THashSet<TString> DeviceIds_; + + std::vector<TBlockIO::TStatisticsItem> GetDetailedStatistics(const char* filename) const; + + std::vector<TStatisticsItem> GetIOServiceBytes() const; + std::vector<TStatisticsItem> GetIOServiced() const; +}; + +void Serialize(const TBlockIO::TStatistics& statistics, NYson::IYsonConsumer* consumer); + +//////////////////////////////////////////////////////////////////////////////// + +class TMemory + : public TCGroup +{ +public: + static const TString Name; + + struct TStatistics + { + TErrorOr<ui64> Rss; + TErrorOr<ui64> MappedFile; + TErrorOr<ui64> MinorPageFaults; + TErrorOr<ui64> MajorPageFaults; + + TErrorOr<ui64> FileCacheUsage; + TErrorOr<ui64> AnonUsage; + TErrorOr<ui64> AnonLimit; + TErrorOr<ui64> MemoryUsage; + TErrorOr<ui64> MemoryGuarantee; + TErrorOr<ui64> MemoryLimit; + TErrorOr<ui64> MaxMemoryUsage; + + TErrorOr<ui64> OomKills; + TErrorOr<ui64> OomKillsTotal; + }; + + explicit TMemory(const TString& name); + + TStatistics GetStatistics() const; + i64 GetMaxMemoryUsage() const; + + void SetLimitInBytes(i64 bytes) const; + + void ForceEmpty() const; +}; + +void Serialize(const TMemory::TStatistics& statistics, NYson::IYsonConsumer* consumer); + +//////////////////////////////////////////////////////////////////////////////// + +class TNetwork +{ +public: + struct TStatistics + { + TErrorOr<ui64> TxBytes; + TErrorOr<ui64> TxPackets; + TErrorOr<ui64> TxDrops; + TErrorOr<ui64> TxLimit; + + TErrorOr<ui64> RxBytes; + TErrorOr<ui64> RxPackets; + TErrorOr<ui64> RxDrops; + TErrorOr<ui64> RxLimit; + }; +}; + +void Serialize(const TNetwork::TStatistics& statistics, NYson::IYsonConsumer* consumer); + +//////////////////////////////////////////////////////////////////////////////// + +class TFreezer + : public TCGroup +{ +public: + static const TString Name; + + explicit TFreezer(const TString& name); + + TString GetState() const; + void Freeze() const; + void Unfreeze() const; +}; + +//////////////////////////////////////////////////////////////////////////////// + +std::map<TString, TString> ParseProcessCGroups(const TString& str); +std::map<TString, TString> GetProcessCGroups(pid_t pid); +bool IsValidCGroupType(const TString& type); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/config.cpp b/yt/yt/library/containers/config.cpp new file mode 100644 index 0000000000..39e46f2372 --- /dev/null +++ b/yt/yt/library/containers/config.cpp @@ -0,0 +1,64 @@ +#include "config.h" + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +void TPodSpecConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cpu_to_vcpu_factor", &TThis::CpuToVCpuFactor) + .Default(); +} + +//////////////////////////////////////////////////////////////////////////////// + +bool TCGroupConfig::IsCGroupSupported(const TString& cgroupType) const +{ + auto it = std::find_if( + SupportedCGroups.begin(), + SupportedCGroups.end(), + [&] (const TString& type) { + return type == cgroupType; + }); + return it != SupportedCGroups.end(); +} + +void TCGroupConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("supported_cgroups", &TThis::SupportedCGroups) + .Default(); + + registrar.Postprocessor([] (TThis* config) { + for (const auto& type : config->SupportedCGroups) { + if (!IsValidCGroupType(type)) { + THROW_ERROR_EXCEPTION("Invalid cgroup type %Qv", type); + } + } + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TPortoExecutorDynamicConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("retries_timeout", &TThis::RetriesTimeout) + .Default(TDuration::Seconds(10)); + registrar.Parameter("poll_period", &TThis::PollPeriod) + .Default(TDuration::MilliSeconds(100)); + registrar.Parameter("api_timeout", &TThis::ApiTimeout) + .Default(TDuration::Minutes(5)); + registrar.Parameter("api_disk_timeout", &TThis::ApiDiskTimeout) + .Default(TDuration::Minutes(30)); + registrar.Parameter("enable_network_isolation", &TThis::EnableNetworkIsolation) + .Default(true); + registrar.Parameter("enable_test_porto_failures", &TThis::EnableTestPortoFailures) + .Default(false); + registrar.Parameter("stub_error_code", &TThis::StubErrorCode) + .Default(EPortoErrorCode::SocketError); + registrar.Parameter("enable_test_porto_not_responding", &TThis::EnableTestPortoNotResponding) + .Default(false); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/config.h b/yt/yt/library/containers/config.h new file mode 100644 index 0000000000..3639274cff --- /dev/null +++ b/yt/yt/library/containers/config.h @@ -0,0 +1,64 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/yson_struct.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +class TPodSpecConfig + : public virtual NYTree::TYsonStruct +{ +public: + std::optional<double> CpuToVCpuFactor; + + REGISTER_YSON_STRUCT(TPodSpecConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TPodSpecConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCGroupConfig + : public virtual NYTree::TYsonStruct +{ +public: + std::vector<TString> SupportedCGroups; + + bool IsCGroupSupported(const TString& cgroupType) const; + + REGISTER_YSON_STRUCT(TCGroupConfig); + + static void Register(TRegistrar registrar); +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoExecutorDynamicConfig + : public NYTree::TYsonStruct +{ +public: + TDuration RetriesTimeout; + TDuration PollPeriod; + TDuration ApiTimeout; + TDuration ApiDiskTimeout; + bool EnableNetworkIsolation; + bool EnableTestPortoFailures; + bool EnableTestPortoNotResponding; + + EPortoErrorCode StubErrorCode; + + REGISTER_YSON_STRUCT(TPortoExecutorDynamicConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TPortoExecutorDynamicConfig) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/cri/config.cpp b/yt/yt/library/containers/cri/config.cpp new file mode 100644 index 0000000000..5572f4d980 --- /dev/null +++ b/yt/yt/library/containers/cri/config.cpp @@ -0,0 +1,54 @@ +#include "config.h" +#include "cri_api.h" + +namespace NYT::NContainers::NCri { + +//////////////////////////////////////////////////////////////////////////////// + +void TCriExecutorConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("runtime_endpoint", &TThis::RuntimeEndpoint) + .Default(TString(DefaultCriEndpoint)); + + registrar.Parameter("image_endpoint", &TThis::ImageEndpoint) + .Default(TString(DefaultCriEndpoint)); + + registrar.Parameter("namespace", &TThis::Namespace) + .NonEmpty(); + + registrar.Parameter("runtime_handler", &TThis::RuntimeHandler) + .Optional(); + + registrar.Parameter("base_cgroup", &TThis::BaseCgroup) + .NonEmpty(); + + registrar.Parameter("cpu_period", &TThis::CpuPeriod) + .Default(TDuration::MilliSeconds(100)); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCriAuthConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("username", &TThis::Username) + .Optional(); + + registrar.Parameter("password", &TThis::Password) + .Optional(); + + registrar.Parameter("auth", &TThis::Auth) + .Optional(); + + registrar.Parameter("server_address", &TThis::ServerAddress) + .Optional(); + + registrar.Parameter("identity_token", &TThis::IdentityToken) + .Optional(); + + registrar.Parameter("registry_token", &TThis::RegistryToken) + .Optional(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers::NCri diff --git a/yt/yt/library/containers/cri/config.h b/yt/yt/library/containers/cri/config.h new file mode 100644 index 0000000000..4ea33fd390 --- /dev/null +++ b/yt/yt/library/containers/cri/config.h @@ -0,0 +1,70 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/rpc/config.h> + +namespace NYT::NContainers::NCri { + +//////////////////////////////////////////////////////////////////////////////// + +class TCriExecutorConfig + : public NRpc::TRetryingChannelConfig +{ +public: + //! gRPC endpoint for CRI container runtime service. + TString RuntimeEndpoint; + + //! gRPC endpoint for CRI image manager service. + TString ImageEndpoint; + + //! CRI namespace where this executor operates. + TString Namespace; + + //! Name of CRI runtime configuration to use. + TString RuntimeHandler; + + //! Common parent cgroup for all pods. + TString BaseCgroup; + + //! Cpu quota period for cpu limits. + TDuration CpuPeriod; + + REGISTER_YSON_STRUCT(TCriExecutorConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCriExecutorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +// TODO(khlebnikov): split docker registry stuff into common "docker" library. + +//! TCriAuthConfig depicts docker registry authentification +class TCriAuthConfig + : public NYTree::TYsonStruct +{ +public: + TString Username; + + TString Password; + + TString Auth; + + TString ServerAddress; + + TString IdentityToken; + + TString RegistryToken; + + REGISTER_YSON_STRUCT(TCriAuthConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCriAuthConfig) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers::NCri diff --git a/yt/yt/library/containers/cri/cri_api.cpp b/yt/yt/library/containers/cri/cri_api.cpp new file mode 100644 index 0000000000..93457017ba --- /dev/null +++ b/yt/yt/library/containers/cri/cri_api.cpp @@ -0,0 +1,33 @@ +#include "cri_api.h" + +namespace NYT::NContainers::NCri { + +using namespace NRpc; + +//////////////////////////////////////////////////////////////////////////////// + +TCriRuntimeApi::TCriRuntimeApi(IChannelPtr channel) + : TProxyBase(std::move(channel), GetDescriptor()) +{ } + +const TServiceDescriptor& TCriRuntimeApi::GetDescriptor() +{ + static const auto Descriptor = TServiceDescriptor(NProto::RuntimeService::service_full_name()); + return Descriptor; +} + +//////////////////////////////////////////////////////////////////////////////// + +TCriImageApi::TCriImageApi(IChannelPtr channel) + : TProxyBase(std::move(channel), GetDescriptor()) +{ } + +const TServiceDescriptor& TCriImageApi::GetDescriptor() +{ + static const auto Descriptor = TServiceDescriptor(NProto::ImageService::service_full_name()); + return Descriptor; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers::NCri diff --git a/yt/yt/library/containers/cri/cri_api.h b/yt/yt/library/containers/cri/cri_api.h new file mode 100644 index 0000000000..74fe9a64a0 --- /dev/null +++ b/yt/yt/library/containers/cri/cri_api.h @@ -0,0 +1,99 @@ +#pragma once + +#include <yt/yt/core/rpc/client.h> + +#include <k8s.io/cri-api/pkg/apis/runtime/v1/api.grpc.pb.h> + +namespace NYT::NContainers::NCri { + +//////////////////////////////////////////////////////////////////////////////// + +namespace NProto = ::runtime::v1; + +//! Reasonable default for CRI gRPC socket address. +constexpr TStringBuf DefaultCriEndpoint = "unix:///run/containerd/containerd.sock"; + +//! RuntimeReady means the runtime is up and ready to accept basic containers. +constexpr TStringBuf RuntimeReady = "RuntimeReady"; + +//! NetworkReady means the runtime network is up and ready to accept containers which require network. +constexpr TStringBuf NetworkReady = "NetworkReady"; + +//! CRI uses cgroupfs notation for systemd slices, but each name must ends with ".slice". +constexpr TStringBuf SystemdSliceSuffix = ".slice"; + +//////////////////////////////////////////////////////////////////////////////// + +//! CRI labels for pods and containers managed by YT +constexpr TStringBuf YTPodNamespaceLabel = "tech.ytsaurus.pod.namespace"; +constexpr TStringBuf YTPodNameLabel = "tech.ytsaurus.pod.name"; +constexpr TStringBuf YTContainerNameLabel = "tech.ytsaurus.container.name"; +constexpr TStringBuf YTJobIdLabel = "tech.ytsaurus.job.id"; + +//////////////////////////////////////////////////////////////////////////////// + +#define DEFINE_CRI_API_METHOD(method, ...) \ + DEFINE_RPC_PROXY_METHOD_GENERIC(method, NProto::method##Request, NProto::method##Response, __VA_ARGS__) + +//! See https://github.com/kubernetes/cri-api +class TCriRuntimeApi + : public NRpc::TProxyBase +{ +public: + explicit TCriRuntimeApi(NRpc::IChannelPtr channel); + + static const NRpc::TServiceDescriptor& GetDescriptor(); + + DEFINE_CRI_API_METHOD(Version); + DEFINE_CRI_API_METHOD(RunPodSandbox); + DEFINE_CRI_API_METHOD(StopPodSandbox); + DEFINE_CRI_API_METHOD(RemovePodSandbox); + DEFINE_CRI_API_METHOD(PodSandboxStatus); + DEFINE_CRI_API_METHOD(ListPodSandbox); + DEFINE_CRI_API_METHOD(CreateContainer); + DEFINE_CRI_API_METHOD(StartContainer); + DEFINE_CRI_API_METHOD(StopContainer); + DEFINE_CRI_API_METHOD(RemoveContainer); + DEFINE_CRI_API_METHOD(ListContainers); + DEFINE_CRI_API_METHOD(ContainerStatus); + DEFINE_CRI_API_METHOD(UpdateContainerResources); + DEFINE_CRI_API_METHOD(ReopenContainerLog); + DEFINE_CRI_API_METHOD(ExecSync); + DEFINE_CRI_API_METHOD(Exec); + DEFINE_CRI_API_METHOD(Attach); + DEFINE_CRI_API_METHOD(PortForward); + DEFINE_CRI_API_METHOD(ContainerStats); + DEFINE_CRI_API_METHOD(ListContainerStats); + DEFINE_CRI_API_METHOD(PodSandboxStats); + DEFINE_CRI_API_METHOD(ListPodSandboxStats); + DEFINE_CRI_API_METHOD(UpdateRuntimeConfig); + DEFINE_CRI_API_METHOD(Status); + DEFINE_CRI_API_METHOD(CheckpointContainer); + DEFINE_CRI_API_METHOD(ListMetricDescriptors); + DEFINE_CRI_API_METHOD(ListPodSandboxMetrics); + + // FIXME(khlebnikov): figure out streaming results + // DEFINE_RPC_PROXY_METHOD_GENERIC(GetContainerEvents, NProto::GetEventsRequest, NProto::ContainerEventResponse, + // .SetStreamingEnabled(true)); +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TCriImageApi + : public NRpc::TProxyBase +{ +public: + explicit TCriImageApi(NRpc::IChannelPtr channel); + + static const NRpc::TServiceDescriptor& GetDescriptor(); + + DEFINE_CRI_API_METHOD(ListImages); + DEFINE_CRI_API_METHOD(ImageStatus); + DEFINE_CRI_API_METHOD(PullImage); + DEFINE_CRI_API_METHOD(RemoveImage); + DEFINE_CRI_API_METHOD(ImageFsInfo); +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers::NCri diff --git a/yt/yt/library/containers/cri/cri_executor.cpp b/yt/yt/library/containers/cri/cri_executor.cpp new file mode 100644 index 0000000000..428fd93165 --- /dev/null +++ b/yt/yt/library/containers/cri/cri_executor.cpp @@ -0,0 +1,666 @@ +#include "cri_executor.h" +#include "private.h" + +#include <yt/yt/core/actions/bind.h> + +#include <yt/yt/core/rpc/grpc/channel.h> + +#include <yt/yt/core/rpc/retrying_channel.h> + +#include <yt/yt/core/misc/error.h> +#include <yt/yt/core/misc/proc.h> +#include <yt/yt/core/misc/protobuf_helpers.h> + +#include <yt/yt/core/concurrency/periodic_executor.h> + +namespace NYT::NContainers::NCri { + +using namespace NRpc; +using namespace NRpc::NGrpc; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +void FormatValue(TStringBuilderBase* builder, const TCriDescriptor& descriptor, TStringBuf /*spec*/) +{ + builder->AppendFormat("%v (%s)", descriptor.Id.substr(0, 12), descriptor.Name); +} + +void FormatValue(TStringBuilderBase* builder, const TCriPodDescriptor& descriptor, TStringBuf /*spec*/) +{ + builder->AppendFormat("%v (%s)", descriptor.Id.substr(0, 12), descriptor.Name); +} + +void FormatValue(TStringBuilderBase* builder, const TCriImageDescriptor& descriptor, TStringBuf /*spec*/) +{ + builder->AppendString(descriptor.Image); +} + +static TError DecodeExitCode(int exitCode, const TString& reason) +{ + if (exitCode == 0) { + return TError(); + } + + // TODO(khkebnikov) map reason == "OOMKilled" + + // Common bash notation for signals: 128 + signal + if (exitCode > 128) { + int signalNumber = exitCode - 128; + return TError( + EProcessErrorCode::Signal, + "Process terminated by signal %v", + signalNumber) + << TErrorAttribute("signal", signalNumber) + << TErrorAttribute("reason", reason); + } + + // TODO(khkebnikov) check these + // 125 - container failed to run + // 126 - non executable + // 127 - command not found + // 128 - invalid exit code + // 255 - exit code out of range + + return TError( + EProcessErrorCode::NonZeroExitCode, + "Process exited with code %v", + exitCode) + << TErrorAttribute("exit_code", exitCode) + << TErrorAttribute("reason", reason); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TCriProcess + : public TProcessBase +{ +public: + TCriProcess( + const TString& path, + ICriExecutorPtr executor, + TCriContainerSpecPtr containerSpec, + const TCriPodDescriptor& podDescriptor, + TCriPodSpecPtr podSpec, + TDuration pollPeriod = TDuration::MilliSeconds(100)) + : TProcessBase(path) + , Executor_(std::move(executor)) + , ContainerSpec_(std::move(containerSpec)) + , PodDescriptor_(podDescriptor) + , PodSpec_(std::move(podSpec)) + , PollPeriod_(pollPeriod) + { + // Just for symmetry with sibling classes. + AddArgument(Path_); + } + + void Kill(int /*signal*/) override + { + WaitFor(Executor_->StopContainer(ContainerDescriptor_)) + .ThrowOnError(); + } + + NNet::IConnectionWriterPtr GetStdInWriter() override + { + THROW_ERROR_EXCEPTION("Not implemented for CRI process"); + } + + NNet::IConnectionReaderPtr GetStdOutReader() override + { + THROW_ERROR_EXCEPTION("Not implemented for CRI process"); + } + + NNet::IConnectionReaderPtr GetStdErrReader() override + { + THROW_ERROR_EXCEPTION("Not implemented for CRI process"); + } + +private: + const ICriExecutorPtr Executor_; + const TCriContainerSpecPtr ContainerSpec_; + const TCriPodDescriptor PodDescriptor_; + const TCriPodSpecPtr PodSpec_; + const TDuration PollPeriod_; + + TCriDescriptor ContainerDescriptor_; + + TPeriodicExecutorPtr AsyncWaitExecutor_; + + void DoSpawn() override + { + if (ContainerSpec_->Command.empty()) { + ContainerSpec_->Command = {Path_}; + } + ContainerSpec_->Arguments = std::vector<TString>(Args_.begin() + 1, Args_.end()); + ContainerSpec_->WorkingDirectory = WorkingDirectory_; + + ContainerSpec_->BindMounts.emplace_back( + NCri::TCriBindMount { + .ContainerPath = WorkingDirectory_, + .HostPath = WorkingDirectory_, + .ReadOnly = false, + } + ); + + for (const auto& keyVal : Env_) { + TStringBuf key, val; + if (TStringBuf(keyVal).TrySplit('=', key, val)) { + ContainerSpec_->Environment[key] = val; + } + } + + ContainerDescriptor_ = WaitFor(Executor_->CreateContainer(ContainerSpec_, PodDescriptor_, PodSpec_)) + .ValueOrThrow(); + + YT_LOG_DEBUG("Spawning process (Command: %v, Container: %v)", ContainerSpec_->Command[0], ContainerDescriptor_); + WaitFor(Executor_->StartContainer(ContainerDescriptor_)) + .ThrowOnError(); + + // TODO(khkebnikov) replace polling with CRI event + AsyncWaitExecutor_ = New<TPeriodicExecutor>( + GetSyncInvoker(), + BIND(&TCriProcess::PollContainerStatus, MakeStrong(this)), + PollPeriod_); + + AsyncWaitExecutor_->Start(); + } + + void PollContainerStatus() + { + Executor_->GetContainerStatus(ContainerDescriptor_) + .SubscribeUnique(BIND(&TCriProcess::OnContainerStatus, MakeStrong(this))); + } + + void OnContainerStatus(TErrorOr<TCriRuntimeApi::TRspContainerStatusPtr>&& responseOrError) + { + auto response = responseOrError.ValueOrThrow(); + if (!response->has_status()) { + return; + } + auto status = response->status(); + if (status.state() == NProto::CONTAINER_EXITED) { + auto error = DecodeExitCode(status.exit_code(), status.reason()); + YT_LOG_DEBUG(error, "Process finished (Container: %v)", ContainerDescriptor_); + YT_UNUSED_FUTURE(AsyncWaitExecutor_->Stop()); + FinishedPromise_.TrySet(error); + } + } +}; + +DEFINE_REFCOUNTED_TYPE(TCriProcess) + +//////////////////////////////////////////////////////////////////////////////// + +class TCriExecutor + : public ICriExecutor +{ +public: + TCriExecutor( + TCriExecutorConfigPtr config, + IChannelFactoryPtr channelFactory) + : Config_(std::move(config)) + , RuntimeApi_(CreateRetryingChannel(Config_, channelFactory->CreateChannel(Config_->RuntimeEndpoint))) + , ImageApi_(CreateRetryingChannel(Config_, channelFactory->CreateChannel(Config_->ImageEndpoint))) + { } + + TString GetPodCgroup(TString podName) const override + { + TStringBuilder cgroup; + cgroup.AppendString(Config_->BaseCgroup); + cgroup.AppendString("/"); + cgroup.AppendString(podName); + if (Config_->BaseCgroup.EndsWith(SystemdSliceSuffix)) { + cgroup.AppendString(SystemdSliceSuffix); + } + return cgroup.Flush(); + } + + TFuture<TCriRuntimeApi::TRspStatusPtr> GetRuntimeStatus(bool verbose = false) override + { + auto req = RuntimeApi_.Status(); + req->set_verbose(verbose); + return req->Invoke(); + } + + TFuture<TCriRuntimeApi::TRspListPodSandboxPtr> ListPodSandbox( + std::function<void(NProto::PodSandboxFilter&)> initFilter = nullptr) override + { + auto req = RuntimeApi_.ListPodSandbox(); + + { + auto* filter = req->mutable_filter(); + + if (auto namespace_ = Config_->Namespace) { + auto& labels = *filter->mutable_label_selector(); + labels[YTPodNamespaceLabel] = namespace_; + } + + if (initFilter) { + initFilter(*filter); + } + } + + return req->Invoke(); + } + + TFuture<TCriRuntimeApi::TRspListContainersPtr> ListContainers( + std::function<void(NProto::ContainerFilter&)> initFilter = nullptr) override + { + auto req = RuntimeApi_.ListContainers(); + + { + auto* filter = req->mutable_filter(); + + if (auto namespace_ = Config_->Namespace) { + auto& labels = *filter->mutable_label_selector(); + labels[YTPodNamespaceLabel] = namespace_; + } + + if (initFilter) { + initFilter(*filter); + } + } + + return req->Invoke(); + } + + TFuture<void> ForEachPodSandbox( + const TCallback<void(const TCriPodDescriptor&, const NProto::PodSandbox&)>& callback, + std::function<void(NProto::PodSandboxFilter&)> initFilter) override + { + return ListPodSandbox(initFilter).Apply(BIND([=] (const TCriRuntimeApi::TRspListPodSandboxPtr& rsp) { + for (const auto& pod : rsp->items()) { + TCriPodDescriptor descriptor{.Name=pod.metadata().name(), .Id=pod.id()}; + callback(descriptor, pod); + } + })); + } + + TFuture<void> ForEachContainer( + const TCallback<void(const TCriDescriptor&, const NProto::Container&)>& callback, + std::function<void(NProto::ContainerFilter&)> initFilter = nullptr) override + { + return ListContainers(initFilter).Apply(BIND([=] (const TCriRuntimeApi::TRspListContainersPtr& rsp) { + for (const auto& ct : rsp->containers()) { + TCriDescriptor descriptor{.Name=ct.metadata().name(), .Id=ct.id()}; + callback(descriptor, ct); + } + })); + } + + TFuture<TCriRuntimeApi::TRspPodSandboxStatusPtr> GetPodSandboxStatus( + const TCriPodDescriptor& podDescriptor, bool verbose = false) override + { + auto req = RuntimeApi_.PodSandboxStatus(); + req->set_pod_sandbox_id(podDescriptor.Id); + req->set_verbose(verbose); + return req->Invoke(); + } + + TFuture<TCriRuntimeApi::TRspContainerStatusPtr> GetContainerStatus( + const TCriDescriptor& descriptor, bool verbose = false) override + { + auto req = RuntimeApi_.ContainerStatus(); + req->set_container_id(descriptor.Id); + req->set_verbose(verbose); + return req->Invoke(); + } + + TFuture<TCriPodDescriptor> RunPodSandbox(TCriPodSpecPtr podSpec) override + { + auto req = RuntimeApi_.RunPodSandbox(); + + FillPodSandboxConfig(req->mutable_config(), *podSpec); + + if (Config_->RuntimeHandler) { + req->set_runtime_handler(Config_->RuntimeHandler); + } + + return req->Invoke().Apply(BIND([name = podSpec->Name] (const TCriRuntimeApi::TRspRunPodSandboxPtr& rsp) -> TCriPodDescriptor { + return TCriPodDescriptor{.Name = name, .Id = rsp->pod_sandbox_id()}; + })); + } + + TFuture<void> StopPodSandbox(const TCriPodDescriptor& podDescriptor) override + { + auto req = RuntimeApi_.StopPodSandbox(); + req->set_pod_sandbox_id(podDescriptor.Id); + return req->Invoke().AsVoid(); + } + + TFuture<void> RemovePodSandbox(const TCriPodDescriptor& podDescriptor) override + { + auto req = RuntimeApi_.RemovePodSandbox(); + req->set_pod_sandbox_id(podDescriptor.Id); + return req->Invoke().AsVoid(); + } + + TFuture<void> UpdatePodResources( + const TCriPodDescriptor& /*pod*/, + const TCriContainerResources& /*resources*/) override + { + return MakeFuture(TError("Not implemented")); + } + + TFuture<TCriDescriptor> CreateContainer( + TCriContainerSpecPtr ctSpec, + const TCriPodDescriptor& podDescriptor, + TCriPodSpecPtr podSpec) override + { + auto req = RuntimeApi_.CreateContainer(); + req->set_pod_sandbox_id(podDescriptor.Id); + + auto* config = req->mutable_config(); + + { + auto* metadata = config->mutable_metadata(); + metadata->set_name(ctSpec->Name); + } + + { + auto& labels = *config->mutable_labels(); + + for (const auto& [key, val] : ctSpec->Labels) { + labels[key] = val; + } + + labels[YTPodNamespaceLabel] = Config_->Namespace; + labels[YTPodNameLabel] = podSpec->Name; + labels[YTContainerNameLabel] = ctSpec->Name; + } + + FillImageSpec(config->mutable_image(), ctSpec->Image); + + for (const auto& mountSpec : ctSpec->BindMounts) { + auto* mount = config->add_mounts(); + mount->set_container_path(mountSpec.ContainerPath); + mount->set_host_path(mountSpec.HostPath); + mount->set_readonly(mountSpec.ReadOnly); + mount->set_propagation(NProto::PROPAGATION_PRIVATE); + } + + { + ToProto(config->mutable_command(), ctSpec->Command); + ToProto(config->mutable_args(), ctSpec->Arguments); + + config->set_working_dir(ctSpec->WorkingDirectory); + + for (const auto& [key, val] : ctSpec->Environment) { + auto* env = config->add_envs(); + env->set_key(key); + env->set_value(val); + } + } + + { + auto* linux = config->mutable_linux(); + FillLinuxContainerResources(linux->mutable_resources(), ctSpec->Resources); + + auto* security = linux->mutable_security_context(); + + auto* namespaces = security->mutable_namespace_options(); + namespaces->set_network(NProto::NODE); + + security->set_readonly_rootfs(ctSpec->ReadOnlyRootFS); + + if (ctSpec->Credentials.Uid) { + security->mutable_run_as_user()->set_value(*ctSpec->Credentials.Uid); + } + if (ctSpec->Credentials.Gid) { + security->mutable_run_as_group()->set_value(*ctSpec->Credentials.Gid); + } + ToProto(security->mutable_supplemental_groups(), ctSpec->Credentials.Groups); + } + + FillPodSandboxConfig(req->mutable_sandbox_config(), *podSpec); + + return req->Invoke().Apply(BIND([name = ctSpec->Name] (const TCriRuntimeApi::TRspCreateContainerPtr& rsp) -> TCriDescriptor { + return TCriDescriptor{.Name = "", .Id = rsp->container_id()}; + })); + } + + TFuture<void> StartContainer(const TCriDescriptor& descriptor) override + { + auto req = RuntimeApi_.StartContainer(); + req->set_container_id(descriptor.Id); + return req->Invoke().AsVoid(); + } + + TFuture<void> StopContainer(const TCriDescriptor& descriptor, TDuration timeout) override + { + auto req = RuntimeApi_.StopContainer(); + req->set_container_id(descriptor.Id); + req->set_timeout(timeout.Seconds()); + return req->Invoke().AsVoid(); + } + + TFuture<void> RemoveContainer(const TCriDescriptor& descriptor) override + { + auto req = RuntimeApi_.RemoveContainer(); + req->set_container_id(descriptor.Id); + return req->Invoke().AsVoid(); + } + + TFuture<void> UpdateContainerResources(const TCriDescriptor& descriptor, const TCriContainerResources& resources) override + { + auto req = RuntimeApi_.UpdateContainerResources(); + req->set_container_id(descriptor.Id); + FillLinuxContainerResources(req->mutable_linux(), resources); + return req->Invoke().AsVoid(); + } + + void CleanNamespace() override + { + YT_VERIFY(Config_->Namespace); + auto pods = WaitFor(ListPodSandbox()) + .ValueOrThrow(); + + { + std::vector<TFuture<void>> futures; + futures.reserve(pods->items_size()); + for (const auto& pod : pods->items()) { + TCriPodDescriptor podDescriptor{.Name = pod.metadata().name(), .Id = pod.id() }; + futures.push_back(StopPodSandbox(podDescriptor)); + } + WaitFor(AllSucceeded(std::move(futures))) + .ThrowOnError(); + } + + { + std::vector<TFuture<void>> futures; + futures.reserve(pods->items_size()); + for (const auto& pod : pods->items()) { + TCriPodDescriptor podDescriptor{.Name = pod.metadata().name(), .Id = pod.id()}; + futures.push_back(RemovePodSandbox(podDescriptor)); + } + WaitFor(AllSucceeded(std::move(futures))) + .ThrowOnError(); + } + } + + void CleanPodSandbox(const TCriPodDescriptor& podDescriptor) override + { + auto containers = WaitFor(ListContainers([=] (NProto::ContainerFilter& filter) { + filter.set_pod_sandbox_id(podDescriptor.Id); + })) + .ValueOrThrow(); + + { + std::vector<TFuture<void>> futures; + futures.reserve(containers->containers_size()); + for (const auto& ct : containers->containers()) { + TCriDescriptor ctDescriptor{.Name = ct.metadata().name(), .Id = ct.id()}; + futures.push_back(StopContainer(ctDescriptor, TDuration::Zero())); + } + WaitFor(AllSucceeded(std::move(futures))) + .ThrowOnError(); + } + + { + std::vector<TFuture<void>> futures; + futures.reserve(containers->containers_size()); + for (const auto& ct : containers->containers()) { + TCriDescriptor ctDescriptor{.Name = ct.metadata().name(), .Id = ct.id()}; + futures.push_back(RemoveContainer(ctDescriptor)); + } + WaitFor(AllSucceeded(std::move(futures))) + .ThrowOnError(); + } + } + + TFuture<TCriImageApi::TRspListImagesPtr> ListImages( + std::function<void(NProto::ImageFilter&)> initFilter = nullptr) override + { + auto req = ImageApi_.ListImages(); + if (initFilter) { + initFilter(*req->mutable_filter()); + } + return req->Invoke(); + } + + TFuture<TCriImageApi::TRspImageStatusPtr> GetImageStatus( + const TCriImageDescriptor& image, + bool verbose = false) override + { + auto req = ImageApi_.ImageStatus(); + FillImageSpec(req->mutable_image(), image); + req->set_verbose(verbose); + return req->Invoke(); + } + + TFuture<TCriImageDescriptor> PullImage( + const TCriImageDescriptor& image, + bool always, + TCriAuthConfigPtr authConfig, + TCriPodSpecPtr podSpec) override + { + if (!always) { + return GetImageStatus(image) + .Apply(BIND([=, this, this_ = MakeStrong(this)] (const TCriImageApi::TRspImageStatusPtr& imageStatus) { + if (imageStatus->has_image()) { + return MakeFuture(TCriImageDescriptor{.Image = imageStatus->image().id()}); + } + return PullImage(image, /*always*/ true, authConfig, podSpec); + })); + } + + auto req = ImageApi_.PullImage(); + FillImageSpec(req->mutable_image(), image); + if (authConfig) { + FillAuthConfig(req->mutable_auth(), *authConfig); + } + if (podSpec) { + FillPodSandboxConfig(req->mutable_sandbox_config(), *podSpec); + } + return req->Invoke().Apply(BIND([] (const TCriImageApi::TRspPullImagePtr& rsp) -> TCriImageDescriptor { + return TCriImageDescriptor{.Image = rsp->image_ref()}; + })); + } + + TFuture<void> RemoveImage(const TCriImageDescriptor& image) override + { + auto req = ImageApi_.RemoveImage(); + FillImageSpec(req->mutable_image(), image); + return req->Invoke().AsVoid(); + } + + TProcessBasePtr CreateProcess( + const TString& path, + TCriContainerSpecPtr containerSpec, + const TCriPodDescriptor& podDescriptor, + TCriPodSpecPtr podSpec) override + { + return New<TCriProcess>(path, this, std::move(containerSpec), podDescriptor, std::move(podSpec)); + } + +private: + const TCriExecutorConfigPtr Config_; + TCriRuntimeApi RuntimeApi_; + TCriImageApi ImageApi_; + + void FillLinuxContainerResources(NProto::LinuxContainerResources* resources, const TCriContainerResources& spec) + { + auto* unified = resources->mutable_unified(); + + if (spec.CpuLimit) { + i64 period = Config_->CpuPeriod.MicroSeconds(); + i64 quota = period * *spec.CpuLimit; + + resources->set_cpu_period(period); + resources->set_cpu_quota(quota); + } + + if (spec.MemoryLimit) { + resources->set_memory_limit_in_bytes(*spec.MemoryLimit); + } + + if (spec.MemoryRequest) { + (*unified)["memory.low"] = ToString(*spec.MemoryRequest); + } + } + + void FillPodSandboxConfig(NProto::PodSandboxConfig* config, const TCriPodSpec& spec) + { + { + auto* metadata = config->mutable_metadata(); + metadata->set_namespace_(Config_->Namespace); + metadata->set_name(spec.Name); + metadata->set_uid(spec.Name); + } + + { + auto& labels = *config->mutable_labels(); + labels[YTPodNamespaceLabel] = Config_->Namespace; + labels[YTPodNameLabel] = spec.Name; + } + + { + auto* linux = config->mutable_linux(); + linux->set_cgroup_parent(GetPodCgroup(spec.Name)); + + auto* security = linux->mutable_security_context(); + auto* namespaces = security->mutable_namespace_options(); + namespaces->set_network(NProto::NODE); + } + } + + void FillImageSpec(NProto::ImageSpec* spec, const TCriImageDescriptor& image) + { + spec->set_image(image.Image); + } + + void FillAuthConfig(NProto::AuthConfig* auth, const TCriAuthConfig& authConfig) + { + if (!authConfig.Username.empty()) { + auth->set_username(authConfig.Username); + } + if (!authConfig.Password.empty()) { + auth->set_password(authConfig.Password); + } + if (!authConfig.Auth.empty()) { + auth->set_auth(authConfig.Auth); + } + if (!authConfig.ServerAddress.empty()) { + auth->set_server_address(authConfig.ServerAddress); + } + if (!authConfig.IdentityToken.empty()) { + auth->set_identity_token(authConfig.IdentityToken); + } + if (!authConfig.RegistryToken.empty()) { + auth->set_registry_token(authConfig.RegistryToken); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +ICriExecutorPtr CreateCriExecutor(TCriExecutorConfigPtr config) +{ + return New<TCriExecutor>( + std::move(config), + GetGrpcChannelFactory()); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers::NCri diff --git a/yt/yt/library/containers/cri/cri_executor.h b/yt/yt/library/containers/cri/cri_executor.h new file mode 100644 index 0000000000..de9741721f --- /dev/null +++ b/yt/yt/library/containers/cri/cri_executor.h @@ -0,0 +1,207 @@ +#pragma once + +#include "public.h" +#include "config.h" +#include "cri_api.h" + +#include <yt/yt/library/process/process.h> + +#include <yt/yt/core/ytree/yson_struct.h> + +namespace NYT::NContainers::NCri { + +//////////////////////////////////////////////////////////////////////////////// + +struct TCriDescriptor +{ + TString Name; + TString Id; +}; + +struct TCriPodDescriptor +{ + TString Name; + TString Id; +}; + +struct TCriImageDescriptor +{ + TString Image; +}; + +void FormatValue(TStringBuilderBase* builder, const TCriDescriptor& descriptor, TStringBuf spec); +void FormatValue(TStringBuilderBase* builder, const TCriPodDescriptor& descriptor, TStringBuf spec); +void FormatValue(TStringBuilderBase* builder, const TCriImageDescriptor& descriptor, TStringBuf spec); + +//////////////////////////////////////////////////////////////////////////////// + +struct TCriContainerResources +{ + std::optional<double> CpuLimit; + std::optional<double> CpuRequest; + std::optional<i64> MemoryLimit; + std::optional<i64> MemoryRequest; +}; + +struct TCriPodSpec + : public TRefCounted +{ + TString Name; + TCriContainerResources Resources; +}; + +DEFINE_REFCOUNTED_TYPE(TCriPodSpec) + +struct TCriBindMount +{ + TString ContainerPath; + TString HostPath; + bool ReadOnly; +}; + +struct TCriCredentials +{ + std::optional<i64> Uid; + std::optional<i64> Gid; + std::vector<i64> Groups; +}; + +struct TCriContainerSpec + : public TRefCounted +{ + TString Name; + + THashMap<TString, TString> Labels; + + TCriImageDescriptor Image; + + bool ReadOnlyRootFS; + + std::vector<TCriBindMount> BindMounts; + + TCriCredentials Credentials; + + TCriContainerResources Resources; + + //! Command to execute (i.e., entrypoint for docker). + std::vector<TString> Command; + + //! Arguments for the Command (i.e., command for docker). + std::vector<TString> Arguments; + + //! Current working directory of the command. + TString WorkingDirectory; + + //! Environment variable to set in the container. + THashMap<TString, TString> Environment; +}; + +DEFINE_REFCOUNTED_TYPE(TCriContainerSpec) + +//////////////////////////////////////////////////////////////////////////////// + +//! Wrapper around CRI gRPC API +//! +//! @see yt/yt/contrib/cri-api/k8s.io/cri-api/pkg/apis/runtime/v1/api.proto +//! @see https://github.com/kubernetes/cri-api +struct ICriExecutor + : public TRefCounted +{ + //! Returns status of the CRI runtime. + //! @param verbose fill field "info" with runtime-specific debug. + virtual TFuture<TCriRuntimeApi::TRspStatusPtr> GetRuntimeStatus(bool verbose = false) = 0; + + // PodSandbox + + virtual TString GetPodCgroup(TString podName) const = 0; + + virtual TFuture<TCriRuntimeApi::TRspListPodSandboxPtr> ListPodSandbox( + std::function<void(NProto::PodSandboxFilter&)> initFilter = nullptr) = 0; + + virtual TFuture<TCriRuntimeApi::TRspListContainersPtr> ListContainers( + std::function<void(NProto::ContainerFilter&)> initFilter = nullptr) = 0; + + virtual TFuture<void> ForEachPodSandbox( + const TCallback<void(const TCriPodDescriptor&, const NProto::PodSandbox&)>& callback, + std::function<void(NProto::PodSandboxFilter&)> initFilter = nullptr) = 0; + + virtual TFuture<void> ForEachContainer( + const TCallback<void(const TCriDescriptor&, const NProto::Container&)>& callback, + std::function<void(NProto::ContainerFilter&)> initFilter = nullptr) = 0; + + //! Returns status of the pod. + //! @param verbose fill field "info" with runtime-specific debug. + virtual TFuture<TCriRuntimeApi::TRspPodSandboxStatusPtr> GetPodSandboxStatus( + const TCriPodDescriptor& pod, bool verbose = false) = 0; + + //! Returns status of the container. + //! @param verbose fill "info" with runtime-specific debug information. + virtual TFuture<TCriRuntimeApi::TRspContainerStatusPtr> GetContainerStatus( + const TCriDescriptor& ct, bool verbose = false) = 0; + + virtual TFuture<TCriPodDescriptor> RunPodSandbox(TCriPodSpecPtr podSpec) = 0; + virtual TFuture<void> StopPodSandbox(const TCriPodDescriptor& pod) = 0; + virtual TFuture<void> RemovePodSandbox(const TCriPodDescriptor& pod) = 0; + virtual TFuture<void> UpdatePodResources( + const TCriPodDescriptor& pod, + const TCriContainerResources& resources) = 0; + + //! Remove all pods and containers in namespace managed by executor. + virtual void CleanNamespace() = 0; + + //! Remove all containers in one pod. + virtual void CleanPodSandbox(const TCriPodDescriptor& pod) = 0; + + virtual TFuture<TCriDescriptor> CreateContainer( + TCriContainerSpecPtr containerSpec, + const TCriPodDescriptor& pod, + TCriPodSpecPtr podSpec) = 0; + + virtual TFuture<void> StartContainer(const TCriDescriptor& ct) = 0; + + //! Stops container if it's running. + //! @param timeout defines timeout for graceful stop, timeout=0 - kill instantly. + virtual TFuture<void> StopContainer( + const TCriDescriptor& ct, + TDuration timeout = TDuration::Zero()) = 0; + + virtual TFuture<void> RemoveContainer(const TCriDescriptor& ct) = 0; + + virtual TFuture<void> UpdateContainerResources( + const TCriDescriptor& ct, + const TCriContainerResources& resources) = 0; + + virtual TFuture<TCriImageApi::TRspListImagesPtr> ListImages( + std::function<void(NProto::ImageFilter&)> initFilter = nullptr) = 0; + + //! Returns status of the image. + //! @param verbose fill field "info" with runtime-specific debug. + virtual TFuture<TCriImageApi::TRspImageStatusPtr> GetImageStatus( + const TCriImageDescriptor& image, + bool verbose = false) = 0; + + virtual TFuture<TCriImageDescriptor> PullImage( + const TCriImageDescriptor& image, + bool always = false, + TCriAuthConfigPtr authConfig = nullptr, + TCriPodSpecPtr podSpec = nullptr) = 0; + + virtual TFuture<void> RemoveImage(const TCriImageDescriptor& image) = 0; + + // FIXME(khlebnikov): temporary compat + virtual TProcessBasePtr CreateProcess( + const TString& path, + TCriContainerSpecPtr containerSpec, + const TCriPodDescriptor& pod, + TCriPodSpecPtr podSpec) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(ICriExecutor) + +//////////////////////////////////////////////////////////////////////////////// + +ICriExecutorPtr CreateCriExecutor(TCriExecutorConfigPtr config); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers::NCri diff --git a/yt/yt/library/containers/cri/private.h b/yt/yt/library/containers/cri/private.h new file mode 100644 index 0000000000..36fdf194f5 --- /dev/null +++ b/yt/yt/library/containers/cri/private.h @@ -0,0 +1,13 @@ +#pragma once + +#include <yt/yt/core/logging/log.h> + +namespace NYT::NContainers::NCri { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger Logger("Cri"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers::NCri diff --git a/yt/yt/library/containers/cri/public.h b/yt/yt/library/containers/cri/public.h new file mode 100644 index 0000000000..a12ee86d57 --- /dev/null +++ b/yt/yt/library/containers/cri/public.h @@ -0,0 +1,17 @@ +#pragma once + +#include <yt/yt/core/misc/intrusive_ptr.h> + +namespace NYT::NContainers::NCri { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_STRUCT(TCriPodSpec) +DECLARE_REFCOUNTED_STRUCT(TCriContainerSpec) +DECLARE_REFCOUNTED_CLASS(TCriExecutorConfig) +DECLARE_REFCOUNTED_CLASS(TCriAuthConfig) +DECLARE_REFCOUNTED_STRUCT(ICriExecutor) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers::NCri diff --git a/yt/yt/library/containers/cri/ya.make b/yt/yt/library/containers/cri/ya.make new file mode 100644 index 0000000000..dc9dd15a0b --- /dev/null +++ b/yt/yt/library/containers/cri/ya.make @@ -0,0 +1,22 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +PEERDIR( + yt/yt/core + yt/yt/core/rpc/grpc + yt/yt/contrib/cri-api +) + +SRCS( + cri_api.cpp + cri_executor.cpp + config.cpp +) + +ADDINCL( + ONE_LEVEL + yt/yt/contrib/cri-api +) + +END() diff --git a/yt/yt/library/containers/disk_manager/config.cpp b/yt/yt/library/containers/disk_manager/config.cpp new file mode 100644 index 0000000000..84484db630 --- /dev/null +++ b/yt/yt/library/containers/disk_manager/config.cpp @@ -0,0 +1,61 @@ +#include "config.h" + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +void TMockedDiskConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("disk_id", &TThis::DiskId) + .Default(); + registrar.Parameter("device_path", &TThis::DevicePath) + .Default(); + registrar.Parameter("device_name", &TThis::DeviceName) + .Default(); + registrar.Parameter("disk_model", &TThis::DiskModel) + .Default(); + registrar.Parameter("partition_fs_labels", &TThis::PartitionFsLabels) + .Default(); + registrar.Parameter("state", &TThis::State) + .Default(EDiskState::Ok); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TDiskInfoProviderConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("disk_ids", &TThis::DiskIds) + .Default(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TDiskManagerProxyConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("disk_manager_address", &TThis::DiskManagerAddress) + .Default("unix:/run/yandex-diskmanager/yandex-diskmanager.sock"); + registrar.Parameter("disk_manager_service_name", &TThis::DiskManagerServiceName) + .Default("diskman.DiskManager"); + + registrar.Parameter("is_mock", &TThis::IsMock) + .Default(false); + registrar.Parameter("mock_disks", &TThis::MockDisks) + .Default(); + registrar.Parameter("mock_yt_paths", &TThis::MockYtPaths) + .Default(); + + registrar.Parameter("request_timeout", &TThis::RequestTimeout) + .Default(TDuration::Seconds(10)); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TDiskManagerProxyDynamicConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("request_timeout", &TThis::RequestTimeout) + .Default(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/disk_manager/config.h b/yt/yt/library/containers/disk_manager/config.h new file mode 100644 index 0000000000..4f01d378b9 --- /dev/null +++ b/yt/yt/library/containers/disk_manager/config.h @@ -0,0 +1,79 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/yson_struct.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +struct TMockedDiskConfig + : public NYTree::TYsonStruct +{ + TString DiskId; + TString DevicePath; + TString DeviceName; + TString DiskModel; + std::vector<TString> PartitionFsLabels; + EDiskState State; + + REGISTER_YSON_STRUCT(TMockedDiskConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TMockedDiskConfig) + +//////////////////////////////////////////////////////////////////////////////// + +struct TDiskManagerProxyConfig + : public NYTree::TYsonStruct +{ + TString DiskManagerAddress; + TString DiskManagerServiceName; + + bool IsMock; + std::vector<TMockedDiskConfigPtr> MockDisks; + std::vector<TString> MockYtPaths; + + TDuration RequestTimeout; + + REGISTER_YSON_STRUCT(TDiskManagerProxyConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TDiskManagerProxyConfig) + +//////////////////////////////////////////////////////////////////////////////// + +struct TDiskInfoProviderConfig + : public NYTree::TYsonStruct +{ + std::vector<TString> DiskIds; + + REGISTER_YSON_STRUCT(TDiskInfoProviderConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TDiskInfoProviderConfig) + +//////////////////////////////////////////////////////////////////////////////// + +struct TDiskManagerProxyDynamicConfig + : public NYTree::TYsonStruct +{ + std::optional<TDuration> RequestTimeout; + + REGISTER_YSON_STRUCT(TDiskManagerProxyDynamicConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TDiskManagerProxyDynamicConfig) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/disk_manager/disk_info_provider.cpp b/yt/yt/library/containers/disk_manager/disk_info_provider.cpp new file mode 100644 index 0000000000..0ee3a5b6cb --- /dev/null +++ b/yt/yt/library/containers/disk_manager/disk_info_provider.cpp @@ -0,0 +1,64 @@ +#include "disk_info_provider.h" + +#include <yt/yt/library/containers/disk_manager/disk_manager_proxy.h> + +#include <yt/yt/core/actions/future.h> +#include <yt/yt/core/actions/invoker_util.h> + +#include <yt/yt/core/concurrency/public.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +TDiskInfoProvider::TDiskInfoProvider( + IDiskManagerProxyPtr diskManagerProxy, + TDiskInfoProviderConfigPtr config) + : DiskManagerProxy_(std::move(diskManagerProxy)) + , Config_(std::move(config)) +{ } + +const std::vector<TString>& TDiskInfoProvider::GetConfigDiskIds() const +{ + return Config_->DiskIds; +} + +TFuture<std::vector<TDiskInfo>> TDiskInfoProvider::GetYTDiskInfos() +{ + auto diskInfosFuture = DiskManagerProxy_->GetDisks(); + auto ytDiskPathsFuture = DiskManagerProxy_->GetYtDiskMountPaths(); + + // Merge two futures and filter disks placed in /yt. + return diskInfosFuture.Apply(BIND([=] (const std::vector<TDiskInfo>& diskInfos) { + return ytDiskPathsFuture.Apply(BIND([=] (const THashSet<TString>& diskPaths) { + std::vector<TDiskInfo> disks; + + for (const auto& diskInfo : diskInfos) { + for (const auto& partitionFsLabel : diskInfo.PartitionFsLabels) { + if (diskPaths.contains(partitionFsLabel)) { + disks.push_back(diskInfo); + break; + } + } + } + + return disks; + })); + })); +} + +TFuture<void> TDiskInfoProvider::RecoverDisk(const TString& diskId) +{ + return DiskManagerProxy_->RecoverDiskById(diskId, ERecoverPolicy::RecoverAuto); +} + +TFuture<void> TDiskInfoProvider::FailDisk( + const TString& diskId, + const TString& reason) +{ + return DiskManagerProxy_->FailDiskById(diskId, reason); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/disk_manager/disk_info_provider.h b/yt/yt/library/containers/disk_manager/disk_info_provider.h new file mode 100644 index 0000000000..b8d686438d --- /dev/null +++ b/yt/yt/library/containers/disk_manager/disk_info_provider.h @@ -0,0 +1,38 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/actions/future.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +class TDiskInfoProvider + : public TRefCounted +{ +public: + TDiskInfoProvider( + IDiskManagerProxyPtr diskManagerProxy, + TDiskInfoProviderConfigPtr config); + + const std::vector<TString>& GetConfigDiskIds() const; + + TFuture<std::vector<TDiskInfo>> GetYTDiskInfos(); + + TFuture<void> RecoverDisk(const TString& diskId); + + TFuture<void> FailDisk( + const TString& diskId, + const TString& reason); + +private: + const IDiskManagerProxyPtr DiskManagerProxy_; + const TDiskInfoProviderConfigPtr Config_; +}; + +DEFINE_REFCOUNTED_TYPE(TDiskInfoProvider) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/disk_manager/disk_manager_proxy.cpp b/yt/yt/library/containers/disk_manager/disk_manager_proxy.cpp new file mode 100644 index 0000000000..961723c51f --- /dev/null +++ b/yt/yt/library/containers/disk_manager/disk_manager_proxy.cpp @@ -0,0 +1,49 @@ +#include "disk_manager_proxy.h" + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +struct TDiskManagerProxyMock + : public IDiskManagerProxy +{ + virtual TFuture<THashSet<TString>> GetYtDiskMountPaths() + { + THROW_ERROR_EXCEPTION("Disk manager library is not available under this build configuration"); + } + + virtual TFuture<std::vector<TDiskInfo>> GetDisks() + { + THROW_ERROR_EXCEPTION("Disk manager library is not available under this build configuration"); + } + + virtual TFuture<void> RecoverDiskById(const TString& /*diskId*/, ERecoverPolicy /*recoverPolicy*/) + { + THROW_ERROR_EXCEPTION("Disk manager library is not available under this build configuration"); + } + + virtual TFuture<void> FailDiskById(const TString& /*diskId*/, const TString& /*reason*/) + { + THROW_ERROR_EXCEPTION("Disk manager library is not available under this build configuration"); + } + + virtual void OnDynamicConfigChanged(const TDiskManagerProxyDynamicConfigPtr& /*newConfig*/) + { + // Do nothing + } +}; + +DEFINE_REFCOUNTED_TYPE(TDiskManagerProxyMock) + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK IDiskManagerProxyPtr CreateDiskManagerProxy(TDiskManagerProxyConfigPtr /*config*/) +{ + // This implementation is used when disk_manager_proxy_impl.cpp is not linked. + + return New<TDiskManagerProxyMock>(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/disk_manager/disk_manager_proxy.h b/yt/yt/library/containers/disk_manager/disk_manager_proxy.h new file mode 100644 index 0000000000..d2da5c1873 --- /dev/null +++ b/yt/yt/library/containers/disk_manager/disk_manager_proxy.h @@ -0,0 +1,38 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/containers/disk_manager/config.h> + +#include <yt/yt/core/misc/atomic_object.h> + +#include <yt/yt/core/rpc/client.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +struct IDiskManagerProxy + : public virtual TRefCounted +{ + virtual TFuture<THashSet<TString>> GetYtDiskMountPaths() = 0; + + virtual TFuture<std::vector<TDiskInfo>> GetDisks() = 0; + + virtual TFuture<void> RecoverDiskById(const TString& diskId, ERecoverPolicy recoverPolicy) = 0; + + virtual TFuture<void> FailDiskById(const TString& diskId, const TString& reason) = 0; + + virtual void OnDynamicConfigChanged(const TDiskManagerProxyDynamicConfigPtr& newConfig) = 0; + +}; + +DEFINE_REFCOUNTED_TYPE(IDiskManagerProxy) + +//////////////////////////////////////////////////////////////////////////////// + +IDiskManagerProxyPtr CreateDiskManagerProxy(TDiskManagerProxyConfigPtr config); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/disk_manager/public.h b/yt/yt/library/containers/disk_manager/public.h new file mode 100644 index 0000000000..8a812638f3 --- /dev/null +++ b/yt/yt/library/containers/disk_manager/public.h @@ -0,0 +1,48 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(EDiskState, + ((Unknown) (0)) + ((Ok) (1)) + ((Failed) (2)) + ((RecoverWait) (3)) +); + +// 1. Remount all disk volumes to it's default state +// 2. Recreate disk layout, all data on disk will be lost +// 3. Replace phisical disk +DEFINE_ENUM(ERecoverPolicy, + ((RecoverAuto) (0)) + ((RecoverMount) (1)) + ((RecoverLayout) (2)) + ((RecoverDisk) (3)) +); + +struct TDiskInfo +{ + TString DiskId; + TString DevicePath; + TString DeviceName; + TString DiskModel; + THashSet<TString> PartitionFsLabels; + EDiskState State; +}; + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_STRUCT(TMockedDiskConfig) +DECLARE_REFCOUNTED_STRUCT(TDiskManagerProxyConfig) +DECLARE_REFCOUNTED_STRUCT(TDiskManagerProxyDynamicConfig) +DECLARE_REFCOUNTED_STRUCT(TDiskInfoProviderConfig) + +DECLARE_REFCOUNTED_STRUCT(IDiskManagerProxy) +DECLARE_REFCOUNTED_CLASS(TDiskInfoProvider) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/disk_manager/ya.make b/yt/yt/library/containers/disk_manager/ya.make new file mode 100644 index 0000000000..dcb260cf38 --- /dev/null +++ b/yt/yt/library/containers/disk_manager/ya.make @@ -0,0 +1,19 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +PEERDIR( + yt/yt/core +) + +SRCS( + config.cpp + disk_info_provider.cpp + disk_manager_proxy.cpp +) + +IF (NOT OPENSOURCE) + INCLUDE(ya_non_opensource.inc) +ENDIF() + +END() diff --git a/yt/yt/library/containers/instance.cpp b/yt/yt/library/containers/instance.cpp new file mode 100644 index 0000000000..0a56987e1b --- /dev/null +++ b/yt/yt/library/containers/instance.cpp @@ -0,0 +1,812 @@ +#ifdef __linux__ + +#include "instance.h" + +#include "porto_executor.h" +#include "private.h" + +#include <yt/yt/library/containers/cgroup.h> +#include <yt/yt/library/containers/config.h> + +#include <yt/yt/core/concurrency/scheduler.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/collection_helpers.h> +#include <yt/yt/core/misc/error.h> +#include <yt/yt/core/misc/fs.h> +#include <yt/yt/core/misc/proc.h> + +#include <library/cpp/porto/libporto.hpp> + +#include <util/stream/file.h> + +#include <util/string/cast.h> +#include <util/string/split.h> + +#include <util/system/env.h> + +#include <initializer_list> +#include <string> + +namespace NYT::NContainers { + +using namespace NConcurrency; +using namespace NNet; + +//////////////////////////////////////////////////////////////////////////////// + +namespace NDetail { + +// Porto passes command string to wordexp, where quota (') symbol +// is delimiter. So we must replace it with concatenation ('"'"'). +TString EscapeForWordexp(const char* in) +{ + TString buffer; + while (*in) { + if (*in == '\'') { + buffer.append(R"('"'"')"); + } else { + buffer.append(*in); + } + in++; + } + return buffer; +} + +i64 Extract( + const TString& input, + const TString& pattern, + const TString& terminator = "\n") +{ + auto start = input.find(pattern) + pattern.length(); + auto end = input.find(terminator, start); + return std::stol(input.substr(start, (end == input.npos) ? end : end - start)); +} + +i64 ExtractSum( + const TString& input, + const TString& pattern, + const TString& delimiter, + const TString& terminator = "\n") +{ + i64 sum = 0; + TString::size_type pos = 0; + while (pos < input.length()) { + pos = input.find(pattern, pos); + if (pos == input.npos) { + break; + } + pos += pattern.length(); + + pos = input.find(delimiter, pos); + if (pos == input.npos) { + break; + } + + pos++; + auto end = input.find(terminator, pos); + sum += std::stol(input.substr(pos, (end == input.npos) ? end : end - pos)); + } + return sum; +} + +using TPortoStatRule = std::pair<TString, std::function<i64(const TString& input)>>; + +static const std::function<i64(const TString&)> LongExtractor = [] (const TString& in) { + return std::stol(in); +}; + +static const std::function<i64(const TString&)> CoreNsPerSecondExtractor = [] (const TString& in) { + int pos = in.find("c", 0); + return (std::stod(in.substr(0, pos))) * 1'000'000'000; +}; + +static const std::function<i64(const TString&)> GetIOStatExtractor(const TString& rwMode = "") +{ + return [rwMode] (const TString& in) { + return ExtractSum(in, "hw", rwMode + ":", ";"); + }; +} + +static const std::function<i64(const TString&)> GetStatByKeyExtractor(const TString& statKey) +{ + return [statKey] (const TString& in) { + return Extract(in, statKey); + }; +} + +const THashMap<EStatField, TPortoStatRule> PortoStatRules = { + {EStatField::CpuUsage, {"cpu_usage", LongExtractor}}, + {EStatField::CpuSystemUsage, {"cpu_usage_system", LongExtractor}}, + {EStatField::CpuWait, {"cpu_wait", LongExtractor}}, + {EStatField::CpuThrottled, {"cpu_throttled", LongExtractor}}, + {EStatField::ThreadCount, {"thread_count", LongExtractor}}, + {EStatField::CpuLimit, {"cpu_limit_bound", CoreNsPerSecondExtractor}}, + {EStatField::CpuGuarantee, {"cpu_guarantee_bound", CoreNsPerSecondExtractor}}, + {EStatField::Rss, {"memory.stat", GetStatByKeyExtractor("total_rss")}}, + {EStatField::MappedFile, {"memory.stat", GetStatByKeyExtractor("total_mapped_file")}}, + {EStatField::MinorPageFaults, {"minor_faults", LongExtractor}}, + {EStatField::MajorPageFaults, {"major_faults", LongExtractor}}, + {EStatField::FileCacheUsage, {"cache_usage", LongExtractor}}, + {EStatField::AnonMemoryUsage, {"anon_usage", LongExtractor}}, + {EStatField::AnonMemoryLimit, {"anon_limit_total", LongExtractor}}, + {EStatField::MemoryUsage, {"memory_usage", LongExtractor}}, + {EStatField::MemoryGuarantee, {"memory_guarantee", LongExtractor}}, + {EStatField::MemoryLimit, {"memory_limit_total", LongExtractor}}, + {EStatField::MaxMemoryUsage, {"memory.max_usage_in_bytes", LongExtractor}}, + {EStatField::OomKills, {"oom_kills", LongExtractor}}, + {EStatField::OomKillsTotal, {"oom_kills_total", LongExtractor}}, + + {EStatField::IOReadByte, {"io_read", GetIOStatExtractor()}}, + {EStatField::IOWriteByte, {"io_write", GetIOStatExtractor()}}, + {EStatField::IOBytesLimit, {"io_limit", GetIOStatExtractor()}}, + {EStatField::IOReadOps, {"io_read_ops", GetIOStatExtractor()}}, + {EStatField::IOWriteOps, {"io_write_ops", GetIOStatExtractor()}}, + {EStatField::IOOps, {"io_ops", GetIOStatExtractor()}}, + {EStatField::IOOpsLimit, {"io_ops_limit", GetIOStatExtractor()}}, + {EStatField::IOTotalTime, {"io_time", GetIOStatExtractor()}}, + {EStatField::IOWaitTime, {"io_wait", GetIOStatExtractor()}}, + + {EStatField::NetTxBytes, {"net_tx_bytes[veth]", LongExtractor}}, + {EStatField::NetTxPackets, {"net_tx_packets[veth]", LongExtractor}}, + {EStatField::NetTxDrops, {"net_tx_drops[veth]", LongExtractor}}, + {EStatField::NetTxLimit, {"net_limit[veth]", LongExtractor}}, + {EStatField::NetRxBytes, {"net_rx_bytes[veth]", LongExtractor}}, + {EStatField::NetRxPackets, {"net_rx_packets[veth]", LongExtractor}}, + {EStatField::NetRxDrops, {"net_rx_drops[veth]", LongExtractor}}, + {EStatField::NetRxLimit, {"net_rx_limit[veth]", LongExtractor}}, +}; + +std::optional<TString> GetParentName(const TString& name) +{ + if (name.empty()) { + return std::nullopt; + } + + auto slashPosition = name.rfind('/'); + if (slashPosition == TString::npos) { + return ""; + } + + return name.substr(0, slashPosition); +} + +std::optional<TString> GetRootName(const TString& name) +{ + if (name.empty()) { + return std::nullopt; + } + + if (name == "/") { + return name; + } + + auto slashPosition = name.find('/'); + if (slashPosition == TString::npos) { + return name; + } + + return name.substr(0, slashPosition); +} + +} // namespace NDetail + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoInstanceLauncher + : public IInstanceLauncher +{ +public: + TPortoInstanceLauncher(const TString& name, IPortoExecutorPtr executor) + : Executor_(std::move(executor)) + , Logger(ContainersLogger.WithTag("Container: %v", name)) + { + Spec_.Name = name; + Spec_.CGroupControllers = { + "freezer", + "cpu", + "cpuacct", + "net_cls", + "blkio", + "devices", + "pids" + }; + } + + const TString& GetName() const override + { + return Spec_.Name; + } + + bool HasRoot() const override + { + return static_cast<bool>(Spec_.RootFS); + } + + void SetStdIn(const TString& inputPath) override + { + Spec_.StdinPath = inputPath; + } + + void SetStdOut(const TString& outPath) override + { + Spec_.StdoutPath = outPath; + } + + void SetStdErr(const TString& errorPath) override + { + Spec_.StderrPath = errorPath; + } + + void SetCwd(const TString& pwd) override + { + Spec_.CurrentWorkingDirectory = pwd; + } + + void SetCoreDumpHandler(const std::optional<TString>& handler) override + { + if (handler) { + Spec_.CoreCommand = *handler; + Spec_.EnableCoreDumps = true; + } else { + Spec_.EnableCoreDumps = false; + } + } + + void SetRoot(const TRootFS& rootFS) override + { + Spec_.RootFS = rootFS; + } + + void SetThreadLimit(i64 threadLimit) override + { + Spec_.ThreadLimit = threadLimit; + } + + void SetDevices(const std::vector<TDevice>& devices) override + { + Spec_.Devices = devices; + } + + void SetEnablePorto(EEnablePorto enablePorto) override + { + Spec_.EnablePorto = enablePorto; + } + + void SetIsolate(bool isolate) override + { + Spec_.Isolate = isolate; + } + + void EnableMemoryTracking() override + { + Spec_.CGroupControllers.push_back("memory"); + } + + void SetGroup(int groupId) override + { + Spec_.GroupId = groupId; + } + + void SetUser(const TString& user) override + { + Spec_.User = user; + } + + void SetIPAddresses(const std::vector<NNet::TIP6Address>& addresses, bool enableNat64) override + { + Spec_.IPAddresses = addresses; + Spec_.EnableNat64 = enableNat64; + Spec_.DisableNetwork = false; + } + + void DisableNetwork() override + { + Spec_.DisableNetwork = true; + Spec_.IPAddresses.clear(); + Spec_.EnableNat64 = false; + } + + void SetHostName(const TString& hostName) override + { + Spec_.HostName = hostName; + } + + TFuture<IInstancePtr> Launch( + const TString& path, + const std::vector<TString>& args, + const THashMap<TString, TString>& env) override + { + TStringBuilder commandBuilder; + auto append = [&] (const auto& value) { + commandBuilder.AppendString("'"); + commandBuilder.AppendString(NDetail::EscapeForWordexp(value.c_str())); + commandBuilder.AppendString("' "); + }; + + append(path); + for (const auto& arg : args) { + append(arg); + } + + Spec_.Command = commandBuilder.Flush(); + YT_LOG_DEBUG("Executing Porto container (Name: %v, Command: %v)", + Spec_.Name, + Spec_.Command); + + Spec_.Env = env; + + auto onContainerCreated = [this, this_ = MakeStrong(this)] (const TError& error) -> IInstancePtr { + if (!error.IsOK()) { + THROW_ERROR_EXCEPTION(EErrorCode::FailedToStartContainer, "Unable to start container") + << error; + } + + return GetPortoInstance(Executor_, Spec_.Name); + }; + + return Executor_->CreateContainer(Spec_, /* start */ true) + .Apply(BIND(onContainerCreated)); + } + +private: + IPortoExecutorPtr Executor_; + TRunnableContainerSpec Spec_; + const NLogging::TLogger Logger; +}; + +IInstanceLauncherPtr CreatePortoInstanceLauncher(const TString& name, IPortoExecutorPtr executor) +{ + return New<TPortoInstanceLauncher>(name, executor); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoInstance + : public IInstance +{ +public: + static IInstancePtr GetSelf(IPortoExecutorPtr executor) + { + return New<TPortoInstance>(GetSelfContainerName(executor), executor); + } + + static IInstancePtr GetInstance(IPortoExecutorPtr executor, const TString& name) + { + return New<TPortoInstance>(name, executor); + } + + void Kill(int signal) override + { + auto error = WaitFor(Executor_->KillContainer(Name_, signal)); + // Killing already finished process is not an error. + if (error.FindMatching(EPortoErrorCode::InvalidState)) { + return; + } + if (!error.IsOK()) { + THROW_ERROR_EXCEPTION("Failed to send signal to Porto instance") + << TErrorAttribute("signal", signal) + << TErrorAttribute("container", Name_) + << error; + } + } + + void Destroy() override + { + WaitFor(Executor_->DestroyContainer(Name_)) + .ThrowOnError(); + Destroyed_ = true; + } + + void Stop() override + { + WaitFor(Executor_->StopContainer(Name_)) + .ThrowOnError(); + } + + TErrorOr<ui64> CalculateCpuUserUsage( + TErrorOr<ui64>& cpuUsage, + TErrorOr<ui64>& cpuSystemUsage) const + { + if (cpuUsage.IsOK() && cpuSystemUsage.IsOK()) { + return cpuUsage.Value() > cpuSystemUsage.Value() ? cpuUsage.Value() - cpuSystemUsage.Value() : 0; + } else if (cpuUsage.IsOK()) { + return TError("Missing property %Qlv in Porto response", EStatField::CpuSystemUsage) + << TErrorAttribute("container", Name_); + } else { + return TError("Missing property %Qlv in Porto response", EStatField::CpuUsage) + << TErrorAttribute("container", Name_); + } + } + + TResourceUsage GetResourceUsage( + const std::vector<EStatField>& fields) const override + { + std::vector<TString> properties; + properties.push_back("absolute_name"); + + bool userTimeRequested = false; + bool contextSwitchesRequested = false; + for (auto field : fields) { + if (auto it = NDetail::PortoStatRules.find(field)) { + const auto& rule = it->second; + properties.push_back(rule.first); + } else if (field == EStatField::ContextSwitchesDelta || field == EStatField::ContextSwitches) { + contextSwitchesRequested = true; + } else if (field == EStatField::CpuUserUsage) { + userTimeRequested = true; + } else { + THROW_ERROR_EXCEPTION("Unknown resource field %Qlv requested", field) + << TErrorAttribute("container", Name_); + } + } + + auto propertyMap = WaitFor(Executor_->GetContainerProperties(Name_, properties)) + .ValueOrThrow(); + + TResourceUsage result; + + for (auto field : fields) { + auto ruleIt = NDetail::PortoStatRules.find(field); + if (ruleIt == NDetail::PortoStatRules.end()) { + continue; + } + + const auto& [property, callback] = ruleIt->second; + auto& record = result[field]; + if (auto responseIt = propertyMap.find(property); responseIt != propertyMap.end()) { + const auto& valueOrError = responseIt->second; + if (valueOrError.IsOK()) { + const auto& value = valueOrError.Value(); + + try { + record = callback(value); + } catch (const std::exception& ex) { + record = TError("Error parsing Porto property %Qlv", field) + << TErrorAttribute("container", Name_) + << TErrorAttribute("property_value", value) + << ex; + } + } else { + record = TError("Error getting Porto property %Qlv", field) + << TErrorAttribute("container", Name_) + << valueOrError; + } + } else { + record = TError("Missing property %Qlv in Porto response", field) + << TErrorAttribute("container", Name_); + } + } + + // We should maintain context switch information even if this field + // is not requested since metrics of individual containers can go up and down. + auto subcontainers = WaitFor(Executor_->ListSubcontainers(Name_, /*includeRoot*/ true)) + .ValueOrThrow(); + + auto metricMap = WaitFor(Executor_->GetContainerMetrics(subcontainers, "ctxsw")) + .ValueOrThrow(); + + // TODO(don-dron): remove diff calculation from GetResourceUsage, because GetResourceUsage must return only snapshot stat. + { + auto guard = Guard(ContextSwitchMapLock_); + + for (const auto& [container, newValue] : metricMap) { + auto& prevValue = ContextSwitchMap_[container]; + TotalContextSwitches_ += std::max<i64>(0LL, newValue - prevValue); + prevValue = newValue; + } + + if (contextSwitchesRequested) { + result[EStatField::ContextSwitchesDelta] = TotalContextSwitches_; + } + } + + if (contextSwitchesRequested) { + ui64 totalContextSwitches = 0; + + for (const auto& [container, newValue] : metricMap) { + totalContextSwitches += std::max<ui64>(0UL, newValue); + } + + result[EStatField::ContextSwitches] = totalContextSwitches; + } + + if (userTimeRequested) { + result[EStatField::CpuUserUsage] = CalculateCpuUserUsage( + result[EStatField::CpuUsage], + result[EStatField::CpuSystemUsage]); + } + + return result; + } + + TResourceLimits GetResourceLimits() const override + { + std::vector<TString> properties; + static TString memoryLimitProperty = "memory_limit_total"; + static TString cpuLimitProperty = "cpu_limit_bound"; + static TString cpuGuaranteeProperty = "cpu_guarantee_bound"; + properties.push_back(memoryLimitProperty); + properties.push_back(cpuLimitProperty); + properties.push_back(cpuGuaranteeProperty); + + auto responseOrError = WaitFor(Executor_->GetContainerProperties(Name_, properties)); + THROW_ERROR_EXCEPTION_IF_FAILED(responseOrError, "Failed to get Porto container resource limits"); + + const auto& response = responseOrError.Value(); + + const auto& memoryLimitRsp = response.at(memoryLimitProperty); + THROW_ERROR_EXCEPTION_IF_FAILED(memoryLimitRsp, "Failed to get memory limit from Porto"); + + i64 memoryLimit; + if (!TryFromString<i64>(memoryLimitRsp.Value(), memoryLimit)) { + THROW_ERROR_EXCEPTION("Failed to parse memory limit value from Porto") + << TErrorAttribute(memoryLimitProperty, memoryLimitRsp.Value()); + } + + const auto& cpuLimitRsp = response.at(cpuLimitProperty); + THROW_ERROR_EXCEPTION_IF_FAILED(cpuLimitRsp, "Failed to get CPU limit from Porto"); + + double cpuLimit; + YT_VERIFY(cpuLimitRsp.Value().EndsWith('c')); + auto cpuLimitValue = TStringBuf(cpuLimitRsp.Value().begin(), cpuLimitRsp.Value().size() - 1); + if (!TryFromString<double>(cpuLimitValue, cpuLimit)) { + THROW_ERROR_EXCEPTION("Failed to parse CPU limit value from Porto") + << TErrorAttribute(cpuLimitProperty, cpuLimitRsp.Value()); + } + + const auto& cpuGuaranteeRsp = response.at(cpuGuaranteeProperty); + THROW_ERROR_EXCEPTION_IF_FAILED(cpuGuaranteeRsp, "Failed to get CPU guarantee from Porto"); + + double cpuGuarantee; + if (!cpuGuaranteeRsp.Value()) { + // XXX(ignat): hack for missing response from Porto. + cpuGuarantee = 0.0; + } else { + YT_VERIFY(cpuGuaranteeRsp.Value().EndsWith('c')); + auto cpuGuaranteeValue = TStringBuf(cpuGuaranteeRsp.Value().begin(), cpuGuaranteeRsp.Value().size() - 1); + if (!TryFromString<double>(cpuGuaranteeValue, cpuGuarantee)) { + THROW_ERROR_EXCEPTION("Failed to parse CPU guarantee value from Porto") + << TErrorAttribute(cpuGuaranteeProperty, cpuGuaranteeRsp.Value()); + } + } + + return TResourceLimits{ + .CpuLimit = cpuLimit, + .CpuGuarantee = cpuGuarantee, + .Memory = memoryLimit, + }; + } + + void SetCpuGuarantee(double cores) override + { + SetProperty("cpu_guarantee", ToString(cores) + "c"); + } + + void SetCpuLimit(double cores) override + { + SetProperty("cpu_limit", ToString(cores) + "c"); + } + + void SetCpuWeight(double weight) override + { + SetProperty("cpu_weight", weight); + } + + void SetMemoryGuarantee(i64 memoryGuarantee) override + { + SetProperty("memory_guarantee", memoryGuarantee); + } + + void SetIOWeight(double weight) override + { + SetProperty("io_weight", weight); + } + + void SetIOThrottle(i64 operations) override + { + SetProperty("io_ops_limit", operations); + } + + TString GetStderr() const override + { + return *WaitFor(Executor_->GetContainerProperty(Name_, "stderr")) + .ValueOrThrow(); + } + + TString GetName() const override + { + return Name_; + } + + std::optional<TString> GetParentName() const override + { + return NDetail::GetParentName(Name_); + } + + std::optional<TString> GetRootName() const override + { + return NDetail::GetRootName(Name_); + } + + pid_t GetPid() const override + { + auto pid = *WaitFor(Executor_->GetContainerProperty(Name_, "root_pid")) + .ValueOrThrow(); + return std::stoi(pid); + } + + i64 GetMajorPageFaultCount() const override + { + auto faults = WaitFor(Executor_->GetContainerProperty(Name_, "major_faults")) + .ValueOrThrow(); + return faults + ? std::stoll(*faults) + : 0; + } + + double GetCpuGuarantee() const override + { + auto result = WaitFor(Executor_->GetContainerProperty(Name_, "cpu_guarantee")) + .ValueOrThrow(); + return result + ? std::stod(*result) + : 0; + } + + std::vector<pid_t> GetPids() const override + { + auto getPidCgroup = [&] (const TString& cgroups) { + for (TStringBuf cgroup : StringSplitter(cgroups).SplitByString("; ")) { + if (cgroup.StartsWith("pids:")) { + auto startPosition = cgroup.find('/'); + YT_VERIFY(startPosition != TString::npos); + return cgroup.substr(startPosition); + } + } + THROW_ERROR_EXCEPTION("Pids cgroup not found for container %Qv", GetName()) + << TErrorAttribute("cgroups", cgroups); + }; + + auto cgroups = *WaitFor(Executor_->GetContainerProperty(Name_, "cgroups")) + .ValueOrThrow(); + // Porto returns full cgroup name, with mount prefix, such as "/sys/fs/cgroup/pids". + auto instanceCgroup = getPidCgroup(cgroups); + + std::vector<pid_t> pids; + for (auto pid : ListPids()) { + std::map<TString, TString> cgroups; + try { + cgroups = GetProcessCGroups(pid); + } catch (const std::exception& ex) { + YT_LOG_DEBUG(ex, "Failed to get CGroups for process (Pid: %v)", pid); + continue; + } + + // Pid cgroups are returned in short form. + auto processPidCgroup = cgroups["pids"]; + if (!processPidCgroup.empty() && instanceCgroup.EndsWith(processPidCgroup)) { + pids.push_back(pid); + } + } + + return pids; + } + + TFuture<void> Wait() override + { + return Executor_->PollContainer(Name_) + .Apply(BIND([] (int status) { + StatusToError(status) + .ThrowOnError(); + })); + } + +private: + const TString Name_; + const IPortoExecutorPtr Executor_; + const NLogging::TLogger Logger; + + bool Destroyed_ = false; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, ContextSwitchMapLock_); + mutable i64 TotalContextSwitches_ = 0; + mutable THashMap<TString, i64> ContextSwitchMap_; + + TPortoInstance(TString name, IPortoExecutorPtr executor) + : Name_(std::move(name)) + , Executor_(std::move(executor)) + , Logger(ContainersLogger.WithTag("Container: %v", Name_)) + { } + + void SetProperty(const TString& key, const TString& value) + { + WaitFor(Executor_->SetContainerProperty(Name_, key, value)) + .ThrowOnError(); + } + + void SetProperty(const TString& key, i64 value) + { + SetProperty(key, ToString(value)); + } + + void SetProperty(const TString& key, double value) + { + SetProperty(key, ToString(value)); + } + + DECLARE_NEW_FRIEND() +}; + +//////////////////////////////////////////////////////////////////////////////// + +TString GetSelfContainerName(const IPortoExecutorPtr& executor) +{ + try { + auto properties = WaitFor(executor->GetContainerProperties( + "self", + std::vector<TString>{"absolute_name", "absolute_namespace"})) + .ValueOrThrow(); + + auto absoluteName = properties.at("absolute_name") + .ValueOrThrow(); + auto absoluteNamespace = properties.at("absolute_namespace") + .ValueOrThrow(); + + if (absoluteName == "/") { + return absoluteName; + } + + if (absoluteName.length() < absoluteNamespace.length()) { + YT_VERIFY(absoluteName + "/" == absoluteNamespace); + return ""; + } else { + YT_VERIFY(absoluteName.StartsWith(absoluteNamespace)); + return absoluteName.substr(absoluteNamespace.length()); + } + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Failed to get name for container \"self\"") + << ex; + } +} + +IInstancePtr GetSelfPortoInstance(IPortoExecutorPtr executor) +{ + return TPortoInstance::GetSelf(executor); +} + +IInstancePtr GetPortoInstance(IPortoExecutorPtr executor, const TString& name) +{ + return TPortoInstance::GetInstance(executor, name); +} + +IInstancePtr GetRootPortoInstance(IPortoExecutorPtr executor) +{ + auto self = GetSelfPortoInstance(executor); + return TPortoInstance::GetInstance(executor, *self->GetRootName()); +} + +double GetSelfPortoInstanceVCpuFactor() +{ + auto config = New<TPortoExecutorDynamicConfig>(); + auto executorPtr = CreatePortoExecutor(config, ""); + auto currentContainer = GetSelfPortoInstance(executorPtr); + double cpuLimit = currentContainer->GetResourceLimits().CpuLimit; + if (cpuLimit <= 0) { + THROW_ERROR_EXCEPTION("Cpu limit must be greater than 0"); + } + + // DEPLOY_VCPU_LIMIT stores value in millicores + if (TString vcpuLimitStr = GetEnv("DEPLOY_VCPU_LIMIT"); !vcpuLimitStr.Empty()) { + double vcpuLimit = FromString<double>(vcpuLimitStr) / 1000.0; + return vcpuLimit / cpuLimit; + } + THROW_ERROR_EXCEPTION("Failed to get vcpu limit from env variable"); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers + +#endif diff --git a/yt/yt/library/containers/instance.h b/yt/yt/library/containers/instance.h new file mode 100644 index 0000000000..ff6e0b3ce1 --- /dev/null +++ b/yt/yt/library/containers/instance.h @@ -0,0 +1,168 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/net/address.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +using TResourceUsage = THashMap<EStatField, TErrorOr<ui64>>; + +const std::vector<EStatField> InstanceStatFields{ + EStatField::CpuUsage, + EStatField::CpuUserUsage, + EStatField::CpuSystemUsage, + EStatField::CpuWait, + EStatField::CpuThrottled, + EStatField::ContextSwitches, + EStatField::ContextSwitchesDelta, + EStatField::ThreadCount, + EStatField::CpuLimit, + EStatField::CpuGuarantee, + + EStatField::Rss, + EStatField::MappedFile, + EStatField::MajorPageFaults, + EStatField::MinorPageFaults, + EStatField::FileCacheUsage, + EStatField::AnonMemoryUsage, + EStatField::AnonMemoryLimit, + EStatField::MemoryUsage, + EStatField::MemoryGuarantee, + EStatField::MemoryLimit, + EStatField::MaxMemoryUsage, + EStatField::OomKills, + EStatField::OomKillsTotal, + + EStatField::IOReadByte, + EStatField::IOWriteByte, + EStatField::IOBytesLimit, + EStatField::IOReadOps, + EStatField::IOWriteOps, + EStatField::IOOps, + EStatField::IOOpsLimit, + EStatField::IOTotalTime, + EStatField::IOWaitTime, + + EStatField::NetTxBytes, + EStatField::NetTxPackets, + EStatField::NetTxDrops, + EStatField::NetTxLimit, + EStatField::NetRxBytes, + EStatField::NetRxPackets, + EStatField::NetRxDrops, + EStatField::NetRxLimit, +}; + +struct TResourceLimits +{ + double CpuLimit; + double CpuGuarantee; + i64 Memory; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct IInstanceLauncher + : public TRefCounted +{ + virtual bool HasRoot() const = 0; + virtual const TString& GetName() const = 0; + + virtual void SetStdIn(const TString& inputPath) = 0; + virtual void SetStdOut(const TString& outPath) = 0; + virtual void SetStdErr(const TString& errorPath) = 0; + virtual void SetCwd(const TString& pwd) = 0; + + // Null core dump handler implies disabled core dumps. + virtual void SetCoreDumpHandler(const std::optional<TString>& handler) = 0; + virtual void SetRoot(const TRootFS& rootFS) = 0; + + virtual void SetThreadLimit(i64 threadLimit) = 0; + virtual void SetDevices(const std::vector<TDevice>& devices) = 0; + + virtual void SetEnablePorto(EEnablePorto enablePorto) = 0; + virtual void SetIsolate(bool isolate) = 0; + virtual void EnableMemoryTracking() = 0; + virtual void SetGroup(int groupId) = 0; + virtual void SetUser(const TString& user) = 0; + virtual void SetIPAddresses( + const std::vector<NNet::TIP6Address>& addresses, + bool enableNat64 = false) = 0; + virtual void DisableNetwork() = 0; + virtual void SetHostName(const TString& hostName) = 0; + + virtual TFuture<IInstancePtr> Launch( + const TString& path, + const std::vector<TString>& args, + const THashMap<TString, TString>& env) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(IInstanceLauncher) + +#ifdef _linux_ +IInstanceLauncherPtr CreatePortoInstanceLauncher(const TString& name, IPortoExecutorPtr executor); +#endif + +//////////////////////////////////////////////////////////////////////////////// + +struct IInstance + : public TRefCounted +{ + virtual void Kill(int signal) = 0; + virtual void Stop() = 0; + virtual void Destroy() = 0; + + virtual TResourceUsage GetResourceUsage( + const std::vector<EStatField>& fields = InstanceStatFields) const = 0; + virtual TResourceLimits GetResourceLimits() const = 0; + virtual void SetCpuGuarantee(double cores) = 0; + virtual void SetCpuLimit(double cores) = 0; + virtual void SetCpuWeight(double weight) = 0; + virtual void SetIOWeight(double weight) = 0; + virtual void SetIOThrottle(i64 operations) = 0; + virtual void SetMemoryGuarantee(i64 memoryGuarantee) = 0; + + virtual TString GetStderr() const = 0; + + virtual TString GetName() const = 0; + virtual std::optional<TString> GetParentName() const = 0; + virtual std::optional<TString> GetRootName() const = 0; + + //! Returns externally visible pid of the root process inside container. + //! Throws if container is not running. + virtual pid_t GetPid() const = 0; + //! Returns the list of externally visible pids of processes running inside container. + virtual std::vector<pid_t> GetPids() const = 0; + + virtual i64 GetMajorPageFaultCount() const = 0; + virtual double GetCpuGuarantee() const = 0; + + //! Future is set when container reaches terminal state (stopped or dead). + //! Resulting error is OK iff container exited with code 0. + virtual TFuture<void> Wait() = 0; +}; + +DEFINE_REFCOUNTED_TYPE(IInstance) + +//////////////////////////////////////////////////////////////////////////////// + +#ifdef _linux_ +TString GetSelfContainerName(const IPortoExecutorPtr& executor); + +IInstancePtr GetSelfPortoInstance(IPortoExecutorPtr executor); +IInstancePtr GetRootPortoInstance(IPortoExecutorPtr executor); +IInstancePtr GetPortoInstance(IPortoExecutorPtr executor, const TString& name); + +//! Works only in Yandex.Deploy pod environment where env DEPLOY_VCPU_LIMIT is set. +//! Throws if this env is absent. +double GetSelfPortoInstanceVCpuFactor(); +#endif + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/instance_limits_tracker.cpp b/yt/yt/library/containers/instance_limits_tracker.cpp new file mode 100644 index 0000000000..55ef7d2d67 --- /dev/null +++ b/yt/yt/library/containers/instance_limits_tracker.cpp @@ -0,0 +1,179 @@ +#include "public.h" +#include "instance_limits_tracker.h" +#include "instance.h" +#include "porto_resource_tracker.h" +#include "private.h" + +#include <yt/yt/core/concurrency/periodic_executor.h> + +#include <yt/yt/core/ytree/fluent.h> +#include <yt/yt/core/ytree/ypath_service.h> + +namespace NYT::NContainers { + +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = ContainersLogger; + +//////////////////////////////////////////////////////////////////////////////// + +TInstanceLimitsTracker::TInstanceLimitsTracker( + IInstancePtr instance, + IInstancePtr root, + IInvokerPtr invoker, + TDuration updatePeriod) + : Invoker_(std::move(invoker)) + , Executor_(New<NConcurrency::TPeriodicExecutor>( + Invoker_, + BIND(&TInstanceLimitsTracker::DoUpdateLimits, MakeWeak(this)), + updatePeriod)) +{ +#ifdef _linux_ + SelfTracker_ = New<TPortoResourceTracker>(std::move(instance), updatePeriod / 2); + RootTracker_ = New<TPortoResourceTracker>(std::move(root), updatePeriod / 2); +#else + Y_UNUSED(instance); + Y_UNUSED(root); +#endif +} + +void TInstanceLimitsTracker::Start() +{ + if (!Running_) { + Executor_->Start(); + Running_ = true; + YT_LOG_INFO("Instance limits tracker started"); + } +} + +void TInstanceLimitsTracker::Stop() +{ + if (Running_) { + YT_UNUSED_FUTURE(Executor_->Stop()); + Running_ = false; + YT_LOG_INFO("Instance limits tracker stopped"); + } +} + +void TInstanceLimitsTracker::DoUpdateLimits() +{ + VERIFY_INVOKER_AFFINITY(Invoker_); + +#ifdef _linux_ + YT_LOG_DEBUG("Checking for instance limits update"); + + auto setIfOk = [] (auto* destination, const auto& valueOrError, const TString& fieldName, bool alert = true) { + if (valueOrError.IsOK()) { + *destination = valueOrError.Value(); + } else { + YT_LOG_ALERT_IF(alert, valueOrError, "Failed to get container property (Field: %v)", + fieldName); + + YT_LOG_DEBUG(valueOrError, "Failed to get container property (Field: %v)", + fieldName); + } + }; + + try { + auto memoryStatistics = SelfTracker_->GetMemoryStatistics(); + auto netStatistics = RootTracker_->GetNetworkStatistics(); + auto cpuStatistics = SelfTracker_->GetCpuStatistics(); + + setIfOk(&MemoryUsage_, memoryStatistics.Rss, "MemoryRss"); + + TDuration cpuGuarantee; + TDuration cpuLimit; + + if (cpuStatistics.GuaranteeTime.IsOK()) { + setIfOk(&cpuGuarantee, cpuStatistics.GuaranteeTime, "CpuGuarantee"); + } else { + // XXX(don-dron, ignat): do nothing, see NContainers::TPortoInstance::GetResourceLimits, hack for missing response from Porto. + } + + setIfOk(&cpuLimit, cpuStatistics.LimitTime, "CpuLimit"); + + if (CpuGuarantee_ != cpuGuarantee) { + YT_LOG_INFO("Instance CPU guarantee updated (OldCpuGuarantee: %v, NewCpuGuarantee: %v)", + CpuGuarantee_, + cpuGuarantee); + CpuGuarantee_ = cpuGuarantee; + // NB: We do not fire LimitsUpdated since this value used only for diagnostics. + } + + TInstanceLimits limits; + limits.Cpu = cpuLimit.SecondsFloat(); + + if (memoryStatistics.AnonLimit.IsOK() && memoryStatistics.MemoryLimit.IsOK()) { + i64 anonLimit = memoryStatistics.AnonLimit.Value(); + i64 memoryLimit = memoryStatistics.MemoryLimit.Value(); + + if (anonLimit > 0 && memoryLimit > 0) { + limits.Memory = std::min(anonLimit, memoryLimit); + } else if (anonLimit > 0) { + limits.Memory = anonLimit; + } else { + limits.Memory = memoryLimit; + } + } else { + setIfOk(&limits.Memory, memoryStatistics.MemoryLimit, "MemoryLimit"); + } + + static constexpr bool DontFireAlertOnError = {}; + setIfOk(&limits.NetTx, netStatistics.TxLimit, "NetTxLimit", DontFireAlertOnError); + setIfOk(&limits.NetRx, netStatistics.RxLimit, "NetRxLimit", DontFireAlertOnError); + + if (InstanceLimits_ != limits) { + YT_LOG_INFO("Instance limits updated (OldLimits: %v, NewLimits: %v)", + InstanceLimits_, + limits); + InstanceLimits_ = limits; + LimitsUpdated_.Fire(limits); + } + } catch (const std::exception& ex) { + YT_LOG_WARNING(ex, "Failed to get instance limits"); + } +#endif +} + +IYPathServicePtr TInstanceLimitsTracker::GetOrchidService() +{ + return IYPathService::FromProducer(BIND(&TInstanceLimitsTracker::DoBuildOrchid, MakeStrong(this))) + ->Via(Invoker_); +} + +void TInstanceLimitsTracker::DoBuildOrchid(NYson::IYsonConsumer* consumer) const +{ + NYTree::BuildYsonFluently(consumer) + .BeginMap() + .DoIf(static_cast<bool>(InstanceLimits_), [&] (auto fluent) { + fluent.Item("cpu_limit").Value(InstanceLimits_->Cpu); + }) + .DoIf(static_cast<bool>(CpuGuarantee_), [&] (auto fluent) { + fluent.Item("cpu_guarantee").Value(*CpuGuarantee_); + }) + .DoIf(static_cast<bool>(InstanceLimits_), [&] (auto fluent) { + fluent.Item("memory_limit").Value(InstanceLimits_->Memory); + }) + .DoIf(static_cast<bool>(MemoryUsage_), [&] (auto fluent) { + fluent.Item("memory_usage").Value(*MemoryUsage_); + }) + .EndMap(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void FormatValue(TStringBuilderBase* builder, const TInstanceLimits& limits, TStringBuf /*format*/) +{ + builder->AppendFormat( + "{Cpu: %v, Memory: %v, NetTx: %v, NetRx: %v}", + limits.Cpu, + limits.Memory, + limits.NetTx, + limits.NetRx); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/instance_limits_tracker.h b/yt/yt/library/containers/instance_limits_tracker.h new file mode 100644 index 0000000000..e652fff446 --- /dev/null +++ b/yt/yt/library/containers/instance_limits_tracker.h @@ -0,0 +1,59 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/actions/signal.h> + +#include <yt/yt/core/concurrency/public.h> + +#include <yt/yt/core/yson/public.h> + +#include <yt/yt/core/ytree/public.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +class TInstanceLimitsTracker + : public TRefCounted +{ +public: + //! Raises when container limits change. + DEFINE_SIGNAL(void(const TInstanceLimits&), LimitsUpdated); + +public: + TInstanceLimitsTracker( + IInstancePtr instance, + IInstancePtr root, + IInvokerPtr invoker, + TDuration updatePeriod); + + void Start(); + void Stop(); + + NYTree::IYPathServicePtr GetOrchidService(); + +private: + void DoUpdateLimits(); + void DoBuildOrchid(NYson::IYsonConsumer* consumer) const; + + TPortoResourceTrackerPtr SelfTracker_; + TPortoResourceTrackerPtr RootTracker_; + const IInvokerPtr Invoker_; + const NConcurrency::TPeriodicExecutorPtr Executor_; + + std::optional<TDuration> CpuGuarantee_; + std::optional<TInstanceLimits> InstanceLimits_; + std::optional<i64> MemoryUsage_; + bool Running_ = false; +}; + +DEFINE_REFCOUNTED_TYPE(TInstanceLimitsTracker) + +//////////////////////////////////////////////////////////////////////////////// + +void FormatValue(TStringBuilderBase* builder, const TInstanceLimits& limits, TStringBuf format); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_executor.cpp b/yt/yt/library/containers/porto_executor.cpp new file mode 100644 index 0000000000..a6a44fd20f --- /dev/null +++ b/yt/yt/library/containers/porto_executor.cpp @@ -0,0 +1,1079 @@ +#include "porto_executor.h" +#include "config.h" + +#include "private.h" + +#include <yt/yt/core/concurrency/action_queue.h> +#include <yt/yt/core/concurrency/periodic_executor.h> +#include <yt/yt/core/concurrency/scheduler.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/fs.h> + +#include <yt/yt/core/profiling/timing.h> + +#include <yt/yt/core/ytree/convert.h> + +#include <library/cpp/porto/proto/rpc.pb.h> + +#include <library/cpp/yt/memory/atomic_intrusive_ptr.h> + +#include <string> + +namespace NYT::NContainers { + +using namespace NConcurrency; +using Porto::EError; + +//////////////////////////////////////////////////////////////////////////////// + +#ifdef _linux_ + +static const NLogging::TLogger& Logger = ContainersLogger; +static constexpr auto RetryInterval = TDuration::MilliSeconds(100); + +//////////////////////////////////////////////////////////////////////////////// + +TString PortoErrorCodeFormatter(int code) +{ + return TEnumTraits<EPortoErrorCode>::ToString(static_cast<EPortoErrorCode>(code)); +} + +YT_DEFINE_ERROR_CODE_RANGE(12000, 13999, "NYT::NContainers::EPortoErrorCode", PortoErrorCodeFormatter); + +//////////////////////////////////////////////////////////////////////////////// + +EPortoErrorCode ConvertPortoErrorCode(EError portoError) +{ + return static_cast<EPortoErrorCode>(PortoErrorCodeBase + portoError); +} + +bool IsRetriableErrorCode(EPortoErrorCode error, bool idempotent) +{ + return + error == EPortoErrorCode::Unknown || + // TODO(babenko): it's not obvious that we can always retry SocketError + // but this is how it has used to work for a while. + error == EPortoErrorCode::SocketError || + error == EPortoErrorCode::SocketTimeout && idempotent; +} + +THashMap<TString, TErrorOr<TString>> ParsePortoGetResponse( + const Porto::TGetResponse_TContainerGetListResponse& response) +{ + THashMap<TString, TErrorOr<TString>> result; + for (const auto& property : response.keyval()) { + if (property.error() == EError::Success) { + result[property.variable()] = property.value(); + } else { + result[property.variable()] = TError(ConvertPortoErrorCode(property.error()), property.errormsg()) + << TErrorAttribute("porto_error", ConvertPortoErrorCode(property.error())); + } + } + return result; +} + +THashMap<TString, TErrorOr<TString>> ParseSinglePortoGetResponse( + const TString& name, + const Porto::TGetResponse& getResponse) +{ + for (const auto& container : getResponse.list()) { + if (container.name() == name) { + return ParsePortoGetResponse(container); + } + } + THROW_ERROR_EXCEPTION("Unable to get properties from Porto") + << TErrorAttribute("container", name); +} + +THashMap<TString, THashMap<TString, TErrorOr<TString>>> ParseMultiplePortoGetResponse( + const Porto::TGetResponse& getResponse) +{ + THashMap<TString, THashMap<TString, TErrorOr<TString>>> result; + for (const auto& container : getResponse.list()) { + result[container.name()] = ParsePortoGetResponse(container); + } + return result; +} + +TString FormatEnablePorto(EEnablePorto value) +{ + switch (value) { + case EEnablePorto::None: return "none"; + case EEnablePorto::Isolate: return "isolate"; + case EEnablePorto::Full: return "full"; + default: YT_ABORT(); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoExecutor + : public IPortoExecutor +{ +public: + TPortoExecutor( + TPortoExecutorDynamicConfigPtr config, + const TString& threadNameSuffix, + const NProfiling::TProfiler& profiler) + : Config_(std::move(config)) + , Queue_(New<TActionQueue>(Format("Porto:%v", threadNameSuffix))) + , Profiler_(profiler) + , PollExecutor_(New<TPeriodicExecutor>( + Queue_->GetInvoker(), + BIND(&TPortoExecutor::DoPoll, MakeWeak(this)), + Config_->PollPeriod)) + { + DynamicConfig_.Store(New<TPortoExecutorDynamicConfig>()); + + Api_->SetTimeout(Config_->ApiTimeout.Seconds()); + Api_->SetDiskTimeout(Config_->ApiDiskTimeout.Seconds()); + + PollExecutor_->Start(); + } + + void SubscribeFailed(const TCallback<void (const TError&)>& callback) override + { + Failed_.Subscribe(callback); + } + + void UnsubscribeFailed(const TCallback<void (const TError&)>& callback) override + { + Failed_.Unsubscribe(callback); + } + + void OnDynamicConfigChanged(const TPortoExecutorDynamicConfigPtr& newConfig) override + { + DynamicConfig_.Store(newConfig); + } + +private: + template <class T, class... TArgs1, class... TArgs2> + auto ExecutePortoApiAction( + T(TPortoExecutor::*Method)(TArgs1...), + const TString& command, + TArgs2&&... args) + { + YT_LOG_DEBUG("Enqueue Porto API action (Command: %v)", command); + return BIND(Method, MakeStrong(this), std::forward<TArgs2>(args)...) + .AsyncVia(Queue_->GetInvoker()) + .Run(); + }; + +public: + TFuture<void> CreateContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoCreateContainer, + "CreateContainer", + container); + } + + TFuture<void> CreateContainer(const TRunnableContainerSpec& containerSpec, bool start) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoCreateContainerFromSpec, + "CreateContainerFromSpec", + containerSpec, + start); + } + + TFuture<std::optional<TString>> GetContainerProperty( + const TString& container, + const TString& property) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoGetContainerProperty, + "GetContainerProperty", + container, + property); + } + + TFuture<THashMap<TString, TErrorOr<TString>>> GetContainerProperties( + const TString& container, + const std::vector<TString>& properties) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoGetContainerProperties, + "GetContainerProperty", + container, + properties); + } + + TFuture<THashMap<TString, THashMap<TString, TErrorOr<TString>>>> GetContainerProperties( + const std::vector<TString>& containers, + const std::vector<TString>& properties) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoGetContainerMultipleProperties, + "GetContainerProperty", + containers, + properties); + } + + TFuture<THashMap<TString, i64>> GetContainerMetrics( + const std::vector<TString>& containers, + const TString& metric) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoGetContainerMetrics, + "GetContainerMetrics", + containers, + metric); + } + + TFuture<void> SetContainerProperty( + const TString& container, + const TString& property, + const TString& value) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoSetContainerProperty, + "SetContainerProperty", + container, + property, + value); + } + + TFuture<void> DestroyContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoDestroyContainer, + "DestroyContainer", + container); + } + + TFuture<void> StopContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoStopContainer, + "StopContainer", + container); + } + + TFuture<void> StartContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoStartContainer, + "StartContainer", + container); + } + + TFuture<TString> ConvertPath(const TString& path, const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoConvertPath, + "ConvertPath", + path, + container); + } + + TFuture<void> KillContainer(const TString& container, int signal) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoKillContainer, + "KillContainer", + container, + signal); + } + + TFuture<std::vector<TString>> ListSubcontainers( + const TString& rootContainer, + bool includeRoot) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoListSubcontainers, + "ListSubcontainers", + rootContainer, + includeRoot); + } + + TFuture<int> PollContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoPollContainer, + "PollContainer", + container); + } + + TFuture<int> WaitContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoWaitContainer, + "WaitContainer", + container); + } + + // This method allocates Porto "resources", so it should be uncancellable. + TFuture<TString> CreateVolume( + const TString& path, + const THashMap<TString, TString>& properties) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoCreateVolume, + "CreateVolume", + path, + properties) + .ToUncancelable(); + } + + // This method allocates Porto "resources", so it should be uncancellable. + TFuture<void> LinkVolume( + const TString& path, + const TString& name) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoLinkVolume, + "LinkVolume", + path, + name) + .ToUncancelable(); + } + + // This method deallocates Porto "resources", so it should be uncancellable. + TFuture<void> UnlinkVolume( + const TString& path, + const TString& name) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoUnlinkVolume, + "UnlinkVolume", + path, + name) + .ToUncancelable(); + } + + TFuture<std::vector<TString>> ListVolumePaths() override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoListVolumePaths, + "ListVolumePaths"); + } + + // This method allocates Porto "resources", so it should be uncancellable. + TFuture<void> ImportLayer(const TString& archivePath, const TString& layerId, const TString& place) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoImportLayer, + "ImportLayer", + archivePath, + layerId, + place) + .ToUncancelable(); + } + + // This method deallocates Porto "resources", so it should be uncancellable. + TFuture<void> RemoveLayer(const TString& layerId, const TString& place, bool async) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoRemoveLayer, + "RemoveLayer", + layerId, + place, + async) + .ToUncancelable(); + } + + TFuture<std::vector<TString>> ListLayers(const TString& place) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoListLayers, + "ListLayers", + place); + } + + IInvokerPtr GetInvoker() const override + { + return Queue_->GetInvoker(); + } + +private: + const TPortoExecutorDynamicConfigPtr Config_; + const TActionQueuePtr Queue_; + const NProfiling::TProfiler Profiler_; + const std::unique_ptr<Porto::TPortoApi> Api_ = std::make_unique<Porto::TPortoApi>(); + const TPeriodicExecutorPtr PollExecutor_; + TAtomicIntrusivePtr<TPortoExecutorDynamicConfig> DynamicConfig_; + + std::vector<TString> Containers_; + THashMap<TString, TPromise<int>> ContainerMap_; + TSingleShotCallbackList<void(const TError&)> Failed_; + + struct TCommandEntry + { + explicit TCommandEntry(const NProfiling::TProfiler& registry) + : TimeGauge(registry.Timer("/command_time")) + , RetryCounter(registry.Counter("/command_retries")) + , SuccessCounter(registry.Counter("/command_successes")) + , FailureCounter(registry.Counter("/command_failures")) + { } + + NProfiling::TEventTimer TimeGauge; + NProfiling::TCounter RetryCounter; + NProfiling::TCounter SuccessCounter; + NProfiling::TCounter FailureCounter; + }; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, CommandLock_); + THashMap<TString, TCommandEntry> CommandToEntry_; + + static const std::vector<TString> ContainerRequestVars_; + + bool IsTestPortoFailureEnabled() const + { + auto config = DynamicConfig_.Acquire(); + return config->EnableTestPortoFailures; + } + + bool IsTestPortoTimeout() const + { + auto config = DynamicConfig_.Acquire(); + return config->EnableTestPortoNotResponding; + } + + EPortoErrorCode GetFailedStubError() const + { + auto config = DynamicConfig_.Acquire(); + return config->StubErrorCode; + } + + static TError CreatePortoError(EPortoErrorCode errorCode, const TString& message) + { + return TError(errorCode, "Porto API error") + << TErrorAttribute("original_porto_error_code", static_cast<int>(errorCode) - PortoErrorCodeBase) + << TErrorAttribute("porto_error_message", message); + } + + THashMap<TString, TErrorOr<TString>> DoGetContainerProperties( + const TString& container, + const std::vector<TString>& properties) + { + auto response = DoRequestContainerProperties({container}, properties); + return ParseSinglePortoGetResponse(container, response); + } + + THashMap<TString, THashMap<TString, TErrorOr<TString>>> DoGetContainerMultipleProperties( + const std::vector<TString>& containers, + const std::vector<TString>& properties) + { + auto response = DoRequestContainerProperties(containers, properties); + return ParseMultiplePortoGetResponse(response); + } + + std::optional<TString> DoGetContainerProperty( + const TString& container, + const TString& property) + { + auto response = DoRequestContainerProperties({container}, {property}); + auto parsedResponse = ParseSinglePortoGetResponse(container, response); + auto it = parsedResponse.find(property); + if (it == parsedResponse.end()) { + return std::nullopt; + } else { + return it->second.ValueOrThrow(); + } + } + + void DoCreateContainer(const TString& container) + { + ExecuteApiCall( + [&] { return Api_->Create(container); }, + "Create", + /*idempotent*/ false); + } + + void DoCreateContainerFromSpec(const TRunnableContainerSpec& spec, bool start) + { + Porto::TContainerSpec portoSpec; + + // Required properties. + portoSpec.set_name(spec.Name); + portoSpec.set_command(spec.Command); + + portoSpec.set_enable_porto(FormatEnablePorto(spec.EnablePorto)); + portoSpec.set_isolate(spec.Isolate); + + if (spec.StdinPath) { + portoSpec.set_stdin_path(*spec.StdinPath); + } + if (spec.StdoutPath) { + portoSpec.set_stdout_path(*spec.StdoutPath); + } + if (spec.StderrPath) { + portoSpec.set_stderr_path(*spec.StderrPath); + } + + if (spec.CurrentWorkingDirectory) { + portoSpec.set_cwd(*spec.CurrentWorkingDirectory); + } + + if (spec.CoreCommand) { + portoSpec.set_core_command(*spec.CoreCommand); + } + if (spec.User) { + portoSpec.set_user(*spec.User); + } + + // Useful for jobs, where we operate with numeric group ids. + if (spec.GroupId) { + portoSpec.set_group(ToString(*spec.GroupId)); + } + + if (spec.ThreadLimit) { + portoSpec.set_thread_limit(*spec.ThreadLimit); + } + + if (spec.HostName) { + // To get a reasonable and unique host name inside container. + portoSpec.set_hostname(*spec.HostName); + if (!spec.IPAddresses.empty()) { + const auto& address = spec.IPAddresses[0]; + auto etcHosts = Format("%v %v\n", address, *spec.HostName); + // To be able to resolve hostname into IP inside container. + portoSpec.set_etc_hosts(etcHosts); + } + } + + if (spec.DisableNetwork) { + auto* netConfig = portoSpec.mutable_net()->add_cfg(); + netConfig->set_opt("none"); + } else if (!spec.IPAddresses.empty() && Config_->EnableNetworkIsolation) { + // This label is intended for HBF-agent: YT-12512. + auto* label = portoSpec.mutable_labels()->add_map(); + label->set_key("HBF.ignore_address"); + label->set_val("1"); + + auto* netConfig = portoSpec.mutable_net()->add_cfg(); + netConfig->set_opt("L3"); + netConfig->add_arg("veth0"); + + for (const auto& address : spec.IPAddresses) { + auto* ipConfig = portoSpec.mutable_ip()->add_cfg(); + ipConfig->set_dev("veth0"); + ipConfig->set_ip(ToString(address)); + } + + if (spec.EnableNat64) { + // Behave like nanny does. + portoSpec.set_resolv_conf("nameserver fd64::1;nameserver 2a02:6b8:0:3400::5005;options attempts:1 timeout:1"); + } + } + + for (const auto& [key, value] : spec.Labels) { + auto* map = portoSpec.mutable_labels()->add_map(); + map->set_key(key); + map->set_val(value); + } + + for (const auto& [name, value] : spec.Env) { + auto* var = portoSpec.mutable_env()->add_var(); + var->set_name(name); + var->set_value(value); + } + + for (const auto& controller : spec.CGroupControllers) { + portoSpec.mutable_controllers()->add_controller(controller); + } + + for (const auto& device : spec.Devices) { + auto* portoDevice = portoSpec.mutable_devices()->add_device(); + portoDevice->set_device(device.DeviceName); + portoDevice->set_access(device.Enabled ? "rw" : "-"); + } + + auto addBind = [&] (const TBind& bind) { + auto* portoBind = portoSpec.mutable_bind()->add_bind(); + portoBind->set_target(bind.TargetPath); + portoBind->set_source(bind.SourcePath); + portoBind->add_flag(bind.ReadOnly ? "ro" : "rw"); + }; + + if (spec.RootFS) { + portoSpec.set_root_readonly(spec.RootFS->IsRootReadOnly); + portoSpec.set_root(spec.RootFS->RootPath); + + for (const auto& bind : spec.RootFS->Binds) { + addBind(bind); + } + } + + { + auto* ulimit = portoSpec.mutable_ulimit()->add_ulimit(); + ulimit->set_type("core"); + if (spec.EnableCoreDumps) { + ulimit->set_unlimited(true); + } else { + ulimit->set_hard(0); + ulimit->set_soft(0); + } + } + + // Set some universal defaults. + portoSpec.set_oom_is_fatal(false); + + ExecuteApiCall( + [&] { return Api_->CreateFromSpec(portoSpec, {}, start); }, + "CreateFromSpec", + /*idempotent*/ false); + } + + void DoSetContainerProperty(const TString& container, const TString& property, const TString& value) + { + ExecuteApiCall( + [&] { return Api_->SetProperty(container, property, value); }, + "SetProperty", + /*idempotent*/ true); + } + + void DoDestroyContainer(const TString& container) + { + try { + ExecuteApiCall( + [&] { return Api_->Destroy(container); }, + "Destroy", + /*idempotent*/ true); + } catch (const TErrorException& ex) { + if (!ex.Error().FindMatching(EPortoErrorCode::ContainerDoesNotExist)) { + throw; + } + } + } + + void DoStopContainer(const TString& container) + { + ExecuteApiCall( + [&] { return Api_->Stop(container); }, + "Stop", + /*idempotent*/ true); + } + + void DoStartContainer(const TString& container) + { + ExecuteApiCall( + [&] { return Api_->Start(container); }, + "Start", + /*idempotent*/ false); + } + + TString DoConvertPath(const TString& path, const TString& container) + { + TString result; + ExecuteApiCall( + [&] { return Api_->ConvertPath(path, container, "self", result); }, + "ConvertPath", + /*idempotent*/ true); + return result; + } + + void DoKillContainer(const TString& container, int signal) + { + ExecuteApiCall( + [&] { return Api_->Kill(container, signal); }, + "Kill", + /*idempotent*/ false); + } + + std::vector<TString> DoListSubcontainers(const TString& rootContainer, bool includeRoot) + { + Porto::TListContainersRequest req; + auto filter = req.add_filters(); + filter->set_name(rootContainer + "/*"); + if (includeRoot) { + auto rootFilter = req.add_filters(); + rootFilter->set_name(rootContainer); + } + auto fieldOptions = req.mutable_field_options(); + fieldOptions->add_properties("absolute_name"); + TVector<Porto::TContainer> containers; + ExecuteApiCall( + [&] { return Api_->ListContainersBy(req, containers); }, + "ListContainersBy", + /*idempotent*/ true); + + std::vector<TString> containerNames; + containerNames.reserve(containers.size()); + for (const auto& container : containers) { + const auto& absoluteName = container.status().absolute_name(); + if (!absoluteName.empty()) { + containerNames.push_back(absoluteName); + } + } + return containerNames; + } + + TFuture<int> DoWaitContainer(const TString& container) + { + auto result = NewPromise<int>(); + auto waitCallback = [=, this, this_ = MakeStrong(this)] (const Porto::TWaitResponse& rsp) { + return OnContainerTerminated(rsp, result); + }; + + ExecuteApiCall( + [&] { return Api_->AsyncWait({container}, {}, waitCallback); }, + "AsyncWait", + /*idempotent*/ false); + + return result.ToFuture().ToImmediatelyCancelable(); + } + + void OnContainerTerminated(const Porto::TWaitResponse& portoWaitResponse, TPromise<int> result) + { + const auto& container = portoWaitResponse.name(); + const auto& state = portoWaitResponse.state(); + if (state != "dead" && state != "stopped") { + result.TrySet(TError("Container finished with unexpected state") + << TErrorAttribute("container_name", container) + << TErrorAttribute("container_state", state)); + return; + } + + // TODO(max42): switch to Subscribe. + YT_UNUSED_FUTURE(GetContainerProperty(container, "exit_status").Apply(BIND( + [=] (const TErrorOr<std::optional<TString>>& errorOrExitCode) { + if (!errorOrExitCode.IsOK()) { + result.TrySet(TError("Container finished, but exit status is unknown") + << errorOrExitCode); + return; + } + + const auto& optionalExitCode = errorOrExitCode.Value(); + if (!optionalExitCode) { + result.TrySet(TError("Container finished, but exit status is unknown") + << TErrorAttribute("container_name", container) + << TErrorAttribute("container_state", state)); + return; + } + + try { + int exitStatus = FromString<int>(*optionalExitCode); + result.TrySet(exitStatus); + } catch (const std::exception& ex) { + auto error = TError("Failed to parse Porto exit status") + << TErrorAttribute("container_name", container) + << TErrorAttribute("exit_status", optionalExitCode.value()); + error.MutableInnerErrors()->push_back(TError(ex)); + result.TrySet(error); + } + }))); + } + + TFuture<int> DoPollContainer(const TString& container) + { + auto [it, inserted] = ContainerMap_.insert({container, NewPromise<int>()}); + if (!inserted) { + YT_LOG_WARNING("Container already added for polling (Container: %v)", + container); + } else { + Containers_.push_back(container); + } + return it->second.ToFuture(); + } + + Porto::TGetResponse DoRequestContainerProperties( + const std::vector<TString>& containers, + const std::vector<TString>& vars) + { + TVector<TString> containers_(containers.begin(), containers.end()); + TVector<TString> vars_(vars.begin(), vars.end()); + + const Porto::TGetResponse* getResponse; + + ExecuteApiCall( + [&] { + getResponse = Api_->Get(containers_, vars_); + return getResponse ? EError::Success : EError::Unknown; + }, + "Get", + /*idempotent*/ true); + + YT_VERIFY(getResponse); + return *getResponse; + } + + THashMap<TString, i64> DoGetContainerMetrics( + const std::vector<TString>& containers, + const TString& metric) + { + TVector<TString> containers_(containers.begin(), containers.end()); + + TMap<TString, uint64_t> result; + + ExecuteApiCall( + [&] { return Api_->GetProcMetric(containers_, metric, result); }, + "GetProcMetric", + /*idempotent*/ true); + + return {result.begin(), result.end()}; + } + + void DoPoll() + { + try { + if (Containers_.empty()) { + return; + } + + auto getResponse = DoRequestContainerProperties(Containers_, ContainerRequestVars_); + + if (getResponse.list().empty()) { + return; + } + + auto getProperty = [] ( + const Porto::TGetResponse::TContainerGetListResponse& container, + const TString& name) -> Porto::TGetResponse::TContainerGetValueResponse + { + for (const auto& property : container.keyval()) { + if (property.variable() == name) { + return property; + } + } + + return {}; + }; + + for (const auto& container : getResponse.list()) { + auto state = getProperty(container, "state"); + if (state.error() == EError::ContainerDoesNotExist) { + HandleResult(container.name(), state); + } else if (state.value() == "dead" || state.value() == "stopped") { + HandleResult(container.name(), getProperty(container, "exit_status")); + } + //TODO(dcherednik): other states + } + } catch (const std::exception& ex) { + YT_LOG_ERROR(ex, "Fatal exception occurred while polling Porto"); + Failed_.Fire(TError(ex)); + } + } + + TString DoCreateVolume( + const TString& path, + const THashMap<TString, TString>& properties) + { + auto volume = path; + TMap<TString, TString> propertyMap(properties.begin(), properties.end()); + ExecuteApiCall( + [&] { return Api_->CreateVolume(volume, propertyMap); }, + "CreateVolume", + /*idempotent*/ false); + return volume; + } + + void DoLinkVolume(const TString& path, const TString& container) + { + ExecuteApiCall( + [&] { return Api_->LinkVolume(path, container); }, + "LinkVolume", + /*idempotent*/ false); + } + + void DoUnlinkVolume(const TString& path, const TString& container) + { + ExecuteApiCall( + [&] { return Api_->UnlinkVolume(path, container); }, + "UnlinkVolume", + /*idempotent*/ false); + } + + std::vector<TString> DoListVolumePaths() + { + TVector<TString> volumes; + ExecuteApiCall( + [&] { return Api_->ListVolumes(volumes); }, + "ListVolume", + /*idempotent*/ true); + return {volumes.begin(), volumes.end()}; + } + + void DoImportLayer(const TString& archivePath, const TString& layerId, const TString& place) + { + ExecuteApiCall( + [&] { return Api_->ImportLayer(layerId, archivePath, false, place); }, + "ImportLayer", + /*idempotent*/ false); + } + + void DoRemoveLayer(const TString& layerId, const TString& place, bool async) + { + ExecuteApiCall( + [&] { return Api_->RemoveLayer(layerId, place, async); }, + "RemoveLayer", + /*idempotent*/ false); + } + + std::vector<TString> DoListLayers(const TString& place) + { + TVector<TString> layers; + ExecuteApiCall( + [&] { return Api_->ListLayers(layers, place); }, + "ListLayers", + /*idempotent*/ true); + return {layers.begin(), layers.end()}; + } + + TCommandEntry* GetCommandEntry(const TString& command) + { + auto guard = Guard(CommandLock_); + if (auto it = CommandToEntry_.find(command)) { + return &it->second; + } + return &CommandToEntry_.emplace(command, TCommandEntry(Profiler_.WithTag("command", command))).first->second; + } + + void ExecuteApiCall( + std::function<EError()> callback, + const TString& command, + bool idempotent) + { + YT_LOG_DEBUG("Porto API call started (Command: %v)", command); + + if (IsTestPortoTimeout()) { + YT_LOG_DEBUG("Testing Porto timeout (Command: %v)", command); + + auto config = DynamicConfig_.Acquire(); + TDelayedExecutor::WaitForDuration(config->ApiTimeout); + + THROW_ERROR CreatePortoError(GetFailedStubError(), "Porto timeout"); + } + + if (IsTestPortoFailureEnabled()) { + YT_LOG_DEBUG("Testing Porto failure (Command: %v)", command); + THROW_ERROR CreatePortoError(GetFailedStubError(), "Porto stub error"); + } + + auto* entry = GetCommandEntry(command); + auto startTime = NProfiling::GetInstant(); + while (true) { + EError error; + + { + NProfiling::TWallTimer timer; + error = callback(); + entry->TimeGauge.Record(timer.GetElapsedTime()); + } + + if (error == EError::Success) { + entry->SuccessCounter.Increment(); + break; + } + + entry->FailureCounter.Increment(); + HandleApiError(command, startTime, idempotent); + + YT_LOG_DEBUG("Sleeping and retrying Porto API call (Command: %v)", command); + entry->RetryCounter.Increment(); + + TDelayedExecutor::WaitForDuration(RetryInterval); + } + + YT_LOG_DEBUG("Porto API call completed (Command: %v)", command); + } + + void HandleApiError( + const TString& command, + TInstant startTime, + bool idempotent) + { + TString errorMessage; + auto error = ConvertPortoErrorCode(Api_->GetLastError(errorMessage)); + + // These errors are typical during job cleanup: we might try to kill a container that is already stopped. + bool debug = (error == EPortoErrorCode::ContainerDoesNotExist || error == EPortoErrorCode::InvalidState); + YT_LOG_EVENT( + Logger, + debug ? NLogging::ELogLevel::Debug : NLogging::ELogLevel::Error, + "Porto API call error (Error: %v, Command: %v, Message: %v)", + error, + command, + errorMessage); + + if (!IsRetriableErrorCode(error, idempotent) || NProfiling::GetInstant() - startTime > Config_->RetriesTimeout) { + THROW_ERROR CreatePortoError(error, errorMessage); + } + } + + void HandleResult(const TString& container, const Porto::TGetResponse::TContainerGetValueResponse& rsp) + { + auto portoErrorCode = ConvertPortoErrorCode(rsp.error()); + auto it = ContainerMap_.find(container); + if (it == ContainerMap_.end()) { + YT_LOG_ERROR("Got an unexpected container " + "(Container: %v, ResponseError: %v, ErrorMessage: %v, Value: %v)", + container, + portoErrorCode, + rsp.errormsg(), + rsp.value()); + return; + } else { + if (portoErrorCode != EPortoErrorCode::Success) { + YT_LOG_ERROR("Container finished with Porto API error " + "(Container: %v, ResponseError: %v, ErrorMessage: %v, Value: %v)", + container, + portoErrorCode, + rsp.errormsg(), + rsp.value()); + it->second.Set(CreatePortoError(portoErrorCode, rsp.errormsg())); + } else { + try { + int exitStatus = std::stoi(rsp.value()); + YT_LOG_DEBUG("Container finished with exit code (Container: %v, ExitCode: %v)", + container, + exitStatus); + + it->second.Set(exitStatus); + } catch (const std::exception& ex) { + it->second.Set(TError("Failed to parse Porto exit status") << ex); + } + } + } + RemoveFromPoller(container); + } + + void RemoveFromPoller(const TString& container) + { + ContainerMap_.erase(container); + + Containers_.clear(); + for (const auto& [name, pid] : ContainerMap_) { + Containers_.push_back(name); + } + } +}; + +const std::vector<TString> TPortoExecutor::ContainerRequestVars_ = { + "state", + "exit_status" +}; + +//////////////////////////////////////////////////////////////////////////////// + +IPortoExecutorPtr CreatePortoExecutor( + TPortoExecutorDynamicConfigPtr config, + const TString& threadNameSuffix, + const NProfiling::TProfiler& profiler) +{ + return New<TPortoExecutor>( + std::move(config), + threadNameSuffix, + profiler); +} + +//////////////////////////////////////////////////////////////////////////////// + +#else + +IPortoExecutorPtr CreatePortoExecutor( + TPortoExecutorDynamicConfigPtr /* config */, + const TString& /* threadNameSuffix */, + const NProfiling::TProfiler& /* profiler */) +{ + THROW_ERROR_EXCEPTION("Porto executor is not available on this platform"); +} + +#endif + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_executor.h b/yt/yt/library/containers/porto_executor.h new file mode 100644 index 0000000000..d629ab6275 --- /dev/null +++ b/yt/yt/library/containers/porto_executor.h @@ -0,0 +1,142 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/profiling/sensor.h> + +#include <yt/yt/core/actions/future.h> +#include <yt/yt/core/actions/signal.h> + +#include <yt/yt/core/net/address.h> + +#include <library/cpp/porto/libporto.hpp> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +struct TVolumeId +{ + TString Path; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TRunnableContainerSpec +{ + TString Name; + TString Command; + + EEnablePorto EnablePorto = EEnablePorto::None; + bool Isolate = true; + + std::optional<TString> StdinPath; + std::optional<TString> StdoutPath; + std::optional<TString> StderrPath; + std::optional<TString> CurrentWorkingDirectory; + std::optional<TString> CoreCommand; + std::optional<TString> User; + std::optional<int> GroupId; + + bool EnableCoreDumps = true; + + std::optional<i64> ThreadLimit; + + std::optional<TString> HostName; + std::vector<NYT::NNet::TIP6Address> IPAddresses; + bool EnableNat64 = false; + bool DisableNetwork = false; + + THashMap<TString, TString> Labels; + THashMap<TString, TString> Env; + std::vector<TString> CGroupControllers; + std::vector<TDevice> Devices; + std::optional<TRootFS> RootFS; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct IPortoExecutor + : public TRefCounted +{ + virtual void OnDynamicConfigChanged(const TPortoExecutorDynamicConfigPtr& newConfig) = 0; + + virtual TFuture<void> CreateContainer(const TString& container) = 0; + + virtual TFuture<void> CreateContainer(const TRunnableContainerSpec& containerSpec, bool start) = 0; + + virtual TFuture<void> SetContainerProperty( + const TString& container, + const TString& property, + const TString& value) = 0; + + virtual TFuture<std::optional<TString>> GetContainerProperty( + const TString& container, + const TString& property) = 0; + + virtual TFuture<THashMap<TString, TErrorOr<TString>>> GetContainerProperties( + const TString& container, + const std::vector<TString>& properties) = 0; + virtual TFuture<THashMap<TString, THashMap<TString, TErrorOr<TString>>>> GetContainerProperties( + const std::vector<TString>& containers, + const std::vector<TString>& properties) = 0; + + virtual TFuture<THashMap<TString, i64>> GetContainerMetrics( + const std::vector<TString>& containers, + const TString& metric) = 0; + virtual TFuture<void> DestroyContainer(const TString& container) = 0; + virtual TFuture<void> StopContainer(const TString& container) = 0; + virtual TFuture<void> StartContainer(const TString& container) = 0; + virtual TFuture<void> KillContainer(const TString& container, int signal) = 0; + + virtual TFuture<TString> ConvertPath(const TString& path, const TString& container) = 0; + + // Returns absolute names of immediate children only. + virtual TFuture<std::vector<TString>> ListSubcontainers( + const TString& rootContainer, + bool includeRoot) = 0; + // Starts polling a given container, returns future with exit code of finished process. + virtual TFuture<int> PollContainer(const TString& container) = 0; + + // Returns future with exit code of finished process. + // NB: temporarily broken, see https://st.yandex-team.ru/PORTO-846 for details. + virtual TFuture<int> WaitContainer(const TString& container) = 0; + + virtual TFuture<TString> CreateVolume( + const TString& path, + const THashMap<TString, TString>& properties) = 0; + virtual TFuture<void> LinkVolume( + const TString& path, + const TString& name) = 0; + virtual TFuture<void> UnlinkVolume( + const TString& path, + const TString& name) = 0; + virtual TFuture<std::vector<TString>> ListVolumePaths() = 0; + + virtual TFuture<void> ImportLayer( + const TString& archivePath, + const TString& layerId, + const TString& place) = 0; + virtual TFuture<void> RemoveLayer( + const TString& layerId, + const TString& place, + bool async) = 0; + virtual TFuture<std::vector<TString>> ListLayers(const TString& place) = 0; + + virtual IInvokerPtr GetInvoker() const = 0; + + DECLARE_INTERFACE_SIGNAL(void(const TError&), Failed); +}; + +DEFINE_REFCOUNTED_TYPE(IPortoExecutor) + +//////////////////////////////////////////////////////////////////////////////// + +IPortoExecutorPtr CreatePortoExecutor( + TPortoExecutorDynamicConfigPtr config, + const TString& threadNameSuffix, + const NProfiling::TProfiler& profiler = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_health_checker.cpp b/yt/yt/library/containers/porto_health_checker.cpp new file mode 100644 index 0000000000..5a5d358441 --- /dev/null +++ b/yt/yt/library/containers/porto_health_checker.cpp @@ -0,0 +1,69 @@ + +#include "porto_health_checker.h" + +#include "porto_executor.h" +#include "private.h" +#include "config.h" + +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/misc/fs.h> + +#include <util/random/random.h> + +namespace NYT::NContainers { + +using namespace NConcurrency; +using namespace NLogging; +using namespace NProfiling; + +//////////////////////////////////////////////////////////////////////////////// + +TPortoHealthChecker::TPortoHealthChecker( + TPortoExecutorDynamicConfigPtr config, + IInvokerPtr invoker, + TLogger logger) + : Config_(std::move(config)) + , Logger(std::move(logger)) + , CheckInvoker_(std::move(invoker)) + , Executor_(CreatePortoExecutor( + Config_, + "porto_check")) +{ } + +void TPortoHealthChecker::Start() +{ + YT_LOG_DEBUG("Porto health checker started"); + + PeriodicExecutor_ = New<TPeriodicExecutor>( + CheckInvoker_, + BIND(&TPortoHealthChecker::OnCheck, MakeWeak(this)), + Config_->RetriesTimeout); + PeriodicExecutor_->Start(); +} + +void TPortoHealthChecker::OnDynamicConfigChanged(const TPortoExecutorDynamicConfigPtr& newConfig) +{ + YT_LOG_DEBUG( + "Porto health checker dynamic config changed (EnableTestPortoFailures: %v, StubErrorCode: %v)", + Config_->EnableTestPortoFailures, + Config_->StubErrorCode); + + Executor_->OnDynamicConfigChanged(newConfig); +} + +void TPortoHealthChecker::OnCheck() +{ + YT_LOG_DEBUG("Run Porto health check"); + + auto result = WaitFor(Executor_->ListVolumePaths().AsVoid()); + if (result.IsOK()) { + Success_.Fire(); + } else { + Failed_.Fire(result); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_health_checker.h b/yt/yt/library/containers/porto_health_checker.h new file mode 100644 index 0000000000..f0fb8f0908 --- /dev/null +++ b/yt/yt/library/containers/porto_health_checker.h @@ -0,0 +1,52 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/profiling/sensor.h> + +#include <yt/yt/core/actions/signal.h> + +#include <yt/yt/core/concurrency/periodic_executor.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/error.h> + +#include <atomic> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoHealthChecker + : public TRefCounted +{ +public: + TPortoHealthChecker( + TPortoExecutorDynamicConfigPtr config, + IInvokerPtr invoker, + NLogging::TLogger logger); + + void Start(); + + void OnDynamicConfigChanged(const TPortoExecutorDynamicConfigPtr& newConfig); + + DEFINE_SIGNAL(void(), Success); + + DEFINE_SIGNAL(void(const TError&), Failed); + +private: + const TPortoExecutorDynamicConfigPtr Config_; + const NLogging::TLogger Logger; + const IInvokerPtr CheckInvoker_; + const IPortoExecutorPtr Executor_; + NConcurrency::TPeriodicExecutorPtr PeriodicExecutor_; + + void OnCheck(); +}; + +DEFINE_REFCOUNTED_TYPE(TPortoHealthChecker) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_resource_tracker.cpp b/yt/yt/library/containers/porto_resource_tracker.cpp new file mode 100644 index 0000000000..c1fe48d6af --- /dev/null +++ b/yt/yt/library/containers/porto_resource_tracker.cpp @@ -0,0 +1,711 @@ +#include "porto_resource_tracker.h" +#include "private.h" + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/error.h> + +#include <yt/yt/core/net/address.h> + +#include <yt/yt/core/ytree/public.h> + +#include <yt/yt/library/process/process.h> + +#include <yt/yt/library/containers/cgroup.h> +#include <yt/yt/library/containers/config.h> +#include <yt/yt/library/containers/instance.h> +#include <yt/yt/library/containers/porto_executor.h> +#include <yt/yt/library/containers/public.h> + +namespace NYT::NContainers { + +using namespace NProfiling; + +static const auto& Logger = ContainersLogger; + +#ifdef _linux_ + +//////////////////////////////////////////////////////////////////////////////// + +struct TPortoProfilers + : public TRefCounted +{ + TPortoResourceProfilerPtr DaemonProfiler; + TPortoResourceProfilerPtr ContainerProfiler; + + TPortoProfilers( + TPortoResourceProfilerPtr daemonProfiler, + TPortoResourceProfilerPtr containerProfiler) + : DaemonProfiler(std::move(daemonProfiler)) + , ContainerProfiler(std::move(containerProfiler)) + { } +}; + +DEFINE_REFCOUNTED_TYPE(TPortoProfilers) + +//////////////////////////////////////////////////////////////////////////////// + +static TErrorOr<ui64> GetFieldOrError( + const TResourceUsage& usage, + EStatField field) +{ + auto it = usage.find(field); + if (it == usage.end()) { + return TError("Resource usage is missing %Qlv field", field); + } + const auto& errorOrValue = it->second; + if (errorOrValue.FindMatching(EPortoErrorCode::NotSupported)) { + return TError("Property %Qlv not supported in Porto response", field); + } + return errorOrValue; +} + +//////////////////////////////////////////////////////////////////////////////// + +TPortoResourceTracker::TPortoResourceTracker( + IInstancePtr instance, + TDuration updatePeriod, + bool isDeltaTracker, + bool isForceUpdate) + : Instance_(std::move(instance)) + , UpdatePeriod_(updatePeriod) + , IsDeltaTracker_(isDeltaTracker) + , IsForceUpdate_(isForceUpdate) +{ + ResourceUsage_ = { + {EStatField::IOReadByte, 0}, + {EStatField::IOWriteByte, 0}, + {EStatField::IOBytesLimit, 0}, + {EStatField::IOReadOps, 0}, + {EStatField::IOWriteOps, 0}, + {EStatField::IOOps, 0}, + {EStatField::IOOpsLimit, 0}, + {EStatField::IOTotalTime, 0}, + {EStatField::IOWaitTime, 0} + }; + ResourceUsageDelta_ = ResourceUsage_; +} + +static TErrorOr<TDuration> ExtractDuration(TErrorOr<ui64> timeNs) +{ + if (timeNs.IsOK()) { + return TErrorOr<TDuration>(TDuration::MicroSeconds(timeNs.Value() / 1000)); + } else { + return TError(timeNs); + } +} + +TCpuStatistics TPortoResourceTracker::ExtractCpuStatistics(const TResourceUsage& resourceUsage) const +{ + // NB: Job proxy uses last sample of CPU statistics but we are interested in + // peak thread count value. + auto currentThreadCountPeak = GetFieldOrError(resourceUsage, EStatField::ThreadCount); + + PeakThreadCount_ = currentThreadCountPeak.IsOK() && PeakThreadCount_.IsOK() + ? std::max<ui64>( + PeakThreadCount_.Value(), + currentThreadCountPeak.Value()) + : currentThreadCountPeak.IsOK() ? currentThreadCountPeak : PeakThreadCount_; + + auto totalTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuUsage); + auto systemTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuSystemUsage); + auto userTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuUserUsage); + auto waitTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuWait); + auto throttledNs = GetFieldOrError(resourceUsage, EStatField::CpuThrottled); + auto limitTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuLimit); + auto guaranteeTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuGuarantee); + + return TCpuStatistics{ + .TotalUsageTime = ExtractDuration(totalTimeNs), + .UserUsageTime = ExtractDuration(userTimeNs), + .SystemUsageTime = ExtractDuration(systemTimeNs), + .WaitTime = ExtractDuration(waitTimeNs), + .ThrottledTime = ExtractDuration(throttledNs), + .ThreadCount = GetFieldOrError(resourceUsage, EStatField::ThreadCount), + .ContextSwitches = GetFieldOrError(resourceUsage, EStatField::ContextSwitches), + .ContextSwitchesDelta = GetFieldOrError(resourceUsage, EStatField::ContextSwitchesDelta), + .PeakThreadCount = PeakThreadCount_, + .LimitTime = ExtractDuration(limitTimeNs), + .GuaranteeTime = ExtractDuration(guaranteeTimeNs), + }; +} + +TMemoryStatistics TPortoResourceTracker::ExtractMemoryStatistics(const TResourceUsage& resourceUsage) const +{ + return TMemoryStatistics{ + .Rss = GetFieldOrError(resourceUsage, EStatField::Rss), + .MappedFile = GetFieldOrError(resourceUsage, EStatField::MappedFile), + .MinorPageFaults = GetFieldOrError(resourceUsage, EStatField::MinorPageFaults), + .MajorPageFaults = GetFieldOrError(resourceUsage, EStatField::MajorPageFaults), + .FileCacheUsage = GetFieldOrError(resourceUsage, EStatField::FileCacheUsage), + .AnonUsage = GetFieldOrError(resourceUsage, EStatField::AnonMemoryUsage), + .AnonLimit = GetFieldOrError(resourceUsage, EStatField::AnonMemoryLimit), + .MemoryUsage = GetFieldOrError(resourceUsage, EStatField::MemoryUsage), + .MemoryGuarantee = GetFieldOrError(resourceUsage, EStatField::MemoryGuarantee), + .MemoryLimit = GetFieldOrError(resourceUsage, EStatField::MemoryLimit), + .MaxMemoryUsage = GetFieldOrError(resourceUsage, EStatField::MaxMemoryUsage), + .OomKills = GetFieldOrError(resourceUsage, EStatField::OomKills), + .OomKillsTotal = GetFieldOrError(resourceUsage, EStatField::OomKillsTotal) + }; +} + +TBlockIOStatistics TPortoResourceTracker::ExtractBlockIOStatistics(const TResourceUsage& resourceUsage) const +{ + auto totalTimeNs = GetFieldOrError(resourceUsage, EStatField::IOTotalTime); + auto waitTimeNs = GetFieldOrError(resourceUsage, EStatField::IOWaitTime); + + return TBlockIOStatistics{ + .IOReadByte = GetFieldOrError(resourceUsage, EStatField::IOReadByte), + .IOWriteByte = GetFieldOrError(resourceUsage, EStatField::IOWriteByte), + .IOBytesLimit = GetFieldOrError(resourceUsage, EStatField::IOBytesLimit), + .IOReadOps = GetFieldOrError(resourceUsage, EStatField::IOReadOps), + .IOWriteOps = GetFieldOrError(resourceUsage, EStatField::IOWriteOps), + .IOOps = GetFieldOrError(resourceUsage, EStatField::IOOps), + .IOOpsLimit = GetFieldOrError(resourceUsage, EStatField::IOOpsLimit), + .IOTotalTime = ExtractDuration(totalTimeNs), + .IOWaitTime = ExtractDuration(waitTimeNs) + }; +} + +TNetworkStatistics TPortoResourceTracker::ExtractNetworkStatistics(const TResourceUsage& resourceUsage) const +{ + return TNetworkStatistics{ + .TxBytes = GetFieldOrError(resourceUsage, EStatField::NetTxBytes), + .TxPackets = GetFieldOrError(resourceUsage, EStatField::NetTxPackets), + .TxDrops = GetFieldOrError(resourceUsage, EStatField::NetTxDrops), + .TxLimit = GetFieldOrError(resourceUsage, EStatField::NetTxLimit), + + .RxBytes = GetFieldOrError(resourceUsage, EStatField::NetRxBytes), + .RxPackets = GetFieldOrError(resourceUsage, EStatField::NetRxPackets), + .RxDrops = GetFieldOrError(resourceUsage, EStatField::NetRxDrops), + .RxLimit = GetFieldOrError(resourceUsage, EStatField::NetRxLimit), + }; +} + +TTotalStatistics TPortoResourceTracker::ExtractTotalStatistics(const TResourceUsage& resourceUsage) const +{ + return TTotalStatistics{ + .CpuStatistics = ExtractCpuStatistics(resourceUsage), + .MemoryStatistics = ExtractMemoryStatistics(resourceUsage), + .BlockIOStatistics = ExtractBlockIOStatistics(resourceUsage), + .NetworkStatistics = ExtractNetworkStatistics(resourceUsage), + }; +} + +TCpuStatistics TPortoResourceTracker::GetCpuStatistics() const +{ + return GetStatistics( + CachedCpuStatistics_, + "CPU", + [&] (TResourceUsage& resourceUsage) { + return ExtractCpuStatistics(resourceUsage); + }); +} + +TMemoryStatistics TPortoResourceTracker::GetMemoryStatistics() const +{ + return GetStatistics( + CachedMemoryStatistics_, + "memory", + [&] (TResourceUsage& resourceUsage) { + return ExtractMemoryStatistics(resourceUsage); + }); +} + +TBlockIOStatistics TPortoResourceTracker::GetBlockIOStatistics() const +{ + return GetStatistics( + CachedBlockIOStatistics_, + "block IO", + [&] (TResourceUsage& resourceUsage) { + return ExtractBlockIOStatistics(resourceUsage); + }); +} + +TNetworkStatistics TPortoResourceTracker::GetNetworkStatistics() const +{ + return GetStatistics( + CachedNetworkStatistics_, + "network", + [&] (TResourceUsage& resourceUsage) { + return ExtractNetworkStatistics(resourceUsage); + }); +} + +TTotalStatistics TPortoResourceTracker::GetTotalStatistics() const +{ + return GetStatistics( + CachedTotalStatistics_, + "total", + [&] (TResourceUsage& resourceUsage) { + return ExtractTotalStatistics(resourceUsage); + }); +} + +template <class T, class F> +T TPortoResourceTracker::GetStatistics( + std::optional<T>& cachedStatistics, + const TString& statisticsKind, + F extractor) const +{ + UpdateResourceUsageStatisticsIfExpired(); + + auto guard = Guard(SpinLock_); + try { + auto newStatistics = extractor(IsDeltaTracker_ ? ResourceUsageDelta_ : ResourceUsage_); + cachedStatistics = newStatistics; + return newStatistics; + } catch (const std::exception& ex) { + if (!cachedStatistics) { + THROW_ERROR_EXCEPTION("Unable to get %v statistics", statisticsKind) + << ex; + } + YT_LOG_WARNING(ex, "Unable to get %v statistics; using the last one", statisticsKind); + return *cachedStatistics; + } +} + +bool TPortoResourceTracker::AreResourceUsageStatisticsExpired() const +{ + return TInstant::Now() - LastUpdateTime_.load() > UpdatePeriod_; +} + +TInstant TPortoResourceTracker::GetLastUpdateTime() const +{ + return LastUpdateTime_.load(); +} + +void TPortoResourceTracker::UpdateResourceUsageStatisticsIfExpired() const +{ + if (IsForceUpdate_ || AreResourceUsageStatisticsExpired()) { + DoUpdateResourceUsage(); + } +} + +TErrorOr<ui64> TPortoResourceTracker::CalculateCounterDelta( + const TErrorOr<ui64>& oldValue, + const TErrorOr<ui64>& newValue) const +{ + if (oldValue.IsOK() && newValue.IsOK()) { + return newValue.Value() - oldValue.Value(); + } else if (newValue.IsOK()) { + // It is better to return an error than an incorrect value. + return oldValue; + } else { + return newValue; + } +} + +static bool IsCumulativeStatistics(EStatField statistic) +{ + return + statistic == EStatField::CpuUsage || + statistic == EStatField::CpuUserUsage || + statistic == EStatField::CpuSystemUsage || + statistic == EStatField::CpuWait || + statistic == EStatField::CpuThrottled || + + statistic == EStatField::ContextSwitches || + + statistic == EStatField::MinorPageFaults || + statistic == EStatField::MajorPageFaults || + + statistic == EStatField::IOReadByte || + statistic == EStatField::IOWriteByte || + statistic == EStatField::IOReadOps || + statistic == EStatField::IOWriteOps || + statistic == EStatField::IOOps || + statistic == EStatField::IOTotalTime || + statistic == EStatField::IOWaitTime || + + statistic == EStatField::NetTxBytes || + statistic == EStatField::NetTxPackets || + statistic == EStatField::NetTxDrops || + statistic == EStatField::NetRxBytes || + statistic == EStatField::NetRxPackets || + statistic == EStatField::NetRxDrops; +} + +void TPortoResourceTracker::ReCalculateResourceUsage(const TResourceUsage& newResourceUsage) const +{ + auto guard = Guard(SpinLock_); + + TResourceUsage resourceUsage; + TResourceUsage resourceUsageDelta; + + for (const auto& stat : InstanceStatFields) { + TErrorOr<ui64> oldValue; + TErrorOr<ui64> newValue; + + if (auto newValueIt = newResourceUsage.find(stat); newValueIt.IsEnd()) { + newValue = TError("Missing property %Qlv in Porto response", stat) + << TErrorAttribute("container", Instance_->GetName()); + } else { + newValue = newValueIt->second; + } + + if (auto oldValueIt = ResourceUsage_.find(stat); oldValueIt.IsEnd()) { + oldValue = newValue; + } else { + oldValue = oldValueIt->second; + } + + if (newValue.IsOK()) { + resourceUsage[stat] = newValue; + } else { + resourceUsage[stat] = oldValue; + } + + if (IsCumulativeStatistics(stat)) { + resourceUsageDelta[stat] = CalculateCounterDelta(oldValue, newValue); + } else { + if (newValue.IsOK()) { + resourceUsageDelta[stat] = newValue; + } else { + resourceUsageDelta[stat] = oldValue; + } + } + } + + ResourceUsage_ = resourceUsage; + ResourceUsageDelta_ = resourceUsageDelta; + LastUpdateTime_.store(TInstant::Now()); +} + +void TPortoResourceTracker::DoUpdateResourceUsage() const +{ + try { + ReCalculateResourceUsage(Instance_->GetResourceUsage()); + } catch (const std::exception& ex) { + YT_LOG_ERROR( + ex, + "Couldn't get metrics from Porto"); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +TPortoResourceProfiler::TPortoResourceProfiler( + TPortoResourceTrackerPtr tracker, + TPodSpecConfigPtr podSpec, + const TProfiler& profiler) + : ResourceTracker_(std::move(tracker)) + , PodSpec_(std::move(podSpec)) +{ + profiler.AddProducer("", MakeStrong(this)); +} + +static void WriteGaugeIfOk( + ISensorWriter* writer, + const TString& path, + TErrorOr<ui64> valueOrError) +{ + if (valueOrError.IsOK()) { + i64 value = static_cast<i64>(valueOrError.Value()); + + if (value >= 0) { + writer->AddGauge(path, value); + } + } +} + +static void WriteCumulativeGaugeIfOk( + ISensorWriter* writer, + const TString& path, + TErrorOr<ui64> valueOrError, + i64 timeDeltaUsec) +{ + if (valueOrError.IsOK()) { + i64 value = static_cast<i64>(valueOrError.Value()); + + if (value >= 0) { + writer->AddGauge(path, + 1.0 * value * ResourceUsageUpdatePeriod.MicroSeconds() / timeDeltaUsec); + } + } +} + +void TPortoResourceProfiler::WriteCpuMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec) +{ + { + if (totalStatistics.CpuStatistics.UserUsageTime.IsOK()) { + i64 userUsageTimeUs = totalStatistics.CpuStatistics.UserUsageTime.Value().MicroSeconds(); + double userUsagePercent = std::max<double>(0.0, 100. * userUsageTimeUs / timeDeltaUsec); + writer->AddGauge("/cpu/user", userUsagePercent); + } + + if (totalStatistics.CpuStatistics.SystemUsageTime.IsOK()) { + i64 systemUsageTimeUs = totalStatistics.CpuStatistics.SystemUsageTime.Value().MicroSeconds(); + double systemUsagePercent = std::max<double>(0.0, 100. * systemUsageTimeUs / timeDeltaUsec); + writer->AddGauge("/cpu/system", systemUsagePercent); + } + + if (totalStatistics.CpuStatistics.WaitTime.IsOK()) { + i64 waitTimeUs = totalStatistics.CpuStatistics.WaitTime.Value().MicroSeconds(); + double waitPercent = std::max<double>(0.0, 100. * waitTimeUs / timeDeltaUsec); + writer->AddGauge("/cpu/wait", waitPercent); + } + + if (totalStatistics.CpuStatistics.ThrottledTime.IsOK()) { + i64 throttledTimeUs = totalStatistics.CpuStatistics.ThrottledTime.Value().MicroSeconds(); + double throttledPercent = std::max<double>(0.0, 100. * throttledTimeUs / timeDeltaUsec); + writer->AddGauge("/cpu/throttled", throttledPercent); + } + + if (totalStatistics.CpuStatistics.TotalUsageTime.IsOK()) { + i64 totalUsageTimeUs = totalStatistics.CpuStatistics.TotalUsageTime.Value().MicroSeconds(); + double totalUsagePercent = std::max<double>(0.0, 100. * totalUsageTimeUs / timeDeltaUsec); + writer->AddGauge("/cpu/total", totalUsagePercent); + } + + if (totalStatistics.CpuStatistics.GuaranteeTime.IsOK()) { + i64 guaranteeTimeUs = totalStatistics.CpuStatistics.GuaranteeTime.Value().MicroSeconds(); + double guaranteePercent = std::max<double>(0.0, (100. * guaranteeTimeUs) / (1'000'000L)); + writer->AddGauge("/cpu/guarantee", guaranteePercent); + } + + if (totalStatistics.CpuStatistics.LimitTime.IsOK()) { + i64 limitTimeUs = totalStatistics.CpuStatistics.LimitTime.Value().MicroSeconds(); + double limitPercent = std::max<double>(0.0, (100. * limitTimeUs) / (1'000'000L)); + writer->AddGauge("/cpu/limit", limitPercent); + } + } + + if (PodSpec_->CpuToVCpuFactor) { + auto factor = *PodSpec_->CpuToVCpuFactor; + + writer->AddGauge("/cpu_to_vcpu_factor", factor); + + if (totalStatistics.CpuStatistics.UserUsageTime.IsOK()) { + i64 userUsageTimeUs = totalStatistics.CpuStatistics.UserUsageTime.Value().MicroSeconds(); + double userUsagePercent = std::max<double>(0.0, 100. * userUsageTimeUs * factor / timeDeltaUsec); + writer->AddGauge("/vcpu/user", userUsagePercent); + } + + if (totalStatistics.CpuStatistics.SystemUsageTime.IsOK()) { + i64 systemUsageTimeUs = totalStatistics.CpuStatistics.SystemUsageTime.Value().MicroSeconds(); + double systemUsagePercent = std::max<double>(0.0, 100. * systemUsageTimeUs * factor / timeDeltaUsec); + writer->AddGauge("/vcpu/system", systemUsagePercent); + } + + if (totalStatistics.CpuStatistics.WaitTime.IsOK()) { + i64 waitTimeUs = totalStatistics.CpuStatistics.WaitTime.Value().MicroSeconds(); + double waitPercent = std::max<double>(0.0, 100. * waitTimeUs * factor / timeDeltaUsec); + writer->AddGauge("/vcpu/wait", waitPercent); + } + + if (totalStatistics.CpuStatistics.ThrottledTime.IsOK()) { + i64 throttledTimeUs = totalStatistics.CpuStatistics.ThrottledTime.Value().MicroSeconds(); + double throttledPercent = std::max<double>(0.0, 100. * throttledTimeUs * factor / timeDeltaUsec); + writer->AddGauge("/vcpu/throttled", throttledPercent); + } + + if (totalStatistics.CpuStatistics.TotalUsageTime.IsOK()) { + i64 totalUsageTimeUs = totalStatistics.CpuStatistics.TotalUsageTime.Value().MicroSeconds(); + double totalUsagePercent = std::max<double>(0.0, 100. * totalUsageTimeUs * factor / timeDeltaUsec); + writer->AddGauge("/vcpu/total", totalUsagePercent); + } + + if (totalStatistics.CpuStatistics.GuaranteeTime.IsOK()) { + i64 guaranteeTimeUs = totalStatistics.CpuStatistics.GuaranteeTime.Value().MicroSeconds(); + double guaranteePercent = std::max<double>(0.0, 100. * guaranteeTimeUs * factor / 1'000'000L); + writer->AddGauge("/vcpu/guarantee", guaranteePercent); + } + + if (totalStatistics.CpuStatistics.LimitTime.IsOK()) { + i64 limitTimeUs = totalStatistics.CpuStatistics.LimitTime.Value().MicroSeconds(); + double limitPercent = std::max<double>(0.0, 100. * limitTimeUs * factor / 1'000'000L); + writer->AddGauge("/vcpu/limit", limitPercent); + } + } + + WriteGaugeIfOk(writer, "/cpu/thread_count", totalStatistics.CpuStatistics.ThreadCount); + WriteGaugeIfOk(writer, "/cpu/context_switches", totalStatistics.CpuStatistics.ContextSwitches); +} + +void TPortoResourceProfiler::WriteMemoryMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec) +{ + WriteCumulativeGaugeIfOk(writer, + "/memory/minor_page_faults", + totalStatistics.MemoryStatistics.MinorPageFaults, + timeDeltaUsec); + WriteCumulativeGaugeIfOk(writer, + "/memory/major_page_faults", + totalStatistics.MemoryStatistics.MajorPageFaults, + timeDeltaUsec); + + WriteGaugeIfOk(writer, "/memory/oom_kills", totalStatistics.MemoryStatistics.OomKills); + WriteGaugeIfOk(writer, "/memory/oom_kills_total", totalStatistics.MemoryStatistics.OomKillsTotal); + + WriteGaugeIfOk(writer, "/memory/file_cache_usage", totalStatistics.MemoryStatistics.FileCacheUsage); + WriteGaugeIfOk(writer, "/memory/anon_usage", totalStatistics.MemoryStatistics.AnonUsage); + WriteGaugeIfOk(writer, "/memory/anon_limit", totalStatistics.MemoryStatistics.AnonLimit); + WriteGaugeIfOk(writer, "/memory/memory_usage", totalStatistics.MemoryStatistics.MemoryUsage); + WriteGaugeIfOk(writer, "/memory/memory_guarantee", totalStatistics.MemoryStatistics.MemoryGuarantee); + WriteGaugeIfOk(writer, "/memory/memory_limit", totalStatistics.MemoryStatistics.MemoryLimit); +} + +void TPortoResourceProfiler::WriteBlockingIOMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec) +{ + WriteCumulativeGaugeIfOk(writer, + "/io/read_bytes", + totalStatistics.BlockIOStatistics.IOReadByte, + timeDeltaUsec); + WriteCumulativeGaugeIfOk(writer, + "/io/write_bytes", + totalStatistics.BlockIOStatistics.IOWriteByte, + timeDeltaUsec); + WriteCumulativeGaugeIfOk(writer, + "/io/read_ops", + totalStatistics.BlockIOStatistics.IOReadOps, + timeDeltaUsec); + WriteCumulativeGaugeIfOk(writer, + "/io/write_ops", + totalStatistics.BlockIOStatistics.IOWriteOps, + timeDeltaUsec); + WriteCumulativeGaugeIfOk(writer, + "/io/ops", + totalStatistics.BlockIOStatistics.IOOps, + timeDeltaUsec); + + WriteGaugeIfOk(writer, + "/io/bytes_limit", + totalStatistics.BlockIOStatistics.IOBytesLimit); + WriteGaugeIfOk(writer, + "/io/ops_limit", + totalStatistics.BlockIOStatistics.IOOpsLimit); + + if (totalStatistics.BlockIOStatistics.IOTotalTime.IsOK()) { + i64 totalTimeUs = totalStatistics.BlockIOStatistics.IOTotalTime.Value().MicroSeconds(); + double totalPercent = std::max<double>(0.0, 100. * totalTimeUs / timeDeltaUsec); + writer->AddGauge("/io/total", totalPercent); + } + + if (totalStatistics.BlockIOStatistics.IOWaitTime.IsOK()) { + i64 waitTimeUs = totalStatistics.BlockIOStatistics.IOWaitTime.Value().MicroSeconds(); + double waitPercent = std::max<double>(0.0, 100. * waitTimeUs / timeDeltaUsec); + writer->AddGauge("/io/wait", waitPercent); + } +} + +void TPortoResourceProfiler::WriteNetworkMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec) +{ + WriteCumulativeGaugeIfOk( + writer, + "/network/rx_bytes", + totalStatistics.NetworkStatistics.RxBytes, + timeDeltaUsec); + WriteCumulativeGaugeIfOk( + writer, + "/network/rx_drops", + totalStatistics.NetworkStatistics.RxDrops, + timeDeltaUsec); + WriteCumulativeGaugeIfOk( + writer, + "/network/rx_packets", + totalStatistics.NetworkStatistics.RxPackets, + timeDeltaUsec); + WriteGaugeIfOk( + writer, + "/network/rx_limit", + totalStatistics.NetworkStatistics.RxLimit); + + WriteCumulativeGaugeIfOk( + writer, + "/network/tx_bytes", + totalStatistics.NetworkStatistics.TxBytes, + timeDeltaUsec); + WriteCumulativeGaugeIfOk( + writer, + "/network/tx_drops", + totalStatistics.NetworkStatistics.TxDrops, + timeDeltaUsec); + WriteCumulativeGaugeIfOk( + writer, + "/network/tx_packets", + totalStatistics.NetworkStatistics.TxPackets, + timeDeltaUsec); + WriteGaugeIfOk( + writer, + "/network/tx_limit", + totalStatistics.NetworkStatistics.TxLimit); +} + +void TPortoResourceProfiler::CollectSensors(ISensorWriter* writer) +{ + i64 lastUpdate = ResourceTracker_->GetLastUpdateTime().MicroSeconds(); + + auto totalStatistics = ResourceTracker_->GetTotalStatistics(); + i64 timeDeltaUsec = TInstant::Now().MicroSeconds() - lastUpdate; + + WriteCpuMetrics(writer, totalStatistics, timeDeltaUsec); + WriteMemoryMetrics(writer, totalStatistics, timeDeltaUsec); + WriteBlockingIOMetrics(writer, totalStatistics, timeDeltaUsec); + WriteNetworkMetrics(writer, totalStatistics, timeDeltaUsec); +} + +//////////////////////////////////////////////////////////////////////////////// + +TPortoResourceProfilerPtr CreatePortoProfilerWithTags( + const IInstancePtr& instance, + const TString containerCategory, + const TPodSpecConfigPtr& podSpec) +{ + auto portoResourceTracker = New<TPortoResourceTracker>( + instance, + ResourceUsageUpdatePeriod, + true, + true); + + return New<TPortoResourceProfiler>( + portoResourceTracker, + podSpec, + TProfiler("/porto") + .WithTag("container_category", containerCategory)); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif + +#ifdef __linux__ +void EnablePortoResourceTracker(const TPodSpecConfigPtr& podSpec) +{ + BIND([=] { + auto executor = CreatePortoExecutor(New<TPortoExecutorDynamicConfig>(), "porto-tracker"); + + executor->SubscribeFailed(BIND([=] (const TError& error) { + YT_LOG_ERROR(error, "Fatal error during Porto polling"); + })); + + LeakyRefCountedSingleton<TPortoProfilers>( + CreatePortoProfilerWithTags(GetSelfPortoInstance(executor), "daemon", podSpec), + CreatePortoProfilerWithTags(GetRootPortoInstance(executor), "pod", podSpec)); + }).AsyncVia(GetCurrentInvoker()) + .Run() + .Subscribe(BIND([] (const TError& error) { + YT_LOG_ERROR_IF(!error.IsOK(), error, "Failed to enable Porto profiler"); + })); +} +#else +void EnablePortoResourceTracker(const TPodSpecConfigPtr& /*podSpec*/) +{ + YT_LOG_WARNING("Porto resource tracker not supported"); +} +#endif + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_resource_tracker.h b/yt/yt/library/containers/porto_resource_tracker.h new file mode 100644 index 0000000000..8a0f781949 --- /dev/null +++ b/yt/yt/library/containers/porto_resource_tracker.h @@ -0,0 +1,158 @@ +#pragma once + +#include <yt/yt/library/containers/instance.h> +#include <yt/yt/library/containers/public.h> + +#include <yt/yt/library/containers/cgroup.h> + +#include <yt/yt/core/misc/singleton.h> +#include <yt/yt/core/net/address.h> +#include <yt/yt/core/ytree/public.h> + +#include <yt/yt/library/process/process.h> +#include <yt/yt/library/profiling/producer.h> + +namespace NYT::NContainers { + +using namespace NProfiling; + +//////////////////////////////////////////////////////////////////////////////// + +static constexpr auto ResourceUsageUpdatePeriod = TDuration::MilliSeconds(1000); + +//////////////////////////////////////////////////////////////////////////////// + +using TCpuStatistics = TCpuAccounting::TStatistics; +using TBlockIOStatistics = TBlockIO::TStatistics; +using TMemoryStatistics = TMemory::TStatistics; +using TNetworkStatistics = TNetwork::TStatistics; + +struct TTotalStatistics +{ +public: + TCpuStatistics CpuStatistics; + TMemoryStatistics MemoryStatistics; + TBlockIOStatistics BlockIOStatistics; + TNetworkStatistics NetworkStatistics; +}; + +#ifdef _linux_ + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoResourceTracker + : public TRefCounted +{ +public: + TPortoResourceTracker( + IInstancePtr instance, + TDuration updatePeriod, + bool isDeltaTracker = false, + bool isForceUpdate = false); + + TCpuStatistics GetCpuStatistics() const; + + TBlockIOStatistics GetBlockIOStatistics() const; + + TMemoryStatistics GetMemoryStatistics() const; + + TNetworkStatistics GetNetworkStatistics() const; + + TTotalStatistics GetTotalStatistics() const; + + bool AreResourceUsageStatisticsExpired() const; + + TInstant GetLastUpdateTime() const; + +private: + const IInstancePtr Instance_; + const TDuration UpdatePeriod_; + const bool IsDeltaTracker_; + const bool IsForceUpdate_; + + mutable std::atomic<TInstant> LastUpdateTime_ = {}; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, SpinLock_); + mutable TResourceUsage ResourceUsage_; + mutable TResourceUsage ResourceUsageDelta_; + + mutable std::optional<TCpuStatistics> CachedCpuStatistics_; + mutable std::optional<TMemoryStatistics> CachedMemoryStatistics_; + mutable std::optional<TBlockIOStatistics> CachedBlockIOStatistics_; + mutable std::optional<TNetworkStatistics> CachedNetworkStatistics_; + mutable std::optional<TTotalStatistics> CachedTotalStatistics_; + mutable TErrorOr<ui64> PeakThreadCount_ = 0; + + template <class T, class F> + T GetStatistics( + std::optional<T>& cachedStatistics, + const TString& statisticsKind, + F extractor) const; + + TCpuStatistics ExtractCpuStatistics(const TResourceUsage& resourceUsage) const; + TMemoryStatistics ExtractMemoryStatistics(const TResourceUsage& resourceUsage) const; + TBlockIOStatistics ExtractBlockIOStatistics(const TResourceUsage& resourceUsage) const; + TNetworkStatistics ExtractNetworkStatistics(const TResourceUsage& resourceUsage) const; + TTotalStatistics ExtractTotalStatistics(const TResourceUsage& resourceUsage) const; + + TErrorOr<ui64> CalculateCounterDelta( + const TErrorOr<ui64>& oldValue, + const TErrorOr<ui64>& newValue) const; + + void ReCalculateResourceUsage(const TResourceUsage& newResourceUsage) const; + + void UpdateResourceUsageStatisticsIfExpired() const; + + void DoUpdateResourceUsage() const; +}; + +DEFINE_REFCOUNTED_TYPE(TPortoResourceTracker) + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoResourceProfiler + : public ISensorProducer +{ +public: + TPortoResourceProfiler( + TPortoResourceTrackerPtr tracker, + TPodSpecConfigPtr podSpec, + const TProfiler& profiler = TProfiler{"/porto"}); + + void CollectSensors(ISensorWriter* writer) override; + +private: + const TPortoResourceTrackerPtr ResourceTracker_; + const TPodSpecConfigPtr PodSpec_; + + void WriteCpuMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec); + + void WriteMemoryMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec); + + void WriteBlockingIOMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec); + + void WriteNetworkMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec); +}; + +DECLARE_REFCOUNTED_TYPE(TPortoResourceProfiler) +DEFINE_REFCOUNTED_TYPE(TPortoResourceProfiler) + +//////////////////////////////////////////////////////////////////////////////// + +#endif + +void EnablePortoResourceTracker(const TPodSpecConfigPtr& podSpec); + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/private.h b/yt/yt/library/containers/private.h new file mode 100644 index 0000000000..62682cb364 --- /dev/null +++ b/yt/yt/library/containers/private.h @@ -0,0 +1,13 @@ +#pragma once + +#include <yt/yt/core/logging/log.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger ContainersLogger("Containers"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/process.cpp b/yt/yt/library/containers/process.cpp new file mode 100644 index 0000000000..ad1c8d35dc --- /dev/null +++ b/yt/yt/library/containers/process.cpp @@ -0,0 +1,154 @@ +#ifdef __linux__ + +#include "process.h" + +#include <yt/yt/library/containers/instance.h> + +#include <yt/yt/core/misc/proc.h> +#include <yt/yt/core/misc/fs.h> + +namespace NYT::NContainers { + +using namespace NPipes; +using namespace NNet; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +static inline const NLogging::TLogger Logger("Process"); + +static constexpr pid_t InvalidProcessId = -1; + +//////////////////////////////////////////////////////////////////////////////// + +TPortoProcess::TPortoProcess( + const TString& path, + IInstanceLauncherPtr containerLauncher, + bool copyEnv) + : TProcessBase(path) + , ContainerLauncher_(std::move(containerLauncher)) +{ + AddArgument(NFS::GetFileName(path)); + if (copyEnv) { + for (char** envIt = environ; *envIt; ++envIt) { + Env_.push_back(Capture(*envIt)); + } + } +} + +void TPortoProcess::Kill(int signal) +{ + if (auto instance = GetInstance()) { + instance->Kill(signal); + } +} + +void TPortoProcess::DoSpawn() +{ + YT_VERIFY(ProcessId_ == InvalidProcessId && !Finished_); + YT_VERIFY(!GetInstance()); + YT_VERIFY(!Started_); + YT_VERIFY(!Args_.empty()); + + if (!WorkingDirectory_.empty()) { + ContainerLauncher_->SetCwd(WorkingDirectory_); + } + + Started_ = true; + + try { + // TPortoProcess doesn't support running processes inside rootFS. + YT_VERIFY(!ContainerLauncher_->HasRoot()); + std::vector<TString> args(Args_.begin() + 1, Args_.end()); + auto instance = WaitFor(ContainerLauncher_->Launch(ResolvedPath_, args, DecomposeEnv())) + .ValueOrThrow(); + ContainerInstance_.Store(instance); + FinishedPromise_.SetFrom(instance->Wait()); + + try { + ProcessId_ = instance->GetPid(); + } catch (const std::exception& ex) { + // This could happen if Porto container has already died or pid namespace of + // parent container is not a parent of pid namespace of child container. + // It's not a problem, since for Porto process pid is used for logging purposes only. + YT_LOG_DEBUG(ex, "Failed to get pid of root process (Container: %v)", + instance->GetName()); + } + + YT_LOG_DEBUG("Process inside Porto spawned successfully (Path: %v, ExternalPid: %v, Container: %v)", + ResolvedPath_, + ProcessId_, + instance->GetName()); + + FinishedPromise_.ToFuture().Subscribe(BIND([=, this, this_ = MakeStrong(this)] (const TError& exitStatus) { + Finished_ = true; + if (exitStatus.IsOK()) { + YT_LOG_DEBUG("Process inside Porto exited gracefully (ExternalPid: %v, Container: %v)", + ProcessId_, + instance->GetName()); + } else { + YT_LOG_DEBUG(exitStatus, "Process inside Porto exited with an error (ExternalPid: %v, Container: %v)", + ProcessId_, + instance->GetName()); + } + })); + } catch (const std::exception& ex) { + Finished_ = true; + THROW_ERROR_EXCEPTION("Failed to start child process inside Porto") + << TErrorAttribute("path", ResolvedPath_) + << TErrorAttribute("container", ContainerLauncher_->GetName()) + << ex; + } +} + +IInstancePtr TPortoProcess::GetInstance() +{ + return ContainerInstance_.Acquire(); +} + +THashMap<TString, TString> TPortoProcess::DecomposeEnv() const +{ + THashMap<TString, TString> result; + for (const auto& env : Env_) { + TStringBuf name, value; + TStringBuf(env).TrySplit('=', name, value); + result[name] = value; + } + return result; +} + +static TString CreateStdIONamedPipePath() +{ + const TString name = ToString(TGuid::Create()); + return NFS::GetRealPath(NFS::CombinePaths("/tmp", name)); +} + +IConnectionWriterPtr TPortoProcess::GetStdInWriter() +{ + auto pipe = TNamedPipe::Create(CreateStdIONamedPipePath()); + ContainerLauncher_->SetStdIn(pipe->GetPath()); + NamedPipes_.push_back(pipe); + return pipe->CreateAsyncWriter(); +} + +IConnectionReaderPtr TPortoProcess::GetStdOutReader() +{ + auto pipe = TNamedPipe::Create(CreateStdIONamedPipePath()); + ContainerLauncher_->SetStdOut(pipe->GetPath()); + NamedPipes_.push_back(pipe); + return pipe->CreateAsyncReader(); +} + +IConnectionReaderPtr TPortoProcess::GetStdErrReader() +{ + auto pipe = TNamedPipe::Create(CreateStdIONamedPipePath()); + ContainerLauncher_->SetStdErr(pipe->GetPath()); + NamedPipes_.push_back(pipe); + return pipe->CreateAsyncReader(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers + +#endif diff --git a/yt/yt/library/containers/process.h b/yt/yt/library/containers/process.h new file mode 100644 index 0000000000..75255165d8 --- /dev/null +++ b/yt/yt/library/containers/process.h @@ -0,0 +1,46 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/process/process.h> + +#include <library/cpp/yt/memory/atomic_intrusive_ptr.h> + +#include <library/cpp/porto/libporto.hpp> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +// NB(psushin): this class is deprecated and only used to run job proxy. +// ToDo(psushin): kill me. +class TPortoProcess + : public TProcessBase +{ +public: + TPortoProcess( + const TString& path, + NContainers::IInstanceLauncherPtr containerLauncher, + bool copyEnv = true); + void Kill(int signal) override; + NNet::IConnectionWriterPtr GetStdInWriter() override; + NNet::IConnectionReaderPtr GetStdOutReader() override; + NNet::IConnectionReaderPtr GetStdErrReader() override; + + NContainers::IInstancePtr GetInstance(); + +private: + const NContainers::IInstanceLauncherPtr ContainerLauncher_; + + TAtomicIntrusivePtr<NContainers::IInstance> ContainerInstance_; + std::vector<NPipes::TNamedPipePtr> NamedPipes_; + + void DoSpawn() override; + THashMap<TString, TString> DecomposeEnv() const; +}; + +DEFINE_REFCOUNTED_TYPE(TPortoProcess) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/public.h b/yt/yt/library/containers/public.h new file mode 100644 index 0000000000..d8e3cf3491 --- /dev/null +++ b/yt/yt/library/containers/public.h @@ -0,0 +1,163 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +#include <library/cpp/porto/proto/rpc.pb.h> +#include <library/cpp/yt/misc/enum.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +const int PortoErrorCodeBase = 12000; + +DEFINE_ENUM(EPortoErrorCode, + ((Success) ((PortoErrorCodeBase + Porto::EError::Success))) + ((Unknown) ((PortoErrorCodeBase + Porto::EError::Unknown))) + ((InvalidMethod) ((PortoErrorCodeBase + Porto::EError::InvalidMethod))) + ((ContainerAlreadyExists) ((PortoErrorCodeBase + Porto::EError::ContainerAlreadyExists))) + ((ContainerDoesNotExist) ((PortoErrorCodeBase + Porto::EError::ContainerDoesNotExist))) + ((InvalidProperty) ((PortoErrorCodeBase + Porto::EError::InvalidProperty))) + ((InvalidData) ((PortoErrorCodeBase + Porto::EError::InvalidData))) + ((InvalidValue) ((PortoErrorCodeBase + Porto::EError::InvalidValue))) + ((InvalidState) ((PortoErrorCodeBase + Porto::EError::InvalidState))) + ((NotSupported) ((PortoErrorCodeBase + Porto::EError::NotSupported))) + ((ResourceNotAvailable) ((PortoErrorCodeBase + Porto::EError::ResourceNotAvailable))) + ((Permission) ((PortoErrorCodeBase + Porto::EError::Permission))) + ((VolumeAlreadyExists) ((PortoErrorCodeBase + Porto::EError::VolumeAlreadyExists))) + ((VolumeNotFound) ((PortoErrorCodeBase + Porto::EError::VolumeNotFound))) + ((NoSpace) ((PortoErrorCodeBase + Porto::EError::NoSpace))) + ((Busy) ((PortoErrorCodeBase + Porto::EError::Busy))) + ((VolumeAlreadyLinked) ((PortoErrorCodeBase + Porto::EError::VolumeAlreadyLinked))) + ((VolumeNotLinked) ((PortoErrorCodeBase + Porto::EError::VolumeNotLinked))) + ((LayerAlreadyExists) ((PortoErrorCodeBase + Porto::EError::LayerAlreadyExists))) + ((LayerNotFound) ((PortoErrorCodeBase + Porto::EError::LayerNotFound))) + ((NoValue) ((PortoErrorCodeBase + Porto::EError::NoValue))) + ((VolumeNotReady) ((PortoErrorCodeBase + Porto::EError::VolumeNotReady))) + ((InvalidCommand) ((PortoErrorCodeBase + Porto::EError::InvalidCommand))) + ((LostError) ((PortoErrorCodeBase + Porto::EError::LostError))) + ((DeviceNotFound) ((PortoErrorCodeBase + Porto::EError::DeviceNotFound))) + ((InvalidPath) ((PortoErrorCodeBase + Porto::EError::InvalidPath))) + ((InvalidNetworkAddress) ((PortoErrorCodeBase + Porto::EError::InvalidNetworkAddress))) + ((PortoFrozen) ((PortoErrorCodeBase + Porto::EError::PortoFrozen))) + ((LabelNotFound) ((PortoErrorCodeBase + Porto::EError::LabelNotFound))) + ((InvalidLabel) ((PortoErrorCodeBase + Porto::EError::InvalidLabel))) + ((NotFound) ((PortoErrorCodeBase + Porto::EError::NotFound))) + ((SocketError) ((PortoErrorCodeBase + Porto::EError::SocketError))) + ((SocketUnavailable) ((PortoErrorCodeBase + Porto::EError::SocketUnavailable))) + ((SocketTimeout) ((PortoErrorCodeBase + Porto::EError::SocketTimeout))) + ((Taint) ((PortoErrorCodeBase + Porto::EError::Taint))) + ((Queued) ((PortoErrorCodeBase + Porto::EError::Queued))) +); + +//////////////////////////////////////////////////////////////////////////////// + +YT_DEFINE_ERROR_ENUM( + ((FailedToStartContainer) (14000)) +); + +DEFINE_ENUM(EStatField, + // CPU + (CpuUsage) + (CpuUserUsage) + (CpuSystemUsage) + (CpuWait) + (CpuThrottled) + (ContextSwitches) + (ContextSwitchesDelta) + (ThreadCount) + (CpuLimit) + (CpuGuarantee) + + // Memory + (Rss) + (MappedFile) + (MajorPageFaults) + (MinorPageFaults) + (FileCacheUsage) + (AnonMemoryUsage) + (AnonMemoryLimit) + (MemoryUsage) + (MemoryGuarantee) + (MemoryLimit) + (MaxMemoryUsage) + (OomKills) + (OomKillsTotal) + + // IO + (IOReadByte) + (IOWriteByte) + (IOBytesLimit) + (IOReadOps) + (IOWriteOps) + (IOOps) + (IOOpsLimit) + (IOTotalTime) + (IOWaitTime) + + // Network + (NetTxBytes) + (NetTxPackets) + (NetTxDrops) + (NetTxLimit) + (NetRxBytes) + (NetRxPackets) + (NetRxDrops) + (NetRxLimit) +); + +DEFINE_ENUM(EEnablePorto, + (None) + (Isolate) + (Full) +); + +struct TBind +{ + TString SourcePath; + TString TargetPath; + bool ReadOnly; +}; + +struct TRootFS +{ + TString RootPath; + bool IsRootReadOnly; + std::vector<TBind> Binds; +}; + +struct TDevice +{ + TString DeviceName; + bool Enabled; +}; + +struct TInstanceLimits +{ + double Cpu = 0; + i64 Memory = 0; + std::optional<i64> NetTx; + std::optional<i64> NetRx; + + bool operator==(const TInstanceLimits&) const = default; +}; + +DECLARE_REFCOUNTED_STRUCT(IContainerManager) +DECLARE_REFCOUNTED_STRUCT(IInstanceLauncher) +DECLARE_REFCOUNTED_STRUCT(IInstance) +DECLARE_REFCOUNTED_STRUCT(IPortoExecutor) + +DECLARE_REFCOUNTED_CLASS(TPortoHealthChecker) +DECLARE_REFCOUNTED_CLASS(TInstanceLimitsTracker) +DECLARE_REFCOUNTED_CLASS(TPortoProcess) +DECLARE_REFCOUNTED_CLASS(TPortoResourceTracker) +DECLARE_REFCOUNTED_CLASS(TPortoExecutorDynamicConfig) +DECLARE_REFCOUNTED_CLASS(TPodSpecConfig) + +//////////////////////////////////////////////////////////////////////////////// + +bool IsValidCGroupType(const TString& type); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/unittests/containers_ut.cpp b/yt/yt/library/containers/unittests/containers_ut.cpp new file mode 100644 index 0000000000..4f1c10a435 --- /dev/null +++ b/yt/yt/library/containers/unittests/containers_ut.cpp @@ -0,0 +1,133 @@ +#include <yt/yt/core/test_framework/framework.h> + +#ifdef _linux_ + +#include <yt/yt/library/containers/config.h> +#include <yt/yt/library/containers/porto_executor.h> +#include <yt/yt/library/containers/instance.h> + +#include <util/system/platform.h> +#include <util/system/env.h> + +namespace NYT::NContainers { +namespace { + +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +class TContainersTest + : public ::testing::Test +{ + void SetUp() override + { + if (GetEnv("SKIP_PORTO_TESTS") != "") { + GTEST_SKIP(); + } + } +}; + +static TString GetUniqueName() +{ + return "yt_ut_" + ToString(TGuid::Create()); +} + +IPortoExecutorPtr CreatePortoExecutor() +{ + return CreatePortoExecutor(New<TPortoExecutorDynamicConfig>(), "default"); +} + +TEST_F(TContainersTest, ListSubcontainers) +{ + auto executor = CreatePortoExecutor(); + auto name = GetUniqueName(); + + WaitFor(executor->CreateContainer(name)) + .ThrowOnError(); + + auto absoluteName = *WaitFor(executor->GetContainerProperty(name, "absolute_name")) + .ValueOrThrow(); + + auto nestedName = absoluteName + "/nested"; + WaitFor(executor->CreateContainer(nestedName)) + .ThrowOnError(); + + auto withRoot = WaitFor(executor->ListSubcontainers(name, true)) + .ValueOrThrow(); + EXPECT_EQ(std::vector<TString>({absoluteName, nestedName}), withRoot); + + auto withoutRoot = WaitFor(executor->ListSubcontainers(name, false)) + .ValueOrThrow(); + EXPECT_EQ(std::vector<TString>({nestedName}), withoutRoot); + + WaitFor(executor->DestroyContainer(absoluteName)) + .ThrowOnError(); +} + +// See https://st.yandex-team.ru/PORTO-846. +TEST_F(TContainersTest, DISABLED_WaitContainer) +{ + auto executor = CreatePortoExecutor(); + auto name = GetUniqueName(); + + WaitFor(executor->CreateContainer(name)) + .ThrowOnError(); + + WaitFor(executor->SetContainerProperty(name, "command", "sleep 10")) + .ThrowOnError(); + + WaitFor(executor->StartContainer(name)) + .ThrowOnError(); + + auto exitCode = WaitFor(executor->WaitContainer(name)) + .ValueOrThrow(); + + EXPECT_EQ(0, exitCode); + + WaitFor(executor->DestroyContainer(name)) + .ThrowOnError(); +} + +TEST_F(TContainersTest, CreateFromSpec) +{ + auto executor = CreatePortoExecutor(); + auto name = GetUniqueName(); + + auto spec = TRunnableContainerSpec { + .Name = name, + .Command = "sleep 2", + }; + + WaitFor(executor->CreateContainer(spec, /*start*/ true)) + .ThrowOnError(); + + auto exitCode = WaitFor(executor->PollContainer(name)) + .ValueOrThrow(); + + EXPECT_EQ(0, exitCode); + + WaitFor(executor->DestroyContainer(name)) + .ThrowOnError(); +} + +TEST_F(TContainersTest, ListPids) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + + auto instance = WaitFor(launcher->Launch("sleep", {"5"}, {})) + .ValueOrThrow(); + + auto pids = instance->GetPids(); + EXPECT_LT(0u, pids.size()); + + instance->Destroy(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT::NContainers + +#endif diff --git a/yt/yt/library/containers/unittests/porto_resource_tracker_ut.cpp b/yt/yt/library/containers/unittests/porto_resource_tracker_ut.cpp new file mode 100644 index 0000000000..04d169ba4e --- /dev/null +++ b/yt/yt/library/containers/unittests/porto_resource_tracker_ut.cpp @@ -0,0 +1,251 @@ +#include <yt/yt/core/test_framework/framework.h> + +#include <yt/yt/core/ytree/convert.h> + +#include <util/system/fs.h> +#include <util/system/tempfile.h> + +#include <yt/yt/library/profiling/producer.h> +#include <yt/yt/library/containers/config.h> +#include <yt/yt/library/containers/porto_executor.h> +#include <yt/yt/library/containers/porto_resource_tracker.h> +#include <yt/yt/library/containers/instance.h> + +#include <util/system/platform.h> +#include <util/system/env.h> + +namespace NYT::NContainers { +namespace { + +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +static constexpr auto TestUpdatePeriod = TDuration::MilliSeconds(10); + +class TPortoTrackerTest + : public ::testing::Test +{ +public: + IPortoExecutorPtr Executor; + + void SetUp() override + { + if (GetEnv("SKIP_PORTO_TESTS") != "") { + GTEST_SKIP(); + } + + Executor = CreatePortoExecutor(New<TPortoExecutorDynamicConfig>(), "default"); + } +}; + +TString GetUniqueName() +{ + return "yt_porto_ut_" + ToString(TGuid::Create()); +} + +TPortoResourceTrackerPtr CreateSumPortoTracker(IPortoExecutorPtr Executor, const TString& name) +{ + return New<TPortoResourceTracker>( + GetPortoInstance(Executor, name), + TestUpdatePeriod, + false); +} + +TPortoResourceProfilerPtr CreateDeltaPortoProfiler(IPortoExecutorPtr executor, const TString& name) +{ + auto instance = GetPortoInstance(executor, name); + auto portoResourceTracker = New<TPortoResourceTracker>( + instance, + ResourceUsageUpdatePeriod, + true, + true + ); + + // Init metrics for delta tracker. + portoResourceTracker->GetTotalStatistics(); + + return LeakyRefCountedSingleton<TPortoResourceProfiler>( + portoResourceTracker, + New<TPodSpecConfig>(), + TProfiler("/porto") + .WithTag("porto_name", instance->GetName()) + .WithTag("container_category", "yt_daemon")); +} + +void AssertGauges(const std::vector<std::tuple<TString, TTagList, double>>& gauges) { + THashSet<TString> sensors{ + "/cpu/user", + "/cpu/total", + "/cpu/system", + "/cpu/wait", + "/cpu/throttled", + "/cpu/guarantee", + "/cpu/limit", + "/cpu/thread_count", + "/cpu/context_switches", + + "/memory/minor_page_faults", + "/memory/major_page_faults", + "/memory/file_cache_usage", + "/memory/anon_usage", + "/memory/anon_limit", + "/memory/memory_usage", + "/memory/memory_guarantee", + "/memory/memory_limit", + + "/io/read_bytes", + "/io/write_bytes", + "/io/bytes_limit", + + "/io/read_ops", + "/io/write_ops", + "/io/ops", + "/io/ops_limit", + "/io/total", + + "/network/rx_bytes", + "/network/rx_drops", + "/network/rx_packets", + "/network/rx_limit", + "/network/tx_bytes", + "/network/tx_drops", + "/network/tx_packets", + "/network/tx_limit" + }; + + THashSet<TString> mayBeEmpty{ + "/cpu/wait", + "/cpu/throttled", + "/cpu/guarantee", + "/cpu/context_switches", + "/memory/major_page_faults", + "/memory/memory_guarantee", + "/io/ops_limit", + "/io/read_ops", + "/io/write_ops", + "/io/wait", + "/io/bytes_limit", + "/network/rx_bytes", + "/network/rx_drops", + "/network/rx_packets", + "/network/rx_limit", + "/network/tx_bytes", + "/network/tx_drops", + "/network/tx_packets", + "/network/tx_limit" + }; + + for (const auto& [name, tags, value] : gauges) { + EXPECT_TRUE(value >= 0 && sensors.find(name) || mayBeEmpty.find(name)); + } +} + +TEST_F(TPortoTrackerTest, ValidateSummaryPortoTracker) +{ + auto name = GetUniqueName(); + + WaitFor(Executor->CreateContainer( + TRunnableContainerSpec { + .Name = name, + .Command = "sleep .1", + }, true)) + .ThrowOnError(); + + auto tracker = CreateSumPortoTracker(Executor, name); + + auto firstStatistics = tracker->GetTotalStatistics(); + + WaitFor(Executor->StopContainer(name)) + .ThrowOnError(); + WaitFor(Executor->SetContainerProperty( + name, + "command", + "find /")) + .ThrowOnError(); + WaitFor(Executor->StartContainer(name)) + .ThrowOnError(); + Sleep(TDuration::MilliSeconds(500)); + + auto secondStatistics = tracker->GetTotalStatistics(); + + WaitFor(Executor->DestroyContainer(name)) + .ThrowOnError(); +} + +TEST_F(TPortoTrackerTest, ValidateDeltaPortoTracker) +{ + auto name = GetUniqueName(); + + auto spec = TRunnableContainerSpec { + .Name = name, + .Command = "sleep .1", + }; + + WaitFor(Executor->CreateContainer(spec, true)) + .ThrowOnError(); + + auto profiler = CreateDeltaPortoProfiler(Executor, name); + + WaitFor(Executor->StopContainer(name)) + .ThrowOnError(); + WaitFor(Executor->SetContainerProperty( + name, + "command", + "find /")) + .ThrowOnError(); + WaitFor(Executor->StartContainer(name)) + .ThrowOnError(); + + Sleep(TDuration::MilliSeconds(500)); + + auto buffer = New<TSensorBuffer>(); + profiler->CollectSensors(buffer.Get()); + AssertGauges(buffer->GetGauges()); + + WaitFor(Executor->DestroyContainer(name)) + .ThrowOnError(); +} + +TEST_F(TPortoTrackerTest, ValidateDeltaRootPortoTracker) +{ + auto name = GetUniqueName(); + + auto spec = TRunnableContainerSpec { + .Name = name, + .Command = "sleep .1", + }; + + WaitFor(Executor->CreateContainer(spec, true)) + .ThrowOnError(); + + auto profiler = CreateDeltaPortoProfiler( + Executor, + GetPortoInstance( + Executor, + *GetPortoInstance(Executor, name)->GetRootName())->GetName()); + + WaitFor(Executor->StopContainer(name)) + .ThrowOnError(); + WaitFor(Executor->SetContainerProperty( + name, + "command", + "find /")) + .ThrowOnError(); + WaitFor(Executor->StartContainer(name)) + .ThrowOnError(); + + Sleep(TDuration::MilliSeconds(500)); + + auto buffer = New<TSensorBuffer>(); + profiler->CollectSensors(buffer.Get()); + AssertGauges(buffer->GetGauges()); + + WaitFor(Executor->DestroyContainer(name)) + .ThrowOnError(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/unittests/process_ut.cpp b/yt/yt/library/containers/unittests/process_ut.cpp new file mode 100644 index 0000000000..b9c0d844f4 --- /dev/null +++ b/yt/yt/library/containers/unittests/process_ut.cpp @@ -0,0 +1,302 @@ +#include <yt/yt/core/test_framework/framework.h> + +#ifdef _linux_ + +#include <yt/yt/core/actions/bind.h> + +#include <yt/yt/core/concurrency/action_queue.h> +#include <yt/yt/core/concurrency/delayed_executor.h> +#include <yt/yt/core/concurrency/scheduler.h> + +#include <yt/yt/core/misc/guid.h> +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/net/connection.h> + +#include <yt/yt/library/containers/process.h> + +#include <yt/yt/library/containers/config.h> +#include <yt/yt/library/containers/porto_executor.h> +#include <yt/yt/library/containers/instance.h> + +#include <util/system/platform.h> +#include <util/system/env.h> + +namespace NYT::NContainers { +namespace { + +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoProcessTest + : public ::testing::Test +{ + void SetUp() override + { + if (GetEnv("SKIP_PORTO_TESTS") != "") { + GTEST_SKIP(); + } + } +}; + +static TString GetUniqueName() +{ + return "yt_ut_" + ToString(TGuid::Create()); +} + +IPortoExecutorPtr CreatePortoExecutor() +{ + return CreatePortoExecutor(New<TPortoExecutorDynamicConfig>(), "default"); +} + +TEST_F(TPortoProcessTest, Basic) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/ls", launcher, true); + TFuture<void> finished; + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_TRUE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); + p->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, RunFromPathEnv) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("ls", launcher, true); + TFuture<void> finished; + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_TRUE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); + p->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, MultiBasic) +{ + auto portoExecutor = CreatePortoExecutor(); + auto l1 = CreatePortoInstanceLauncher(GetUniqueName(), portoExecutor); + auto l2 = CreatePortoInstanceLauncher(GetUniqueName(), portoExecutor); + auto p1 = New<TPortoProcess>("/bin/ls", l1, true); + auto p2 = New<TPortoProcess>("/bin/ls", l2, true); + TFuture<void> f1; + TFuture<void> f2; + ASSERT_NO_THROW(f1 = p1->Spawn()); + ASSERT_NO_THROW(f2 = p2->Spawn()); + auto error = WaitFor((AllSucceeded(std::vector<TFuture<void>>{f1, f2}))); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p1->IsFinished()); + EXPECT_TRUE(p2->IsFinished()); + p1->GetInstance()->Destroy(); + p2->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, InvalidPath) +{ + auto portoExecutor = CreatePortoExecutor(); + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + portoExecutor); + auto p = New<TPortoProcess>("/some/bad/path/binary", launcher, true); + TFuture<void> finished; + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_FALSE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_FALSE(p->IsFinished()); + EXPECT_FALSE(error.IsOK()); + WaitFor(portoExecutor->DestroyContainer(launcher->GetName())) + .ThrowOnError(); +} + +TEST_F(TPortoProcessTest, StdOut) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/date", launcher, true); + + auto outStream = p->GetStdOutReader(); + TFuture<void> finished; + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_TRUE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); + + auto buffer = TSharedMutableRef::Allocate(4_KB, {.InitializeStorage = false}); + auto future = outStream->Read(buffer); + TErrorOr<size_t> result = WaitFor(future); + size_t sz = result.ValueOrThrow(); + EXPECT_TRUE(sz > 0); + p->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, GetCommandLine) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/bash", launcher, true); + EXPECT_EQ("/bin/bash", p->GetCommandLine()); + p->AddArgument("-c"); + EXPECT_EQ("/bin/bash -c", p->GetCommandLine()); + p->AddArgument("exit 0"); + EXPECT_EQ("/bin/bash -c \"exit 0\"", p->GetCommandLine()); +} + +TEST_F(TPortoProcessTest, ProcessReturnCode0) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/bash", launcher, true); + p->AddArgument("-c"); + p->AddArgument("exit 0"); + + TFuture<void> finished; + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_TRUE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); + p->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, ProcessReturnCode123) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/bash", launcher, true); + p->AddArgument("-c"); + p->AddArgument("exit 123"); + + TFuture<void> finished; + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_TRUE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_EQ(EProcessErrorCode::NonZeroExitCode, error.GetCode()); + EXPECT_EQ(123, error.Attributes().Get<int>("exit_code")); + EXPECT_TRUE(p->IsFinished()); + p->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, Params1) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/bash", launcher, true); + p->AddArgument("-c"); + p->AddArgument("if test 3 -gt 1; then exit 7; fi"); + + auto error = WaitFor(p->Spawn()); + EXPECT_FALSE(error.IsOK()); + EXPECT_TRUE(p->IsFinished()); + p->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, Params2) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/bash", launcher, true); + p->AddArgument("-c"); + p->AddArgument("if test 1 -gt 3; then exit 7; fi"); + + auto error = WaitFor(p->Spawn()); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); + p->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, InheritEnvironment) +{ + const char* name = "SPAWN_TEST_ENV_VAR"; + const char* value = "42"; + setenv(name, value, 1); + + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/bash", launcher, true); + p->AddArgument("-c"); + p->AddArgument("if test $SPAWN_TEST_ENV_VAR = 42; then exit 7; fi"); + + auto error = WaitFor(p->Spawn()); + EXPECT_FALSE(error.IsOK()); + EXPECT_TRUE(p->IsFinished()); + + unsetenv(name); + p->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, Kill) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/sleep", launcher, true); + p->AddArgument("5"); + + auto finished = p->Spawn(); + + NConcurrency::TDelayedExecutor::Submit( + BIND([&] () { + p->Kill(SIGKILL); + }), + TDuration::MilliSeconds(100)); + + auto error = WaitFor(finished); + EXPECT_FALSE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); + p->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, KillFinished) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/bash", launcher, true); + p->AddArgument("-c"); + p->AddArgument("true"); + + auto finished = p->Spawn(); + + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()); + + p->Kill(SIGKILL); + p->GetInstance()->Destroy(); +} + +TEST_F(TPortoProcessTest, PollDuration) +{ + auto launcher = CreatePortoInstanceLauncher( + GetUniqueName(), + CreatePortoExecutor()); + auto p = New<TPortoProcess>("/bin/sleep", launcher, true); + p->AddArgument("1"); + + auto error = WaitFor(p->Spawn()); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); + p->GetInstance()->Destroy(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT::NContainers + +#endif diff --git a/yt/yt/library/containers/unittests/ya.make b/yt/yt/library/containers/unittests/ya.make new file mode 100644 index 0000000000..42984e2dc7 --- /dev/null +++ b/yt/yt/library/containers/unittests/ya.make @@ -0,0 +1,35 @@ +GTEST(unittester-containers) + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +ALLOCATOR(TCMALLOC) + +IF (AUTOCHECK) + ENV(SKIP_PORTO_TESTS=1) +ENDIF() + +IF (DISTBUILD) # TODO(prime@): this is always on + ENV(SKIP_PORTO_TESTS=1) +ENDIF() + +SRCS( + containers_ut.cpp + process_ut.cpp +) + +IF(OS_LINUX) + SRCS( + porto_resource_tracker_ut.cpp + ) +ENDIF() + +INCLUDE(${ARCADIA_ROOT}/yt/opensource_tests.inc) + +PEERDIR( + yt/yt/build + yt/yt/library/containers +) + +SIZE(MEDIUM) + +END() diff --git a/yt/yt/library/containers/ya.make b/yt/yt/library/containers/ya.make new file mode 100644 index 0000000000..499b8d9da8 --- /dev/null +++ b/yt/yt/library/containers/ya.make @@ -0,0 +1,37 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + cgroup.cpp + config.cpp + instance.cpp + instance_limits_tracker.cpp + process.cpp + porto_executor.cpp + porto_resource_tracker.cpp + porto_health_checker.cpp +) + +PEERDIR( + library/cpp/porto/proto + yt/yt/library/process + yt/yt/core +) + +IF(OS_LINUX) + PEERDIR( + library/cpp/porto + ) +ENDIF() + +END() + +RECURSE( + disk_manager + cri +) + +RECURSE_FOR_TESTS( + unittests +) diff --git a/yt/yt/library/monitoring/CMakeLists.darwin-x86_64.txt b/yt/yt/library/monitoring/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..cbddba12a3 --- /dev/null +++ b/yt/yt/library/monitoring/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,26 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-monitoring) +target_compile_options(yt-library-monitoring PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-monitoring PUBLIC + contrib-libs-cxxsupp + yutil + yt-yt-core + yt-yt-build + yt-library-profiling + library-profiling-solomon + library-cpp-cgiparam +) +target_sources(yt-library-monitoring PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/monitoring/http_integration.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/monitoring/monitoring_manager.cpp +) diff --git a/yt/yt/library/monitoring/CMakeLists.linux-aarch64.txt b/yt/yt/library/monitoring/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..192d29cfb1 --- /dev/null +++ b/yt/yt/library/monitoring/CMakeLists.linux-aarch64.txt @@ -0,0 +1,30 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-monitoring) +target_compile_options(yt-library-monitoring PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-monitoring PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + yt-yt-build + yt-library-profiling + library-profiling-solomon + library-cpp-cgiparam + yt-library-ytprof + library-ytprof-http + library-backtrace_introspector-http +) +target_sources(yt-library-monitoring PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/monitoring/http_integration.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/monitoring/monitoring_manager.cpp +) diff --git a/yt/yt/library/monitoring/CMakeLists.linux-x86_64.txt b/yt/yt/library/monitoring/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..192d29cfb1 --- /dev/null +++ b/yt/yt/library/monitoring/CMakeLists.linux-x86_64.txt @@ -0,0 +1,30 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-monitoring) +target_compile_options(yt-library-monitoring PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-monitoring PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + yt-yt-build + yt-library-profiling + library-profiling-solomon + library-cpp-cgiparam + yt-library-ytprof + library-ytprof-http + library-backtrace_introspector-http +) +target_sources(yt-library-monitoring PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/monitoring/http_integration.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/monitoring/monitoring_manager.cpp +) diff --git a/yt/yt/library/monitoring/CMakeLists.txt b/yt/yt/library/monitoring/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/library/monitoring/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/monitoring/CMakeLists.windows-x86_64.txt b/yt/yt/library/monitoring/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..adebf50e70 --- /dev/null +++ b/yt/yt/library/monitoring/CMakeLists.windows-x86_64.txt @@ -0,0 +1,23 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-monitoring) +target_link_libraries(yt-library-monitoring PUBLIC + contrib-libs-cxxsupp + yutil + yt-yt-core + yt-yt-build + yt-library-profiling + library-profiling-solomon + library-cpp-cgiparam +) +target_sources(yt-library-monitoring PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/monitoring/http_integration.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/monitoring/monitoring_manager.cpp +) diff --git a/yt/yt/library/monitoring/http_integration.cpp b/yt/yt/library/monitoring/http_integration.cpp new file mode 100644 index 0000000000..a526d2ede6 --- /dev/null +++ b/yt/yt/library/monitoring/http_integration.cpp @@ -0,0 +1,203 @@ +#include "http_integration.h" + +#include "monitoring_manager.h" + +#include <yt/yt/build/build.h> + +#include <yt/yt/core/json/config.h> +#include <yt/yt/core/json/json_writer.h> + +#include <yt/yt/core/ytree/fluent.h> + +#include <yt/yt/core/yson/parser.h> +#include <yt/yt/core/yson/consumer.h> + +#include <yt/yt/core/concurrency/scheduler.h> + +#include <yt/yt/core/ytree/helpers.h> +#include <yt/yt/core/ytree/virtual.h> +#include <yt/yt/core/ytree/ypath_detail.h> +#include <yt/yt/core/ytree/ypath_proxy.h> + +#include <yt/yt/core/http/http.h> +#include <yt/yt/core/http/helpers.h> +#include <yt/yt/core/http/server.h> + +#include <yt/yt/core/misc/ref_counted_tracker_statistics_producer.h> + +#include <yt/yt/library/profiling/solomon/exporter.h> + +#ifdef _linux_ +#include <yt/yt/library/ytprof/http/handler.h> +#include <yt/yt/library/ytprof/build_info.h> + +#include <yt/yt/library/backtrace_introspector/http/handler.h> +#endif + +#include <library/cpp/cgiparam/cgiparam.h> + +#include <util/string/vector.h> + +namespace NYT::NMonitoring { + +using namespace NYTree; +using namespace NYson; +using namespace NHttp; +using namespace NConcurrency; +using namespace NJson; + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(EVerb, + (Get) + (List) +); + +//////////////////////////////////////////////////////////////////////////////// + +void Initialize( + const NHttp::IServerPtr& monitoringServer, + const NProfiling::TSolomonExporterConfigPtr& config, + TMonitoringManagerPtr* monitoringManager, + NYTree::IMapNodePtr* orchidRoot) +{ + *monitoringManager = New<TMonitoringManager>(); + (*monitoringManager)->Register("/ref_counted", CreateRefCountedTrackerStatisticsProducer()); + (*monitoringManager)->Register("/solomon", BIND([] (NYson::IYsonConsumer* consumer) { + auto tags = NProfiling::TSolomonRegistry::Get()->GetDynamicTags(); + + BuildYsonFluently(consumer) + .BeginMap() + .Item("dynamic_tags").Value(THashMap<TString, TString>(tags.begin(), tags.end())) + .EndMap(); + })); + (*monitoringManager)->Start(); + + *orchidRoot = NYTree::GetEphemeralNodeFactory(true)->CreateMap(); + SetNodeByYPath( + *orchidRoot, + "/monitoring", + CreateVirtualNode((*monitoringManager)->GetService())); + +#ifdef _linux_ + auto buildInfo = NYTProf::TBuildInfo::GetDefault(); + buildInfo.BinaryVersion = GetVersion(); + + SetNodeByYPath( + *orchidRoot, + "/build_info", + NYTree::BuildYsonNodeFluently() + .BeginMap() + .Item("arc_revision").Value(buildInfo.ArcRevision) + .Item("binary_version").Value(buildInfo.BinaryVersion) + .Item("build_type").Value(buildInfo.BuildType) + .EndMap()); +#endif + + if (monitoringServer) { + auto exporter = New<NProfiling::TSolomonExporter>(config); + exporter->Register("/solomon", monitoringServer); + exporter->Start(); + + SetNodeByYPath( + *orchidRoot, + "/sensors", + CreateVirtualNode(exporter->GetSensorService())); + +#ifdef _linux_ + NYTProf::Register(monitoringServer, "/ytprof", buildInfo); + NBacktraceIntrospector::Register(monitoringServer, "/backtrace"); +#endif + monitoringServer->AddHandler( + "/orchid/", + GetOrchidYPathHttpHandler(*orchidRoot)); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +class TYPathHttpHandler + : public IHttpHandler +{ +public: + explicit TYPathHttpHandler(IYPathServicePtr service) + : Service_(std::move(service)) + { } + + void HandleRequest( + const IRequestPtr& req, + const IResponseWriterPtr& rsp) override + { + const TStringBuf orchidPrefix = "/orchid"; + + TString path{req->GetUrl().Path}; + if (!path.StartsWith(orchidPrefix)) { + THROW_ERROR_EXCEPTION("HTTP request must start with %Qv prefix", + orchidPrefix) + << TErrorAttribute("path", path); + } + + path = path.substr(orchidPrefix.size(), TString::npos); + TCgiParameters params(req->GetUrl().RawQuery); + + auto verb = EVerb::Get; + + auto options = CreateEphemeralAttributes(); + for (const auto& param : params) { + if (param.first == "verb") { + verb = ParseEnum<EVerb>(param.second); + } else { + // Just a check, IAttributeDictionary takes raw YSON anyway. + try { + ValidateYson(TYsonString(param.second), DefaultYsonParserNestingLevelLimit); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Error parsing value of query parameter %Qv", + param.first) + << ex; + } + + options->SetYson(param.first, TYsonString(param.second)); + } + } + + TYsonString result; + switch (verb) { + case EVerb::Get: { + auto ypathReq = TYPathProxy::Get(path); + ToProto(ypathReq->mutable_options(), *options); + auto ypathRsp = WaitFor(ExecuteVerb(Service_, ypathReq)) + .ValueOrThrow(); + result = TYsonString(ypathRsp->value()); + break; + } + case EVerb::List: { + auto ypathReq = TYPathProxy::List(path); + auto ypathRsp = WaitFor(ExecuteVerb(Service_, ypathReq)) + .ValueOrThrow(); + result = TYsonString(ypathRsp->value()); + break; + } + default: + YT_ABORT(); + } + + rsp->SetStatus(EStatusCode::OK); + NHttp::ReplyJson(rsp, [&] (NYson::IYsonConsumer* writer) { + Serialize(result, writer); + }); + WaitFor(rsp->Close()) + .ThrowOnError(); + } + +private: + const IYPathServicePtr Service_; +}; + +IHttpHandlerPtr GetOrchidYPathHttpHandler(const IYPathServicePtr& service) +{ + return WrapYTException(New<TYPathHttpHandler>(service)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NMonitoring diff --git a/yt/yt/library/monitoring/http_integration.h b/yt/yt/library/monitoring/http_integration.h new file mode 100644 index 0000000000..48c12ca8a8 --- /dev/null +++ b/yt/yt/library/monitoring/http_integration.h @@ -0,0 +1,28 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/ypath_service.h> + +#include <yt/yt/core/http/public.h> + +#include <yt/yt/library/profiling/solomon/public.h> + +namespace NYT::NMonitoring { + +//////////////////////////////////////////////////////////////////////////////// + +void Initialize( + const NHttp::IServerPtr& monitoringServer, + const NProfiling::TSolomonExporterConfigPtr& solomonExporterConfig, + TMonitoringManagerPtr* monitoringManager, + NYTree::IMapNodePtr* orchidRoot); + +NHttp::IHttpHandlerPtr CreateTracingHttpHandler(); + +NHttp::IHttpHandlerPtr GetOrchidYPathHttpHandler( + const NYTree::IYPathServicePtr& service); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NMonitoring diff --git a/yt/yt/library/monitoring/monitoring_manager.cpp b/yt/yt/library/monitoring/monitoring_manager.cpp new file mode 100644 index 0000000000..ef642034a4 --- /dev/null +++ b/yt/yt/library/monitoring/monitoring_manager.cpp @@ -0,0 +1,177 @@ +#include "monitoring_manager.h" +#include "private.h" + +#include <yt/yt/core/concurrency/action_queue.h> +#include <yt/yt/core/concurrency/periodic_executor.h> + +#include <yt/yt/core/ytree/convert.h> +#include <yt/yt/core/ytree/ephemeral_node_factory.h> +#include <yt/yt/core/ytree/node.h> +#include <yt/yt/core/ytree/tree_visitor.h> +#include <yt/yt/core/ytree/ypath_detail.h> +#include <yt/yt/core/ytree/ypath_client.h> + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NMonitoring { + +using namespace NYTree; +using namespace NYPath; +using namespace NYson; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = MonitoringLogger; + +static const auto UpdatePeriod = TDuration::Seconds(3); +static const auto EmptyRoot = GetEphemeralNodeFactory()->CreateMap(); + +//////////////////////////////////////////////////////////////////////////////// + +class TMonitoringManager::TImpl + : public TRefCounted +{ +public: + void Register(const TYPath& path, TYsonProducer producer) + { + auto guard = Guard(SpinLock_); + YT_VERIFY(PathToProducer_.emplace(path, producer).second); + } + + void Unregister(const TYPath& path) + { + auto guard = Guard(SpinLock_); + YT_VERIFY(PathToProducer_.erase(path) == 1); + } + + IYPathServicePtr GetService() + { + return New<TYPathService>(this); + } + + void Start() + { + auto guard = Guard(SpinLock_); + + YT_VERIFY(!Started_); + + PeriodicExecutor_ = New<TPeriodicExecutor>( + ActionQueue_->GetInvoker(), + BIND(&TImpl::Update, MakeWeak(this)), + UpdatePeriod); + PeriodicExecutor_->Start(); + + Started_ = true; + } + + void Stop() + { + auto guard = Guard(SpinLock_); + + if (!Started_) + return; + + Started_ = false; + YT_UNUSED_FUTURE(PeriodicExecutor_->Stop()); + Root_.Reset(); + } + +private: + class TYPathService + : public TYPathServiceBase + { + public: + explicit TYPathService(TIntrusivePtr<TImpl> owner) + : Owner_(std::move(owner)) + { } + + TResolveResult Resolve(const TYPath& path, const IYPathServiceContextPtr& /*context*/) override + { + return TResolveResultThere{Owner_->GetRoot(), path}; + } + + private: + const TIntrusivePtr<TImpl> Owner_; + + }; + + bool Started_ = false; + TActionQueuePtr ActionQueue_ = New<TActionQueue>("Monitoring"); + TPeriodicExecutorPtr PeriodicExecutor_; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, SpinLock_); + THashMap<TString, NYson::TYsonProducer> PathToProducer_; + IMapNodePtr Root_; + + void Update() + { + YT_LOG_DEBUG("Started updating monitoring state"); + + YT_PROFILE_TIMING("/monitoring/update_time") { + auto newRoot = GetEphemeralNodeFactory()->CreateMap(); + + THashMap<TString, NYson::TYsonProducer> pathToProducer;; + { + auto guard = Guard(SpinLock_); + pathToProducer = PathToProducer_; + } + + for (const auto& [path, producer] : pathToProducer) { + auto value = ConvertToYsonString(producer); + SyncYPathSet(newRoot, path, value); + } + + if (Started_) { + auto guard = Guard(SpinLock_); + std::swap(Root_, newRoot); + } + } + YT_LOG_DEBUG("Finished updating monitoring state"); + } + + IMapNodePtr GetRoot() + { + auto guard = Guard(SpinLock_); + return Root_ ? Root_ : EmptyRoot; + } +}; + +DEFINE_REFCOUNTED_TYPE(TMonitoringManager) + +//////////////////////////////////////////////////////////////////////////////// + +TMonitoringManager::TMonitoringManager() + : Impl_(New<TImpl>()) +{ } + +TMonitoringManager::~TMonitoringManager() = default; + +void TMonitoringManager::Register(const TYPath& path, TYsonProducer producer) +{ + Impl_->Register(path, producer); +} + +void TMonitoringManager::Unregister(const TYPath& path) +{ + Impl_->Unregister(path); +} + +IYPathServicePtr TMonitoringManager::GetService() +{ + return Impl_->GetService(); +} + +void TMonitoringManager::Start() +{ + Impl_->Start(); +} + +void TMonitoringManager::Stop() +{ + Impl_->Stop(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NMonitoring diff --git a/yt/yt/library/monitoring/monitoring_manager.h b/yt/yt/library/monitoring/monitoring_manager.h new file mode 100644 index 0000000000..fc5c3de6c7 --- /dev/null +++ b/yt/yt/library/monitoring/monitoring_manager.h @@ -0,0 +1,54 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/yson/consumer.h> + +#include <yt/yt/core/ypath/public.h> + +#include <yt/yt/core/ytree/public.h> + +namespace NYT::NMonitoring { + +//////////////////////////////////////////////////////////////////////////////// + +//! Exposes a tree assembled from results returned by a set of +//! registered NYson::TYsonProducer-s. +/*! + * \note + * The results are cached and periodically updated. + */ +class TMonitoringManager + : public TRefCounted +{ +public: + TMonitoringManager(); + ~TMonitoringManager(); + + //! Registers a new #producer for a given #path. + void Register(const NYPath::TYPath& path, NYson::TYsonProducer producer); + + //! Unregisters an existing producer for the specified #path. + void Unregister(const NYPath::TYPath& path); + + //! Returns the service representing the whole tree. + /*! + * \note The service is thread-safe. + */ + NYTree::IYPathServicePtr GetService(); + + //! Starts periodic updates. + void Start(); + + //! Stops periodic updates. + void Stop(); + +private: + class TImpl; + TIntrusivePtr<TImpl> Impl_; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NMonitoring diff --git a/yt/yt/library/monitoring/private.h b/yt/yt/library/monitoring/private.h new file mode 100644 index 0000000000..e2bfb31c78 --- /dev/null +++ b/yt/yt/library/monitoring/private.h @@ -0,0 +1,15 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/logging/log.h> + +namespace NYT::NMonitoring { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger MonitoringLogger("Monitoring"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NJournalClient diff --git a/yt/yt/library/monitoring/public.h b/yt/yt/library/monitoring/public.h new file mode 100644 index 0000000000..3514bdd858 --- /dev/null +++ b/yt/yt/library/monitoring/public.h @@ -0,0 +1,13 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +namespace NYT::NMonitoring { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_CLASS(TMonitoringManager) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NMonitoring diff --git a/yt/yt/library/monitoring/ya.make b/yt/yt/library/monitoring/ya.make new file mode 100644 index 0000000000..c2fccd99ac --- /dev/null +++ b/yt/yt/library/monitoring/ya.make @@ -0,0 +1,27 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + http_integration.cpp + monitoring_manager.cpp +) + +PEERDIR( + yt/yt/core + yt/yt/build + yt/yt/library/profiling + yt/yt/library/profiling/solomon + library/cpp/cgiparam +) + +IF (OS_LINUX) + PEERDIR( + yt/yt/library/ytprof + yt/yt/library/ytprof/http + + yt/yt/library/backtrace_introspector/http + ) +ENDIF() + +END() diff --git a/yt/yt/library/process/CMakeLists.darwin-x86_64.txt b/yt/yt/library/process/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..b66c679390 --- /dev/null +++ b/yt/yt/library/process/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,26 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-process) +target_compile_options(yt-library-process PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-process PUBLIC + contrib-libs-cxxsupp + yutil + yt-yt-core + contrib-libs-re2 +) +target_sources(yt-library-process PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/io_dispatcher.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/pipe.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/process.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/pty.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/subprocess.cpp +) diff --git a/yt/yt/library/process/CMakeLists.linux-aarch64.txt b/yt/yt/library/process/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..e065bba3b1 --- /dev/null +++ b/yt/yt/library/process/CMakeLists.linux-aarch64.txt @@ -0,0 +1,27 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-process) +target_compile_options(yt-library-process PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-process PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + contrib-libs-re2 +) +target_sources(yt-library-process PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/io_dispatcher.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/pipe.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/process.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/pty.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/subprocess.cpp +) diff --git a/yt/yt/library/process/CMakeLists.linux-x86_64.txt b/yt/yt/library/process/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..e065bba3b1 --- /dev/null +++ b/yt/yt/library/process/CMakeLists.linux-x86_64.txt @@ -0,0 +1,27 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-process) +target_compile_options(yt-library-process PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-process PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + contrib-libs-re2 +) +target_sources(yt-library-process PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/io_dispatcher.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/pipe.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/process.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/pty.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/subprocess.cpp +) diff --git a/yt/yt/library/process/CMakeLists.txt b/yt/yt/library/process/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/library/process/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/process/CMakeLists.windows-x86_64.txt b/yt/yt/library/process/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..3637ee7dae --- /dev/null +++ b/yt/yt/library/process/CMakeLists.windows-x86_64.txt @@ -0,0 +1,23 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-process) +target_link_libraries(yt-library-process PUBLIC + contrib-libs-cxxsupp + yutil + yt-yt-core + contrib-libs-re2 +) +target_sources(yt-library-process PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/io_dispatcher.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/pipe.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/process.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/pty.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/process/subprocess.cpp +) diff --git a/yt/yt/library/process/io_dispatcher.cpp b/yt/yt/library/process/io_dispatcher.cpp new file mode 100644 index 0000000000..7da757658d --- /dev/null +++ b/yt/yt/library/process/io_dispatcher.cpp @@ -0,0 +1,37 @@ +#include "io_dispatcher.h" + +#include <yt/yt/core/concurrency/thread_pool_poller.h> +#include <yt/yt/core/concurrency/poller.h> + +#include <yt/yt/core/misc/singleton.h> + +namespace NYT::NPipes { + +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +TIODispatcher::TIODispatcher() + : Poller_(BIND([] { return CreateThreadPoolPoller(1, "Pipes"); })) +{ } + +TIODispatcher::~TIODispatcher() = default; + +TIODispatcher* TIODispatcher::Get() +{ + return Singleton<TIODispatcher>(); +} + +IInvokerPtr TIODispatcher::GetInvoker() +{ + return Poller_.Value()->GetInvoker(); +} + +IPollerPtr TIODispatcher::GetPoller() +{ + return Poller_.Value(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/io_dispatcher.h b/yt/yt/library/process/io_dispatcher.h new file mode 100644 index 0000000000..2db1b34386 --- /dev/null +++ b/yt/yt/library/process/io_dispatcher.h @@ -0,0 +1,34 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/concurrency/public.h> + +#include <yt/yt/core/misc/lazy_ptr.h> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +class TIODispatcher +{ +public: + ~TIODispatcher(); + + static TIODispatcher* Get(); + + IInvokerPtr GetInvoker(); + + NConcurrency::IPollerPtr GetPoller(); + +private: + TIODispatcher(); + + Y_DECLARE_SINGLETON_FRIEND() + + TLazyIntrusivePtr<NConcurrency::IThreadPoolPoller> Poller_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/pipe.cpp b/yt/yt/library/process/pipe.cpp new file mode 100644 index 0000000000..f51d043f22 --- /dev/null +++ b/yt/yt/library/process/pipe.cpp @@ -0,0 +1,256 @@ +#include "pipe.h" +#include "private.h" +#include "io_dispatcher.h" + +#include <yt/yt/core/net/connection.h> + +#include <yt/yt/core/misc/proc.h> +#include <yt/yt/core/misc/fs.h> + +#include <sys/types.h> +#include <sys/stat.h> + +namespace NYT::NPipes { + +using namespace NNet; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = PipesLogger; + +//////////////////////////////////////////////////////////////////////////////// + +TNamedPipe::TNamedPipe(const TString& path, bool owning) + : Path_(path) + , Owning_(owning) +{ } + +TNamedPipe::~TNamedPipe() +{ + if (!Owning_) { + return; + } + + if (unlink(Path_.c_str()) == -1) { + YT_LOG_INFO(TError::FromSystem(), "Failed to unlink pipe %v", Path_); + } +} + +TNamedPipePtr TNamedPipe::Create(const TString& path, int permissions) +{ + auto pipe = New<TNamedPipe>(path, /* owning */ true); + pipe->Open(permissions); + YT_LOG_DEBUG("Named pipe created (Path: %v, Permissions: %v)", path, permissions); + return pipe; +} + +TNamedPipePtr TNamedPipe::FromPath(const TString& path) +{ + return New<TNamedPipe>(path, /* owning */ false); +} + +void TNamedPipe::Open(int permissions) +{ + if (mkfifo(Path_.c_str(), permissions) == -1) { + THROW_ERROR_EXCEPTION("Failed to create named pipe %v", Path_) + << TError::FromSystem(); + } +} + +IConnectionReaderPtr TNamedPipe::CreateAsyncReader() +{ + YT_VERIFY(!Path_.empty()); + return CreateInputConnectionFromPath(Path_, TIODispatcher::Get()->GetPoller(), MakeStrong(this)); +} + +IConnectionWriterPtr TNamedPipe::CreateAsyncWriter() +{ + YT_VERIFY(!Path_.empty()); + return CreateOutputConnectionFromPath(Path_, TIODispatcher::Get()->GetPoller(), MakeStrong(this)); +} + +TString TNamedPipe::GetPath() const +{ + return Path_; +} + +//////////////////////////////////////////////////////////////////////////////// + +TNamedPipeConfigPtr TNamedPipeConfig::Create(TString path, int fd, bool write) +{ + auto result = New<TNamedPipeConfig>(); + result->Path = std::move(path); + result->FD = fd; + result->Write = write; + + return result; +} + +void TNamedPipeConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("path", &TThis::Path) + .Default(); + + registrar.Parameter("fd", &TThis::FD) + .Default(0); + + registrar.Parameter("write", &TThis::Write) + .Default(false); +} + +DEFINE_REFCOUNTED_TYPE(TNamedPipeConfig) + +//////////////////////////////////////////////////////////////////////////////// + +TPipe::TPipe() +{ } + +TPipe::TPipe(TPipe&& pipe) +{ + Init(std::move(pipe)); +} + +TPipe::TPipe(int fd[2]) + : ReadFD_(fd[0]) + , WriteFD_(fd[1]) +{ } + +void TPipe::Init(TPipe&& other) +{ + ReadFD_ = other.ReadFD_; + WriteFD_ = other.WriteFD_; + other.ReadFD_ = InvalidFD; + other.WriteFD_ = InvalidFD; +} + +TPipe::~TPipe() +{ + if (ReadFD_ != InvalidFD) { + YT_VERIFY(TryClose(ReadFD_, false)); + } + + if (WriteFD_ != InvalidFD) { + YT_VERIFY(TryClose(WriteFD_, false)); + } +} + +void TPipe::operator=(TPipe&& other) +{ + if (this == &other) { + return; + } + + Init(std::move(other)); +} + +IConnectionWriterPtr TPipe::CreateAsyncWriter() +{ + YT_VERIFY(WriteFD_ != InvalidFD); + SafeMakeNonblocking(WriteFD_); + return CreateConnectionFromFD(ReleaseWriteFD(), {}, {}, TIODispatcher::Get()->GetPoller()); +} + +IConnectionReaderPtr TPipe::CreateAsyncReader() +{ + YT_VERIFY(ReadFD_ != InvalidFD); + SafeMakeNonblocking(ReadFD_); + return CreateConnectionFromFD(ReleaseReadFD(), {}, {}, TIODispatcher::Get()->GetPoller()); +} + +int TPipe::ReleaseReadFD() +{ + YT_VERIFY(ReadFD_ != InvalidFD); + auto fd = ReadFD_; + ReadFD_ = InvalidFD; + return fd; +} + +int TPipe::ReleaseWriteFD() +{ + YT_VERIFY(WriteFD_ != InvalidFD); + auto fd = WriteFD_; + WriteFD_ = InvalidFD; + return fd; +} + +int TPipe::GetReadFD() const +{ + YT_VERIFY(ReadFD_ != InvalidFD); + return ReadFD_; +} + +int TPipe::GetWriteFD() const +{ + YT_VERIFY(WriteFD_ != InvalidFD); + return WriteFD_; +} + +void TPipe::CloseReadFD() +{ + if (ReadFD_ == InvalidFD) { + return; + } + auto fd = ReadFD_; + ReadFD_ = InvalidFD; + SafeClose(fd, false); +} + +void TPipe::CloseWriteFD() +{ + if (WriteFD_ == InvalidFD) { + return; + } + auto fd = WriteFD_; + WriteFD_ = InvalidFD; + SafeClose(fd, false); +} + +//////////////////////////////////////////////////////////////////////////////// + +TString ToString(const TPipe& pipe) +{ + return Format("{ReadFD: %v, WriteFD: %v}", + pipe.GetReadFD(), + pipe.GetWriteFD()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TPipeFactory::TPipeFactory(int minFD) + : MinFD_(minFD) +{ } + +TPipeFactory::~TPipeFactory() +{ + for (int fd : ReservedFDs_) { + YT_VERIFY(TryClose(fd, false)); + } +} + +TPipe TPipeFactory::Create() +{ + while (true) { + int fd[2]; + SafePipe(fd); + if (fd[0] >= MinFD_ && fd[1] >= MinFD_) { + TPipe pipe(fd); + return pipe; + } else { + ReservedFDs_.push_back(fd[0]); + ReservedFDs_.push_back(fd[1]); + } + } +} + +void TPipeFactory::Clear() +{ + for (int& fd : ReservedFDs_) { + YT_VERIFY(TryClose(fd, false)); + fd = TPipe::InvalidFD; + } + ReservedFDs_.clear(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/pipe.h b/yt/yt/library/process/pipe.h new file mode 100644 index 0000000000..10da81cc8a --- /dev/null +++ b/yt/yt/library/process/pipe.h @@ -0,0 +1,114 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/net/public.h> + +#include <yt/yt/core/ytree/yson_struct.h> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +class TNamedPipe + : public TRefCounted +{ +public: + ~TNamedPipe(); + static TNamedPipePtr Create(const TString& path, int permissions = 0660); + static TNamedPipePtr FromPath(const TString& path); + + NNet::IConnectionReaderPtr CreateAsyncReader(); + NNet::IConnectionWriterPtr CreateAsyncWriter(); + + TString GetPath() const; + +private: + const TString Path_; + + //! Whether pipe was created by this class + //! and should be removed in destructor. + const bool Owning_; + + explicit TNamedPipe(const TString& path, bool owning); + void Open(int permissions); + DECLARE_NEW_FRIEND() +}; + +DEFINE_REFCOUNTED_TYPE(TNamedPipe) + +//////////////////////////////////////////////////////////////////////////////// + +class TNamedPipeConfig + : public NYTree::TYsonStruct +{ +public: + TString Path; + int FD = 0; + bool Write = false; + + static TNamedPipeConfigPtr Create(TString path, int fd, bool write); + + REGISTER_YSON_STRUCT(TNamedPipeConfig); + + static void Register(TRegistrar registrar); +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TPipe + : public TNonCopyable +{ +public: + static const int InvalidFD = -1; + + TPipe(); + TPipe(TPipe&& pipe); + ~TPipe(); + + void operator=(TPipe&& other); + + void CloseReadFD(); + void CloseWriteFD(); + + NNet::IConnectionReaderPtr CreateAsyncReader(); + NNet::IConnectionWriterPtr CreateAsyncWriter(); + + int ReleaseReadFD(); + int ReleaseWriteFD(); + + int GetReadFD() const; + int GetWriteFD() const; + +private: + int ReadFD_ = InvalidFD; + int WriteFD_ = InvalidFD; + + TPipe(int fd[2]); + void Init(TPipe&& other); + + friend class TPipeFactory; +}; + +TString ToString(const TPipe& pipe); + +//////////////////////////////////////////////////////////////////////////////// + +class TPipeFactory +{ +public: + explicit TPipeFactory(int minFD = 0); + ~TPipeFactory(); + + TPipe Create(); + + void Clear(); + +private: + const int MinFD_; + std::vector<int> ReservedFDs_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/private.h b/yt/yt/library/process/private.h new file mode 100644 index 0000000000..95b2ffb0f5 --- /dev/null +++ b/yt/yt/library/process/private.h @@ -0,0 +1,14 @@ +#pragma once + +#include <yt/yt/core/logging/log.h> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger PipesLogger("Pipes"); +inline const NLogging::TLogger PtyLogger("Pty"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/process.cpp b/yt/yt/library/process/process.cpp new file mode 100644 index 0000000000..809a50ed9a --- /dev/null +++ b/yt/yt/library/process/process.cpp @@ -0,0 +1,697 @@ +#include "process.h" +#include "pipe.h" + +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/error.h> +#include <yt/yt/core/misc/fs.h> +#include <yt/yt/core/misc/finally.h> + +#include <yt/yt/core/concurrency/periodic_executor.h> +#include <yt/yt/core/concurrency/delayed_executor.h> + +#include <library/cpp/yt/system/handle_eintr.h> + +#include <util/folder/dirut.h> + +#include <util/generic/guid.h> + +#include <util/string/ascii.h> + +#include <util/string/util.h> + +#include <util/system/env.h> +#include <util/system/execpath.h> +#include <util/system/maxlen.h> +#include <util/system/shellcommand.h> + +#ifdef _unix_ + #include <unistd.h> + #include <errno.h> + #include <sys/wait.h> + #include <sys/resource.h> +#endif + +#ifdef _darwin_ + #include <crt_externs.h> + #define environ (*_NSGetEnviron()) +#endif + +namespace NYT { + +using namespace NPipes; +using namespace NNet; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +static inline const NLogging::TLogger Logger("Process"); + +static constexpr pid_t InvalidProcessId = -1; + +static constexpr int ExecveRetryCount = 5; +static constexpr auto ExecveRetryTimeout = TDuration::Seconds(1); + +static constexpr int ResolveRetryCount = 5; +static constexpr auto ResolveRetryTimeout = TDuration::Seconds(1); + +//////////////////////////////////////////////////////////////////////////////// + +TErrorOr<TString> ResolveBinaryPath(const TString& binary) +{ + auto Logger = NYT::Logger + .WithTag("Binary: %v", binary); + + YT_LOG_DEBUG("Resolving binary path"); + + std::vector<TError> accumulatedErrors; + + auto test = [&] (const char* path) { + YT_LOG_DEBUG("Probing path (Path: %v)", path); + if (access(path, R_OK | X_OK) == 0) { + return true; + } else { + auto error = TError("Cannot run %Qlv", path) << TError::FromSystem(); + accumulatedErrors.push_back(std::move(error)); + return false; + } + }; + + auto failure = [&] { + auto error = TError( + EProcessErrorCode::CannotResolveBinary, + "Cannot resolve binary %Qlv", + binary); + error.MutableInnerErrors()->swap(accumulatedErrors); + YT_LOG_DEBUG(error, "Error resolving binary path"); + return error; + }; + + auto success = [&] (const TString& path) { + YT_LOG_DEBUG("Binary resolved (Path: %v)", path); + return path; + }; + + if (test(binary.c_str())) { + return success(binary); + } + + // If this is an absolute path, stop here. + if (binary.empty() || binary[0] == '/') { + return failure(); + } + + // XXX(sandello): Sometimes we drop PATH from environment when spawning isolated processes. + // In this case, try to locate somewhere nearby. + { + auto execPathDirName = GetDirName(GetExecPath()); + YT_LOG_DEBUG("Looking in our exec path directory (ExecPathDir: %v)", execPathDirName); + auto probe = TString::Join(execPathDirName, "/", binary); + if (test(probe.c_str())) { + return success(probe); + } + } + + std::array<char, MAX_PATH> buffer; + + auto envPathStr = GetEnv("PATH"); + TStringBuf envPath(envPathStr); + TStringBuf envPathItem; + + YT_LOG_DEBUG("Looking for binary in PATH (Path: %v)", envPathStr); + + while (envPath.NextTok(':', envPathItem)) { + if (buffer.size() < 2 + envPathItem.size() + binary.size()) { + continue; + } + + size_t index = 0; + std::copy(envPathItem.begin(), envPathItem.end(), buffer.begin() + index); + index += envPathItem.size(); + buffer[index] = '/'; + index += 1; + std::copy(binary.begin(), binary.end(), buffer.begin() + index); + index += binary.size(); + buffer[index] = 0; + + if (test(buffer.data())) { + return success(TString(buffer.data(), index)); + } + } + + return failure(); +} + +bool TryKillProcessByPid(int pid, int signal) +{ +#ifdef _unix_ + YT_VERIFY(pid != -1); + int result = ::kill(pid, signal); + // Ignore ESRCH because process may have died just before TryKillProcessByPid. + if (result < 0 && errno != ESRCH) { + return false; + } + return true; +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +#ifdef _unix_ + +bool TryWaitid(idtype_t idtype, id_t id, siginfo_t *infop, int options) +{ + if (infop != nullptr) { + // See comment below. + infop->si_pid = 0; + } + + siginfo_t info; + ::memset(&info, 0, sizeof(info)); + auto res = HandleEintr(::waitid, idtype, id, infop != nullptr ? infop : &info, options); + + if (res == 0) { + // According to man wait. + // If WNOHANG was specified in options and there were + // no children in a waitable state, then waitid() returns 0 immediately. + // To distinguish this case from that where a child + // was in a waitable state, zero out the si_pid field + // before the call and check for a nonzero value in this field after + // the call returns. + if (infop && infop->si_pid == 0) { + return false; + } + return true; + } + + return false; +} + +void Wait4OrDie(pid_t id, int* status, int options, rusage* rusage) +{ + auto res = HandleEintr(::wait4, id, status, options, rusage); + if (res == -1) { + YT_LOG_FATAL(TError::FromSystem(), "Wait4 failed"); + } +} + +void Cleanup(int pid) +{ + YT_VERIFY(pid > 0); + + YT_VERIFY(TryKillProcessByPid(pid, 9)); + YT_VERIFY(TryWaitid(P_PID, pid, nullptr, WEXITED)); +} + +bool TrySetSignalMask(const sigset_t* sigmask, sigset_t* oldSigmask) +{ + int error = pthread_sigmask(SIG_SETMASK, sigmask, oldSigmask); + if (error != 0) { + return false; + } + return true; +} + +bool TryResetSignals() +{ + for (int sig = 1; sig < NSIG; ++sig) { + // Ignore invalid signal errors. + ::signal(sig, SIG_DFL); + } + return true; +} + +#endif + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +TProcessBase::TProcessBase(const TString& path) + : Path_(path) + , ProcessId_(InvalidProcessId) +{ } + +void TProcessBase::AddArgument(TStringBuf arg) +{ + YT_VERIFY(ProcessId_ == InvalidProcessId && !Finished_); + + Args_.push_back(Capture(arg)); +} + +void TProcessBase::AddEnvVar(TStringBuf var) +{ + YT_VERIFY(ProcessId_ == InvalidProcessId && !Finished_); + + Env_.push_back(Capture(var)); +} + +void TProcessBase::AddArguments(std::initializer_list<TStringBuf> args) +{ + for (auto arg : args) { + AddArgument(arg); + } +} + +void TProcessBase::AddArguments(const std::vector<TString>& args) +{ + for (const auto& arg : args) { + AddArgument(arg); + } +} + +void TProcessBase::SetWorkingDirectory(const TString& path) +{ + WorkingDirectory_ = path; +} + +void TProcessBase::CreateProcessGroup() +{ + CreateProcessGroup_ = true; +} + +//////////////////////////////////////////////////////////////////////////////// + +TSimpleProcess::TSimpleProcess(const TString& path, bool copyEnv, TDuration pollPeriod) + // TString is guaranteed to be zero-terminated. + // https://wiki.yandex-team.ru/Development/Poisk/arcadia/util/TStringAndTStringBuf#sobstvennosimvoly + : TProcessBase(path) + , PollPeriod_(pollPeriod) + , PipeFactory_(3) +{ + AddArgument(path); + + if (copyEnv) { + for (char** envIt = environ; *envIt; ++envIt) { + Env_.push_back(Capture(*envIt)); + } + } +} + +void TSimpleProcess::AddDup2FileAction(int oldFD, int newFD) +{ + TSpawnAction action{ + std::bind(TryDup2, oldFD, newFD), + Format("Error duplicating %v file descriptor to %v in child process", oldFD, newFD) + }; + + MaxSpawnActionFD_ = std::max(MaxSpawnActionFD_, newFD); + SpawnActions_.push_back(action); +} + +IConnectionReaderPtr TSimpleProcess::GetStdOutReader() +{ + auto& pipe = StdPipes_[STDOUT_FILENO]; + pipe = PipeFactory_.Create(); + AddDup2FileAction(pipe.GetWriteFD(), STDOUT_FILENO); + return pipe.CreateAsyncReader(); +} + +IConnectionReaderPtr TSimpleProcess::GetStdErrReader() +{ + auto& pipe = StdPipes_[STDERR_FILENO]; + pipe = PipeFactory_.Create(); + AddDup2FileAction(pipe.GetWriteFD(), STDERR_FILENO); + return pipe.CreateAsyncReader(); +} + +IConnectionWriterPtr TSimpleProcess::GetStdInWriter() +{ + auto& pipe = StdPipes_[STDIN_FILENO]; + pipe = PipeFactory_.Create(); + AddDup2FileAction(pipe.GetReadFD(), STDIN_FILENO); + return pipe.CreateAsyncWriter(); +} + +TFuture<void> TProcessBase::Spawn() +{ + try { + // Resolve binary path. + std::vector<TError> innerErrors; + for (int retryIndex = ResolveRetryCount; retryIndex >= 0; --retryIndex) { + auto errorOrPath = ResolveBinaryPath(Path_); + if (errorOrPath.IsOK()) { + ResolvedPath_ = errorOrPath.Value(); + break; + } + + innerErrors.push_back(errorOrPath); + + if (retryIndex == 0) { + THROW_ERROR_EXCEPTION("Failed to resolve binary path %v", Path_) + << innerErrors; + } + + TDelayedExecutor::WaitForDuration(ResolveRetryTimeout); + } + + DoSpawn(); + } catch (const std::exception& ex) { + FinishedPromise_.TrySet(ex); + } + return FinishedPromise_; +} + +void TSimpleProcess::DoSpawn() +{ +#ifdef _unix_ + auto finally = Finally([&] () { + StdPipes_[STDIN_FILENO].CloseReadFD(); + StdPipes_[STDOUT_FILENO].CloseWriteFD(); + StdPipes_[STDERR_FILENO].CloseWriteFD(); + PipeFactory_.Clear(); + }); + + YT_VERIFY(ProcessId_ == InvalidProcessId && !Finished_); + + // Make sure no spawn action closes Pipe_.WriteFD + TPipeFactory pipeFactory(MaxSpawnActionFD_ + 1); + Pipe_ = pipeFactory.Create(); + pipeFactory.Clear(); + + YT_LOG_DEBUG("Spawning new process (Path: %v, ErrorPipe: %v, Arguments: %v, Environment: %v)", + ResolvedPath_, + Pipe_, + Args_, + Env_); + + Env_.push_back(nullptr); + Args_.push_back(nullptr); + + // Block all signals around vfork; see http://ewontfix.com/7/ + + // As the child may run in the same address space as the parent until + // the actual execve() system call, any (custom) signal handlers that + // the parent has might alter parent's memory if invoked in the child, + // with undefined results. So we block all signals in the parent before + // vfork(), which will cause them to be blocked in the child as well (we + // rely on the fact that Linux, just like all sane implementations, only + // clones the calling thread). Then, in the child, we reset all signals + // to their default dispositions (while still blocked), and unblock them + // (so the exec()ed process inherits the parent's signal mask) + + sigset_t allBlocked; + sigfillset(&allBlocked); + sigset_t oldSignals; + + if (!TrySetSignalMask(&allBlocked, &oldSignals)) { + THROW_ERROR_EXCEPTION("Failed to block all signals") + << TError::FromSystem(); + } + + SpawnActions_.push_back(TSpawnAction{ + TryResetSignals, + "Error resetting signals to default disposition in child process: signal failed" + }); + + SpawnActions_.push_back(TSpawnAction{ + std::bind(TrySetSignalMask, &oldSignals, nullptr), + "Error unblocking signals in child process: pthread_sigmask failed" + }); + + if (!WorkingDirectory_.empty()) { + SpawnActions_.push_back(TSpawnAction{ + [&] () { + NFs::SetCurrentWorkingDirectory(WorkingDirectory_); + return true; + }, + "Error changing working directory" + }); + } + + if (CreateProcessGroup_) { + SpawnActions_.push_back(TSpawnAction{ + [&] () { + setpgrp(); + return true; + }, + "Error creating process group" + }); + } + + SpawnActions_.push_back(TSpawnAction{ + [this] { + for (int retryIndex = 0; retryIndex < ExecveRetryCount; ++retryIndex) { + // Execve may fail, if called binary is being updated, e.g. during yandex-yt package update. + // So we'd better retry several times. + // For example see YT-6352. + TryExecve(ResolvedPath_.c_str(), Args_.data(), Env_.data()); + if (retryIndex < ExecveRetryCount - 1) { + Sleep(ExecveRetryTimeout); + } + } + // If we are still here, return failure. + return false; + }, + "Error starting child process: execve failed" + }); + + SpawnChild(); + + // This should not fail ever. + YT_VERIFY(TrySetSignalMask(&oldSignals, nullptr)); + + Pipe_.CloseWriteFD(); + + ValidateSpawnResult(); + + AsyncWaitExecutor_ = New<TPeriodicExecutor>( + GetSyncInvoker(), + BIND(&TSimpleProcess::AsyncPeriodicTryWait, MakeStrong(this)), + PollPeriod_); + + AsyncWaitExecutor_->Start(); +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +void TSimpleProcess::SpawnChild() +{ + // NB: fork() will cause data corruption when run concurrently with + // Disk IO on O_DIRECT file descriptor. Seems like vfork don't suffer from the same issue. + +#ifdef _unix_ + int pid = vfork(); + + if (pid < 0) { + THROW_ERROR_EXCEPTION("Error starting child process: vfork failed") + << TErrorAttribute("path", ResolvedPath_) + << TError::FromSystem(); + } + + if (pid == 0) { + try { + Child(); + } catch (...) { + YT_ABORT(); + } + } + + ProcessId_ = pid; + Started_ = true; +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +void TSimpleProcess::ValidateSpawnResult() +{ +#ifdef _unix_ + int data[2]; + ssize_t res; + res = HandleEintr(::read, Pipe_.GetReadFD(), &data, sizeof(data)); + Pipe_.CloseReadFD(); + + if (res == 0) { + // Child successfully spawned or was killed by a signal. + // But there is no way to distinguish between these two cases: + // * child killed by signal before exec + // * child killed by signal after exec + // So we treat kill-before-exec the same way as kill-after-exec. + YT_LOG_DEBUG("Child process spawned successfully (Pid: %v)", ProcessId_); + return; + } + + YT_VERIFY(res == sizeof(data)); + Finished_ = true; + + Cleanup(ProcessId_); + ProcessId_ = InvalidProcessId; + + int actionIndex = data[0]; + int errorCode = data[1]; + + YT_VERIFY(0 <= actionIndex && actionIndex < std::ssize(SpawnActions_)); + const auto& action = SpawnActions_[actionIndex]; + THROW_ERROR_EXCEPTION("%v", action.ErrorMessage) + << TError::FromSystem(errorCode); +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +#ifdef _unix_ +void TSimpleProcess::AsyncPeriodicTryWait() +{ + siginfo_t processInfo; + memset(&processInfo, 0, sizeof(siginfo_t)); + + // Note WNOWAIT flag. + // This call just waits for a process to be finished but does not clear zombie flag. + + if (!TryWaitid(P_PID, ProcessId_, &processInfo, WEXITED | WNOWAIT | WNOHANG) || + processInfo.si_pid != ProcessId_) + { + return; + } + + YT_UNUSED_FUTURE(AsyncWaitExecutor_->Stop()); + AsyncWaitExecutor_ = nullptr; + + // This call just should return immediately + // because we have already waited for this process with WNOHANG + rusage rusage; + Wait4OrDie(ProcessId_, nullptr, WNOHANG, &rusage); + + Finished_ = true; + auto error = ProcessInfoToError(processInfo); + YT_LOG_DEBUG("Process finished (Pid: %v, MajFaults: %d, Error: %v)", ProcessId_, rusage.ru_majflt, error); + + FinishedPromise_.Set(error); +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +void TSimpleProcess::Kill(int signal) +{ +#ifdef _unix_ + if (!Started_) { + THROW_ERROR_EXCEPTION("Process is not started yet"); + } + + if (Finished_) { + return; + } + + YT_LOG_DEBUG("Killing child process (Pid: %v)", ProcessId_); + + bool result = false; + if (!CreateProcessGroup_) { + result = TryKillProcessByPid(ProcessId_, signal); + } else { + result = TryKillProcessByPid(-1 * ProcessId_, signal); + } + + if (!result) { + THROW_ERROR_EXCEPTION("Failed to kill child process %v", ProcessId_) + << TError::FromSystem(); + } + return; +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +TString TProcessBase::GetPath() const +{ + return Path_; +} + +int TProcessBase::GetProcessId() const +{ + return ProcessId_; +} + +bool TProcessBase::IsStarted() const +{ + return Started_; +} + +bool TProcessBase::IsFinished() const +{ + return Finished_; +} + +TString TProcessBase::GetCommandLine() const +{ + TStringBuilder builder; + builder.AppendString(Path_); + + bool first = true; + for (const auto& arg_ : Args_) { + TStringBuf arg(arg_); + if (first) { + first = false; + } else { + if (arg) { + builder.AppendChar(' '); + bool needQuote = false; + for (size_t i = 0; i < arg.length(); ++i) { + if (!IsAsciiAlnum(arg[i]) && + arg[i] != '-' && arg[i] != '_' && arg[i] != '=' && arg[i] != '/') + { + needQuote = true; + break; + } + } + if (needQuote) { + builder.AppendChar('"'); + TStringBuf left, right; + while (arg.TrySplit('"', left, right)) { + builder.AppendString(left); + builder.AppendString("\\\""); + arg = right; + } + builder.AppendString(arg); + builder.AppendChar('"'); + } else { + builder.AppendString(arg); + } + } + } + } + + return builder.Flush(); +} + +const char* TProcessBase::Capture(TStringBuf arg) +{ + StringHolders_.push_back(TString(arg)); + return StringHolders_.back().c_str(); +} + +void TSimpleProcess::Child() +{ +#ifdef _unix_ + for (int actionIndex = 0; actionIndex < std::ssize(SpawnActions_); ++actionIndex) { + auto& action = SpawnActions_[actionIndex]; + if (!action.Callback()) { + // Report error through the pipe. + int data[] = { + actionIndex, + errno + }; + + // According to pipe(7) write of small buffer is atomic. + ssize_t size = HandleEintr(::write, Pipe_.GetWriteFD(), &data, sizeof(data)); + YT_VERIFY(size == sizeof(data)); + _exit(1); + } + } +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/process/process.h b/yt/yt/library/process/process.h new file mode 100644 index 0000000000..b38ae3f4b3 --- /dev/null +++ b/yt/yt/library/process/process.h @@ -0,0 +1,125 @@ +#pragma once + +#include "pipe.h" + +#include <yt/yt/core/misc/error.h> + +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/concurrency/public.h> + +#include <atomic> +#include <vector> +#include <array> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +TErrorOr<TString> ResolveBinaryPath(const TString& binary); +bool TryKillProcessByPid(int pid, int signal); + +//////////////////////////////////////////////////////////////////////////////// + +class TProcessBase + : public TRefCounted +{ +public: + explicit TProcessBase(const TString& path); + + void AddArgument(TStringBuf arg); + void AddEnvVar(TStringBuf var); + + void AddArguments(std::initializer_list<TStringBuf> args); + void AddArguments(const std::vector<TString>& args); + + void SetWorkingDirectory(const TString& path); + void CreateProcessGroup(); + + virtual NNet::IConnectionWriterPtr GetStdInWriter() = 0; + virtual NNet::IConnectionReaderPtr GetStdOutReader() = 0; + virtual NNet::IConnectionReaderPtr GetStdErrReader() = 0; + + TFuture<void> Spawn(); + virtual void Kill(int signal) = 0; + + TString GetPath() const; + int GetProcessId() const; + bool IsStarted() const; + bool IsFinished() const; + + TString GetCommandLine() const; + +protected: + const TString Path_; + + int ProcessId_; + std::atomic<bool> Started_ = {false}; + std::atomic<bool> Finished_ = {false}; + int MaxSpawnActionFD_ = - 1; + NPipes::TPipe Pipe_; + // Container for owning string data. Use std::deque because it never moves contained objects. + std::deque<std::string> StringHolders_; + std::vector<const char*> Args_; + std::vector<const char*> Env_; + TString ResolvedPath_; + TString WorkingDirectory_; + bool CreateProcessGroup_ = false; + TPromise<void> FinishedPromise_ = NewPromise<void>(); + + virtual void DoSpawn() = 0; + const char* Capture(TStringBuf arg); + +private: + void SpawnChild(); + void ValidateSpawnResult(); + void Child(); + void AsyncPeriodicTryWait(); +}; + +DEFINE_REFCOUNTED_TYPE(TProcessBase) + +//////////////////////////////////////////////////////////////////////////////// + +// Read this +// http://ewontfix.com/7/ +// before making any changes. +class TSimpleProcess + : public TProcessBase +{ +public: + explicit TSimpleProcess( + const TString& path, + bool copyEnv = true, + TDuration pollPeriod = TDuration::MilliSeconds(100)); + void Kill(int signal) override; + NNet::IConnectionWriterPtr GetStdInWriter() override; + NNet::IConnectionReaderPtr GetStdOutReader() override; + NNet::IConnectionReaderPtr GetStdErrReader() override; + +private: + const TDuration PollPeriod_; + + NPipes::TPipeFactory PipeFactory_; + std::array<NPipes::TPipe, 3> StdPipes_; + + NConcurrency::TPeriodicExecutorPtr AsyncWaitExecutor_; + struct TSpawnAction + { + std::function<bool()> Callback; + TString ErrorMessage; + }; + + std::vector<TSpawnAction> SpawnActions_; + + void AddDup2FileAction(int oldFD, int newFD); + void DoSpawn() override; + void SpawnChild(); + void ValidateSpawnResult(); + void AsyncPeriodicTryWait(); + void Child(); +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/process/pty.cpp b/yt/yt/library/process/pty.cpp new file mode 100644 index 0000000000..fc972d38ea --- /dev/null +++ b/yt/yt/library/process/pty.cpp @@ -0,0 +1,64 @@ +#include "pty.h" + +#include "io_dispatcher.h" + +#include <yt/yt/core/misc/common.h> +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/net/connection.h> + +namespace NYT::NPipes { + +using namespace NNet; + +//////////////////////////////////////////////////////////////////////////////// + +TPty::TPty(int height, int width) +{ + SafeOpenPty(&MasterFD_, &SlaveFD_, height, width); +} + +TPty::~TPty() +{ + if (MasterFD_ != InvalidFD) { + YT_VERIFY(TryClose(MasterFD_, false)); + } + + if (SlaveFD_ != InvalidFD) { + YT_VERIFY(TryClose(SlaveFD_, false)); + } +} + +IConnectionWriterPtr TPty::CreateMasterAsyncWriter() +{ + YT_VERIFY(MasterFD_ != InvalidFD); + int fd = SafeDup(MasterFD_); + SafeSetCloexec(fd); + SafeMakeNonblocking(fd); + return CreateConnectionFromFD(fd, {}, {}, TIODispatcher::Get()->GetPoller()); +} + +IConnectionReaderPtr TPty::CreateMasterAsyncReader() +{ + YT_VERIFY(MasterFD_ != InvalidFD); + int fd = SafeDup(MasterFD_); + SafeSetCloexec(fd); + SafeMakeNonblocking(fd); + return CreateConnectionFromFD(fd, {}, {}, TIODispatcher::Get()->GetPoller()); +} + +int TPty::GetMasterFD() const +{ + YT_VERIFY(MasterFD_ != InvalidFD); + return MasterFD_; +} + +int TPty::GetSlaveFD() const +{ + YT_VERIFY(SlaveFD_ != InvalidFD); + return SlaveFD_; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/pty.h b/yt/yt/library/process/pty.h new file mode 100644 index 0000000000..b585782d12 --- /dev/null +++ b/yt/yt/library/process/pty.h @@ -0,0 +1,33 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/net/public.h> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +class TPty + : public TNonCopyable +{ +public: + static const int InvalidFD = -1; + + TPty(int height, int width); + ~TPty(); + + NNet::IConnectionReaderPtr CreateMasterAsyncReader(); + NNet::IConnectionWriterPtr CreateMasterAsyncWriter(); + + int GetMasterFD() const; + int GetSlaveFD() const; + +private: + int MasterFD_ = InvalidFD; + int SlaveFD_ = InvalidFD; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/public.h b/yt/yt/library/process/public.h new file mode 100644 index 0000000000..0fa1d3d0a9 --- /dev/null +++ b/yt/yt/library/process/public.h @@ -0,0 +1,14 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_CLASS(TNamedPipe) +DECLARE_REFCOUNTED_CLASS(TNamedPipeConfig) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/subprocess.cpp b/yt/yt/library/process/subprocess.cpp new file mode 100644 index 0000000000..02555b0c9b --- /dev/null +++ b/yt/yt/library/process/subprocess.cpp @@ -0,0 +1,153 @@ +#include "subprocess.h" + +#include <yt/yt/core/misc/blob.h> +#include <yt/yt/core/misc/proc.h> +#include <yt/yt/core/misc/finally.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/net/connection.h> + +#include <util/system/execpath.h> + +#include <array> + +namespace NYT { + +using namespace NConcurrency; +using namespace NPipes; + +//////////////////////////////////////////////////////////////////////////////// + +const static size_t PipeBlockSize = 64 * 1024; +static NLogging::TLogger Logger("Subprocess"); + +//////////////////////////////////////////////////////////////////////////////// + +TSubprocess::TSubprocess(const TString& path, bool copyEnv) + : Process_(New<TSimpleProcess>(path, copyEnv)) +{ } + +TSubprocess TSubprocess::CreateCurrentProcessSpawner() +{ + return TSubprocess(GetExecPath()); +} + +void TSubprocess::AddArgument(TStringBuf arg) +{ + Process_->AddArgument(arg); +} + +void TSubprocess::AddArguments(std::initializer_list<TStringBuf> args) +{ + Process_->AddArguments(args); +} + +TSubprocessResult TSubprocess::Execute(const TSharedRef& input) +{ +#ifdef _unix_ + auto inputStream = Process_->GetStdInWriter(); + auto outputStream = Process_->GetStdOutReader(); + auto errorStream = Process_->GetStdErrReader(); + auto finished = Process_->Spawn(); + + auto readIntoBlob = [] (IAsyncInputStreamPtr stream) { + TBlob output; + auto buffer = TSharedMutableRef::Allocate(PipeBlockSize, {.InitializeStorage = false}); + while (true) { + auto size = WaitFor(stream->Read(buffer)) + .ValueOrThrow(); + + if (size == 0) + break; + + // ToDo(psushin): eliminate copying. + output.Append(buffer.Begin(), size); + } + return TSharedRef::FromBlob(std::move(output)); + }; + + auto writeStdin = BIND([=] { + if (input.Size() > 0) { + WaitFor(inputStream->Write(input)) + .ThrowOnError(); + } + + WaitFor(inputStream->Close()) + .ThrowOnError(); + + //! Return dummy ref, so later we cat put Future into vector + //! along with stdout and stderr. + return TSharedRef::MakeEmpty(); + }); + + std::vector<TFuture<TSharedRef>> futures = { + BIND(readIntoBlob, outputStream).AsyncVia(GetCurrentInvoker()).Run(), + BIND(readIntoBlob, errorStream).AsyncVia(GetCurrentInvoker()).Run(), + writeStdin.AsyncVia(GetCurrentInvoker()).Run(), + }; + + try { + auto outputsOrError = WaitFor(AllSucceeded(futures)); + THROW_ERROR_EXCEPTION_IF_FAILED( + outputsOrError, + "IO error occurred during subprocess call"); + + const auto& outputs = outputsOrError.Value(); + YT_VERIFY(outputs.size() == 3); + + // This can block indefinitely. + auto exitCode = WaitFor(finished); + return TSubprocessResult{outputs[0], outputs[1], exitCode}; + } catch (...) { + try { + Process_->Kill(SIGKILL); + } catch (...) { } + Y_UNUSED(WaitFor(finished)); + throw; + } +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +void TSubprocess::Kill(int signal) +{ + Process_->Kill(signal); +} + +TString TSubprocess::GetCommandLine() const +{ + return Process_->GetCommandLine(); +} + +TProcessBasePtr TSubprocess::GetProcess() const +{ + return Process_; +} + +//////////////////////////////////////////////////////////////////////////////// + +void RunSubprocess(const std::vector<TString>& cmd) +{ + if (cmd.empty()) { + THROW_ERROR_EXCEPTION("Command can't be empty"); + } + + auto process = TSubprocess(cmd[0]); + for (int index = 1; index < std::ssize(cmd); ++index) { + process.AddArgument(cmd[index]); + } + + auto result = process.Execute(); + if (!result.Status.IsOK()) { + THROW_ERROR_EXCEPTION("Failed to run %v", cmd[0]) + << result.Status + << TErrorAttribute("command_line", process.GetCommandLine()) + << TErrorAttribute("error", TString(result.Error.Begin(), result.Error.End())); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/process/subprocess.h b/yt/yt/library/process/subprocess.h new file mode 100644 index 0000000000..223db533f6 --- /dev/null +++ b/yt/yt/library/process/subprocess.h @@ -0,0 +1,48 @@ +#pragma once + +#include "public.h" +#include "process.h" + +#include <library/cpp/yt/memory/ref.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +struct TSubprocessResult +{ + TSharedRef Output; + TSharedRef Error; + TError Status; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TSubprocess +{ +public: + explicit TSubprocess(const TString& path, bool copyEnv = true); + + static TSubprocess CreateCurrentProcessSpawner(); + + void AddArgument(TStringBuf arg); + void AddArguments(std::initializer_list<TStringBuf> args); + + TSubprocessResult Execute(const TSharedRef& input = TSharedRef::MakeEmpty()); + void Kill(int signal); + + TString GetCommandLine() const; + + TProcessBasePtr GetProcess() const; + +private: + const TProcessBasePtr Process_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +void RunSubprocess(const std::vector<TString>& cmd); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/process/unittests/pipes_ut.cpp b/yt/yt/library/process/unittests/pipes_ut.cpp new file mode 100644 index 0000000000..0c7e2a0cf2 --- /dev/null +++ b/yt/yt/library/process/unittests/pipes_ut.cpp @@ -0,0 +1,319 @@ +#include <yt/yt/core/test_framework/framework.h> + +#include <yt/yt/core/concurrency/action_queue.h> +#include <yt/yt/core/concurrency/scheduler.h> + +#include <yt/yt/core/misc/blob.h> +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/net/connection.h> + +#include <yt/yt/library/process/pipe.h> + +#include <random> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +using namespace NConcurrency; +using namespace NNet; + +#ifndef _win_ + +TEST(TPipeIOHolder, CanInstantiate) +{ + auto pipe = TPipeFactory().Create(); + + auto readerHolder = pipe.CreateAsyncReader(); + auto writerHolder = pipe.CreateAsyncWriter(); + + readerHolder->Abort().Get(); + writerHolder->Abort().Get(); +} + +TEST(TPipeTest, PrematureEOF) +{ + auto pipe = TNamedPipe::Create("./namedpipe"); + auto reader = pipe->CreateAsyncReader(); + + auto buffer = TSharedMutableRef::Allocate(1024 * 1024); + EXPECT_THROW(reader->Read(buffer).WithTimeout(TDuration::Seconds(1)).Get().ValueOrThrow(), TErrorException); +} + +//////////////////////////////////////////////////////////////////////////////// + +TBlob ReadAll(IConnectionReaderPtr reader, bool useWaitFor) +{ + auto buffer = TSharedMutableRef::Allocate(1_MB, {.InitializeStorage = false}); + auto whole = TBlob(GetRefCountedTypeCookie<TDefaultBlobTag>()); + + while (true) { + TErrorOr<size_t> result; + auto future = reader->Read(buffer); + if (useWaitFor) { + result = WaitFor(future); + } else { + result = future.Get(); + } + + if (result.ValueOrThrow() == 0) { + break; + } + + whole.Append(buffer.Begin(), result.Value()); + } + return whole; +} + +TEST(TAsyncWriterTest, AsyncCloseFail) +{ + auto pipe = TPipeFactory().Create(); + + auto reader = pipe.CreateAsyncReader(); + auto writer = pipe.CreateAsyncWriter(); + + auto queue = New<NConcurrency::TActionQueue>(); + auto readFromPipe = + BIND(&ReadAll, reader, false) + .AsyncVia(queue->GetInvoker()) + .Run(); + + int length = 200*1024; + auto buffer = TSharedMutableRef::Allocate(length); + ::memset(buffer.Begin(), 'a', buffer.Size()); + + auto writeResult = writer->Write(buffer).Get(); + + EXPECT_TRUE(writeResult.IsOK()) + << ToString(writeResult); + + auto error = writer->Close(); + + auto readResult = readFromPipe.Get(); + ASSERT_TRUE(readResult.IsOK()) + << ToString(readResult); + + auto closeStatus = error.Get(); +} + +TEST(TAsyncWriterTest, WriteFailed) +{ + auto pipe = TPipeFactory().Create(); + auto reader = pipe.CreateAsyncReader(); + auto writer = pipe.CreateAsyncWriter(); + + int length = 200*1024; + auto buffer = TSharedMutableRef::Allocate(length); + ::memset(buffer.Begin(), 'a', buffer.Size()); + + auto asyncWriteResult = writer->Write(buffer); + reader->Abort(); + + EXPECT_FALSE(asyncWriteResult.Get().IsOK()) + << ToString(asyncWriteResult.Get()); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TPipeReadWriteTest + : public ::testing::Test +{ +protected: + void SetUp() override + { + auto pipe = TPipeFactory().Create(); + + Reader = pipe.CreateAsyncReader(); + Writer = pipe.CreateAsyncWriter(); + } + + void TearDown() override + { } + + IConnectionReaderPtr Reader; + IConnectionWriterPtr Writer; +}; + +class TNamedPipeReadWriteTest + : public ::testing::Test +{ +protected: + void SetUp() override + { + auto pipe = TNamedPipe::Create("./namedpipe"); + Reader = pipe->CreateAsyncReader(); + Writer = pipe->CreateAsyncWriter(); + } + + void TearDown() override + { } + + IConnectionReaderPtr Reader; + IConnectionWriterPtr Writer; +}; + +TEST_F(TPipeReadWriteTest, ReadSomethingSpin) +{ + TString message("Hello pipe!\n"); + auto buffer = TSharedRef::FromString(message); + Writer->Write(buffer).Get().ThrowOnError(); + Writer->Close().Get().ThrowOnError(); + + auto data = TSharedMutableRef::Allocate(1); + auto whole = TBlob(GetRefCountedTypeCookie<TDefaultBlobTag>()); + + while (true) { + auto result = Reader->Read(data).Get(); + if (result.ValueOrThrow() == 0) { + break; + } + whole.Append(data.Begin(), result.Value()); + } + + EXPECT_EQ(message, TString(whole.Begin(), whole.End())); +} + +TEST_F(TNamedPipeReadWriteTest, ReadSomethingSpin) +{ + TString message("Hello pipe!\n"); + auto buffer = TSharedRef::FromString(message); + + Writer->Write(buffer).Get().ThrowOnError(); + Writer->Close().Get().ThrowOnError(); + + auto data = TSharedMutableRef::Allocate(1); + auto whole = TBlob(GetRefCountedTypeCookie<TDefaultBlobTag>()); + + while (true) { + auto result = Reader->Read(data).Get(); + if (result.ValueOrThrow() == 0) { + break; + } + whole.Append(data.Begin(), result.Value()); + } + EXPECT_EQ(message, TString(whole.Begin(), whole.End())); +} + + +TEST_F(TPipeReadWriteTest, ReadSomethingWait) +{ + TString message("Hello pipe!\n"); + auto buffer = TSharedRef::FromString(message); + EXPECT_TRUE(Writer->Write(buffer).Get().IsOK()); + WaitFor(Writer->Close()) + .ThrowOnError(); + auto whole = ReadAll(Reader, false); + EXPECT_EQ(message, TString(whole.Begin(), whole.End())); +} + +TEST_F(TNamedPipeReadWriteTest, ReadSomethingWait) +{ + TString message("Hello pipe!\n"); + auto buffer = TSharedRef::FromString(message); + EXPECT_TRUE(Writer->Write(buffer).Get().IsOK()); + WaitFor(Writer->Close()) + .ThrowOnError(); + auto whole = ReadAll(Reader, false); + EXPECT_EQ(message, TString(whole.Begin(), whole.End())); +} + +TEST_F(TPipeReadWriteTest, ReadWrite) +{ + TString text("Hello cruel world!\n"); + auto buffer = TSharedRef::FromString(text); + Writer->Write(buffer).Get(); + auto errorsOnClose = Writer->Close(); + + auto textFromPipe = ReadAll(Reader, false); + + auto error = errorsOnClose.Get(); + EXPECT_TRUE(error.IsOK()) << error.GetMessage(); + EXPECT_EQ(text, TString(textFromPipe.Begin(), textFromPipe.End())); +} + +TEST_F(TNamedPipeReadWriteTest, ReadWrite) +{ + TString text("Hello cruel world!\n"); + auto buffer = TSharedRef::FromString(text); + Writer->Write(buffer).Get(); + auto errorsOnClose = Writer->Close(); + + auto textFromPipe = ReadAll(Reader, false); + + auto error = errorsOnClose.Get(); + EXPECT_TRUE(error.IsOK()) << error.GetMessage(); + EXPECT_EQ(text, TString(textFromPipe.Begin(), textFromPipe.End())); +} + +void WriteAll(IConnectionWriterPtr writer, const char* data, size_t size, size_t blockSize) +{ + while (size > 0) { + const size_t currentBlockSize = std::min(blockSize, size); + auto buffer = TSharedRef(data, currentBlockSize, nullptr); + auto error = WaitFor(writer->Write(buffer)); + THROW_ERROR_EXCEPTION_IF_FAILED(error); + size -= currentBlockSize; + data += currentBlockSize; + } + + { + auto error = WaitFor(writer->Close()); + THROW_ERROR_EXCEPTION_IF_FAILED(error); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +class TPipeBigReadWriteTest + : public TPipeReadWriteTest + , public ::testing::WithParamInterface<std::pair<size_t, size_t>> +{ }; + +TEST_P(TPipeBigReadWriteTest, RealReadWrite) +{ + size_t dataSize, blockSize; + std::tie(dataSize, blockSize) = GetParam(); + + auto queue = New<NConcurrency::TActionQueue>(); + + std::vector<char> data(dataSize, 'a'); + + BIND([&] () { + auto dice = std::bind( + std::uniform_int_distribution<int>(0, 127), + std::default_random_engine()); + for (size_t i = 0; i < data.size(); ++i) { + data[i] = dice(); + } + }) + .AsyncVia(queue->GetInvoker()).Run(); + + auto writeError = BIND(&WriteAll, Writer, data.data(), data.size(), blockSize) + .AsyncVia(queue->GetInvoker()) + .Run(); + auto readFromPipe = BIND(&ReadAll, Reader, true) + .AsyncVia(queue->GetInvoker()) + .Run(); + + auto textFromPipe = readFromPipe.Get().ValueOrThrow(); + EXPECT_EQ(data.size(), textFromPipe.Size()); + auto result = std::mismatch(textFromPipe.Begin(), textFromPipe.End(), data.begin()); + EXPECT_TRUE(std::equal(textFromPipe.Begin(), textFromPipe.End(), data.begin())) << + (result.first - textFromPipe.Begin()) << " " << (int)(*result.first); +} + +INSTANTIATE_TEST_SUITE_P( + ValueParametrized, + TPipeBigReadWriteTest, + ::testing::Values( + std::make_pair(2000 * 4096, 4096), + std::make_pair(100 * 4096, 10000), + std::make_pair(100 * 4096, 100), + std::make_pair(100, 4096))); + +#endif + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/unittests/process_ut.cpp b/yt/yt/library/process/unittests/process_ut.cpp new file mode 100644 index 0000000000..55f7fc65c8 --- /dev/null +++ b/yt/yt/library/process/unittests/process_ut.cpp @@ -0,0 +1,235 @@ +#include <yt/yt/library/process/process.h> + +#include <yt/yt/core/test_framework/framework.h> + +#include <yt/yt/core/actions/bind.h> + +#include <yt/yt/core/concurrency/action_queue.h> +#include <yt/yt/core/concurrency/delayed_executor.h> +#include <yt/yt/core/concurrency/scheduler.h> + +#include <yt/yt/core/net/connection.h> + +#include <library/cpp/yt/system/handle_eintr.h> + +namespace NYT { +namespace { + +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(_unix_) and not defined(_asan_enabled_) + +TEST(TProcessTest, Basic) +{ + auto p = New<TSimpleProcess>("/bin/ls"); + TFuture<void> finished; + + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_TRUE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); +} + +// NB: We cannot rely on 'ls' and 'sleep' in arcadia tests. +TEST(TProcessTest, RunFromPathEnv) +{ + auto p = New<TSimpleProcess>("/bin/ls", false); + TFuture<void> finished; + + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_TRUE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); +} + +TEST(TProcessTest, PollDuration) +{ + auto p = New<TSimpleProcess>("/bin/sleep", true, TDuration::MilliSeconds(1)); + p->AddArgument("0.1"); + + auto error = WaitFor(p->Spawn()); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); +} + +TEST(TProcessTest, InvalidPath) +{ + auto p = New<TSimpleProcess>("/some/bad/path/binary"); + + TFuture<void> finished; + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_FALSE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_FALSE(p->IsFinished()); + EXPECT_FALSE(error.IsOK()); +} + +TEST(TProcessTest, StdOut) +{ + auto p = New<TSimpleProcess>("/bin/date"); + + auto outStream = p->GetStdOutReader(); + TFuture<void> finished; + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_TRUE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); + + auto buffer = TSharedMutableRef::Allocate(4_KB, {.InitializeStorage = false}); + auto future = outStream->Read(buffer); + auto result = WaitFor(future); + size_t sz = result.ValueOrThrow(); + EXPECT_TRUE(sz > 0); +} + +TEST(TSimpleProcess, GetCommandLine1) +{ + auto p = New<TSimpleProcess>("/bin/bash"); + EXPECT_EQ("/bin/bash", p->GetCommandLine()); + p->AddArgument("-c"); + EXPECT_EQ("/bin/bash -c", p->GetCommandLine()); + p->AddArgument("exit 0"); + EXPECT_EQ("/bin/bash -c \"exit 0\"", p->GetCommandLine()); +} + +TEST(TProcessBase, GetCommandLine2) +{ + auto p = New<TSimpleProcess>("/bin/bash"); + EXPECT_EQ("/bin/bash", p->GetCommandLine()); + p->AddArgument("-c"); + EXPECT_EQ("/bin/bash -c", p->GetCommandLine()); + p->AddArgument("\"quoted\""); + EXPECT_EQ("/bin/bash -c \"\\\"quoted\\\"\"", p->GetCommandLine()); +} + +TEST(TProcessTest, ProcessReturnCode0) +{ + auto p = New<TSimpleProcess>("/bin/bash"); + p->AddArgument("-c"); + p->AddArgument("exit 0"); + + TFuture<void> finished; + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_TRUE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); +} + +TEST(TProcessTest, ProcessReturnCode123) +{ + auto p = New<TSimpleProcess>("/bin/bash"); + p->AddArgument("-c"); + p->AddArgument("exit 123"); + + TFuture<void> finished; + ASSERT_NO_THROW(finished = p->Spawn()); + ASSERT_TRUE(p->IsStarted()); + auto error = WaitFor(finished); + EXPECT_EQ(EProcessErrorCode::NonZeroExitCode, error.GetCode()); + EXPECT_EQ(123, error.Attributes().Get<int>("exit_code")); + EXPECT_TRUE(p->IsFinished()); +} + +TEST(TProcessTest, Params1) +{ + auto p = New<TSimpleProcess>("/bin/bash"); + p->AddArgument("-c"); + p->AddArgument("if test 3 -gt 1; then exit 7; fi"); + + auto error = WaitFor(p->Spawn()); + EXPECT_FALSE(error.IsOK()); + EXPECT_TRUE(p->IsFinished()); +} + +TEST(TProcessTest, Params2) +{ + auto p = New<TSimpleProcess>("/bin/bash"); + p->AddArgument("-c"); + p->AddArgument("if test 1 -gt 3; then exit 7; fi"); + + auto error = WaitFor(p->Spawn()); + EXPECT_TRUE(error.IsOK()) << ToString(error); + EXPECT_TRUE(p->IsFinished()); +} + +TEST(TProcessTest, InheritEnvironment) +{ + const char* name = "SPAWN_TEST_ENV_VAR"; + const char* value = "42"; + setenv(name, value, 1); + + auto p = New<TSimpleProcess>("/bin/bash"); + p->AddArgument("-c"); + p->AddArgument("if test $SPAWN_TEST_ENV_VAR = 42; then exit 7; fi"); + + auto error = WaitFor(p->Spawn()); + EXPECT_FALSE(error.IsOK()); + EXPECT_TRUE(p->IsFinished()); + + unsetenv(name); +} + +TEST(TProcessTest, Kill) +{ + auto p = New<TSimpleProcess>("/bin/sleep"); + p->AddArgument("5"); + + auto finished = p->Spawn(); + + NConcurrency::TDelayedExecutor::Submit( + BIND([&] () { + p->Kill(SIGKILL); + }), + TDuration::MilliSeconds(100)); + + auto error = WaitFor(finished); + EXPECT_FALSE(error.IsOK()); + EXPECT_TRUE(p->IsFinished()); +} + +TEST(TProcessTest, KillFinished) +{ + auto p = New<TSimpleProcess>("/bin/bash"); + p->AddArgument("-c"); + p->AddArgument("true"); + + auto finished = p->Spawn(); + + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()); + + p->Kill(SIGKILL); +} + +TEST(TProcessTest, KillZombie) +{ + auto p = New<TSimpleProcess>("/bin/bash"); + p->AddArgument("-c"); + p->AddArgument("/bin/sleep 1; /bin/true"); + + auto finished = p->Spawn(); + + siginfo_t infop; + auto res = HandleEintr(::waitid, P_PID, p->GetProcessId(), &infop, WEXITED | WNOWAIT); + EXPECT_EQ(0, res) + << "errno = " << errno; + EXPECT_EQ(p->GetProcessId(), infop.si_pid); + + p->Kill(SIGKILL); + auto error = WaitFor(finished); + EXPECT_TRUE(error.IsOK()) + << ToString(error); +} + +#endif + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT diff --git a/yt/yt/library/process/unittests/subprocess_ut.cpp b/yt/yt/library/process/unittests/subprocess_ut.cpp new file mode 100644 index 0000000000..932d250784 --- /dev/null +++ b/yt/yt/library/process/unittests/subprocess_ut.cpp @@ -0,0 +1,101 @@ +#include <yt/yt/core/test_framework/framework.h> + +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/concurrency/action_queue.h> + +#include <yt/yt/library/process/subprocess.h> + +namespace NYT { +namespace { + +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(_unix_) and not defined(_asan_enabled_) + +TEST(TSubprocessTest, Basic) +{ + TSubprocess subprocess("/bin/bash"); + + subprocess.AddArgument("-c"); + subprocess.AddArgument("true"); + + auto result = subprocess.Execute(); + EXPECT_TRUE(result.Status.IsOK()); +} + + +TEST(TSubprocessTest, PipeOutput) +{ + TSubprocess subprocess("/bin/echo"); + + subprocess.AddArgument("hello"); + + auto result = subprocess.Execute(); + EXPECT_TRUE(result.Status.IsOK()); + TString output(result.Output.Begin(), result.Output.End()); + EXPECT_TRUE(output == "hello\n") << output; +} + +TEST(TSubprocessTest, PipeStdin) +{ + auto queue = New<TActionQueue>(); + + BIND([] () { + TSubprocess subprocess("/bin/cat"); + subprocess.AddArgument("-"); + + auto input = TString("TEST test TEST"); + auto inputRef = TSharedRef::FromString(input); + auto result = subprocess.Execute(inputRef); + EXPECT_TRUE(result.Status.IsOK()); + + TString output(result.Output.Begin(), result.Output.End()); + EXPECT_EQ(input, output); + }).AsyncVia(queue->GetInvoker()).Run().Get().ThrowOnError(); +} + +TEST(TSubprocessTest, PipeBigOutput) +{ + auto queue = New<TActionQueue>(); + + auto result = BIND([] () { + TSubprocess subprocess("/bin/bash"); + + subprocess.AddArgument("-c"); + subprocess.AddArgument("for i in `/usr/bin/seq 100000`; do echo hello; done; echo world"); + + auto result = subprocess.Execute(); + return result.Status.IsOK(); + }).AsyncVia(queue->GetInvoker()).Run().Get().Value(); + + EXPECT_TRUE(result); +} + + +TEST(TSubprocessTest, PipeBigError) +{ + auto queue = New<TActionQueue>(); + + auto result = BIND([] () { + TSubprocess subprocess("/bin/bash"); + + subprocess.AddArgument("-c"); + subprocess.AddArgument("for i in `/usr/bin/seq 100000`; do echo hello 1>&2; done; echo world"); + + auto result = subprocess.Execute(); + return result; + }).AsyncVia(queue->GetInvoker()).Run().Get().Value(); + + EXPECT_TRUE(result.Status.IsOK()); + EXPECT_EQ(6*100000, std::ssize(result.Error)); +} + +#endif + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT diff --git a/yt/yt/library/process/unittests/ya.make b/yt/yt/library/process/unittests/ya.make new file mode 100644 index 0000000000..6e476c5702 --- /dev/null +++ b/yt/yt/library/process/unittests/ya.make @@ -0,0 +1,24 @@ +GTEST(unittester-library-process) + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +ALLOCATOR(YT) + +SRCS( + pipes_ut.cpp + process_ut.cpp + subprocess_ut.cpp +) + +INCLUDE(${ARCADIA_ROOT}/yt/opensource_tests.inc) + +PEERDIR( + yt/yt/build + yt/yt/core + yt/yt/core/test_framework + yt/yt/library/process +) + +SIZE(MEDIUM) + +END() diff --git a/yt/yt/library/process/ya.make b/yt/yt/library/process/ya.make new file mode 100644 index 0000000000..79763c7267 --- /dev/null +++ b/yt/yt/library/process/ya.make @@ -0,0 +1,22 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + io_dispatcher.cpp + pipe.cpp + process.cpp + pty.cpp + subprocess.cpp +) + +PEERDIR( + yt/yt/core + contrib/libs/re2 +) + +END() + +RECURSE_FOR_TESTS( + unittests +) diff --git a/yt/yt/library/profiling/CMakeLists.darwin-x86_64.txt b/yt/yt/library/profiling/CMakeLists.darwin-x86_64.txt index bafd86d310..0f1318b348 100644 --- a/yt/yt/library/profiling/CMakeLists.darwin-x86_64.txt +++ b/yt/yt/library/profiling/CMakeLists.darwin-x86_64.txt @@ -6,7 +6,10 @@ # original buildsystem will not be accepted. +add_subdirectory(perf) add_subdirectory(resource_tracker) +add_subdirectory(solomon) +add_subdirectory(tcmalloc) add_library(yt-library-profiling) target_compile_options(yt-library-profiling PRIVATE diff --git a/yt/yt/library/profiling/CMakeLists.linux-aarch64.txt b/yt/yt/library/profiling/CMakeLists.linux-aarch64.txt index e4524762fc..b1ff65e70f 100644 --- a/yt/yt/library/profiling/CMakeLists.linux-aarch64.txt +++ b/yt/yt/library/profiling/CMakeLists.linux-aarch64.txt @@ -6,7 +6,10 @@ # original buildsystem will not be accepted. +add_subdirectory(perf) add_subdirectory(resource_tracker) +add_subdirectory(solomon) +add_subdirectory(tcmalloc) add_library(yt-library-profiling) target_compile_options(yt-library-profiling PRIVATE diff --git a/yt/yt/library/profiling/CMakeLists.linux-x86_64.txt b/yt/yt/library/profiling/CMakeLists.linux-x86_64.txt index e4524762fc..b1ff65e70f 100644 --- a/yt/yt/library/profiling/CMakeLists.linux-x86_64.txt +++ b/yt/yt/library/profiling/CMakeLists.linux-x86_64.txt @@ -6,7 +6,10 @@ # original buildsystem will not be accepted. +add_subdirectory(perf) add_subdirectory(resource_tracker) +add_subdirectory(solomon) +add_subdirectory(tcmalloc) add_library(yt-library-profiling) target_compile_options(yt-library-profiling PRIVATE diff --git a/yt/yt/library/profiling/CMakeLists.windows-x86_64.txt b/yt/yt/library/profiling/CMakeLists.windows-x86_64.txt index 6a6f1a6dde..34e12beaad 100644 --- a/yt/yt/library/profiling/CMakeLists.windows-x86_64.txt +++ b/yt/yt/library/profiling/CMakeLists.windows-x86_64.txt @@ -6,7 +6,10 @@ # original buildsystem will not be accepted. +add_subdirectory(perf) add_subdirectory(resource_tracker) +add_subdirectory(solomon) +add_subdirectory(tcmalloc) add_library(yt-library-profiling) target_link_libraries(yt-library-profiling PUBLIC diff --git a/yt/yt/library/profiling/perf/CMakeLists.darwin-x86_64.txt b/yt/yt/library/profiling/perf/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..fd65c86abb --- /dev/null +++ b/yt/yt/library/profiling/perf/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,22 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-profiling-perf) +target_compile_options(library-profiling-perf PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-profiling-perf PUBLIC + contrib-libs-cxxsupp + yutil + yt-library-profiling + yt-yt-core +) +target_sources(library-profiling-perf PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/perf/counters_other.cpp +) diff --git a/yt/yt/library/profiling/perf/CMakeLists.linux-aarch64.txt b/yt/yt/library/profiling/perf/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..8615b211d9 --- /dev/null +++ b/yt/yt/library/profiling/perf/CMakeLists.linux-aarch64.txt @@ -0,0 +1,23 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-profiling-perf) +target_compile_options(library-profiling-perf PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-profiling-perf PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-library-profiling + yt-yt-core +) +target_sources(library-profiling-perf PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/perf/counters.cpp +) diff --git a/yt/yt/library/profiling/perf/CMakeLists.linux-x86_64.txt b/yt/yt/library/profiling/perf/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..8615b211d9 --- /dev/null +++ b/yt/yt/library/profiling/perf/CMakeLists.linux-x86_64.txt @@ -0,0 +1,23 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-profiling-perf) +target_compile_options(library-profiling-perf PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-profiling-perf PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-library-profiling + yt-yt-core +) +target_sources(library-profiling-perf PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/perf/counters.cpp +) diff --git a/yt/yt/library/profiling/perf/CMakeLists.txt b/yt/yt/library/profiling/perf/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/library/profiling/perf/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/profiling/perf/CMakeLists.windows-x86_64.txt b/yt/yt/library/profiling/perf/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..f6897f44de --- /dev/null +++ b/yt/yt/library/profiling/perf/CMakeLists.windows-x86_64.txt @@ -0,0 +1,19 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-profiling-perf) +target_link_libraries(library-profiling-perf PUBLIC + contrib-libs-cxxsupp + yutil + yt-library-profiling + yt-yt-core +) +target_sources(library-profiling-perf PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/perf/counters_other.cpp +) diff --git a/yt/yt/library/profiling/solomon/CMakeLists.darwin-x86_64.txt b/yt/yt/library/profiling/solomon/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..b4c6b378dd --- /dev/null +++ b/yt/yt/library/profiling/solomon/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,68 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-profiling-solomon) +target_compile_options(library-profiling-solomon PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-profiling-solomon PUBLIC + contrib-libs-cxxsupp + yutil + yt-library-profiling + yt-yt-core + yt-core-http + library-cpp-cgiparam + cpp-monlib-metrics + monlib-encode-prometheus + monlib-encode-spack + monlib-encode-json + cpp-yt-threading + contrib-libs-protobuf +) +target_proto_messages(library-profiling-solomon PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_dump.proto +) +target_sources(library-profiling-solomon PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/cube.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/exporter.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/percpu.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/producer.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/registry.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/remote.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_service.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_set.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/tag_registry.cpp +) +target_proto_addincls(library-profiling-solomon + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-profiling-solomon + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) diff --git a/yt/yt/library/profiling/solomon/CMakeLists.linux-aarch64.txt b/yt/yt/library/profiling/solomon/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..57e8fb3c6e --- /dev/null +++ b/yt/yt/library/profiling/solomon/CMakeLists.linux-aarch64.txt @@ -0,0 +1,69 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-profiling-solomon) +target_compile_options(library-profiling-solomon PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-profiling-solomon PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-library-profiling + yt-yt-core + yt-core-http + library-cpp-cgiparam + cpp-monlib-metrics + monlib-encode-prometheus + monlib-encode-spack + monlib-encode-json + cpp-yt-threading + contrib-libs-protobuf +) +target_proto_messages(library-profiling-solomon PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_dump.proto +) +target_sources(library-profiling-solomon PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/cube.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/exporter.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/percpu.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/producer.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/registry.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/remote.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_service.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_set.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/tag_registry.cpp +) +target_proto_addincls(library-profiling-solomon + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-profiling-solomon + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) diff --git a/yt/yt/library/profiling/solomon/CMakeLists.linux-x86_64.txt b/yt/yt/library/profiling/solomon/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..57e8fb3c6e --- /dev/null +++ b/yt/yt/library/profiling/solomon/CMakeLists.linux-x86_64.txt @@ -0,0 +1,69 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-profiling-solomon) +target_compile_options(library-profiling-solomon PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-profiling-solomon PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-library-profiling + yt-yt-core + yt-core-http + library-cpp-cgiparam + cpp-monlib-metrics + monlib-encode-prometheus + monlib-encode-spack + monlib-encode-json + cpp-yt-threading + contrib-libs-protobuf +) +target_proto_messages(library-profiling-solomon PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_dump.proto +) +target_sources(library-profiling-solomon PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/cube.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/exporter.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/percpu.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/producer.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/registry.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/remote.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_service.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_set.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/tag_registry.cpp +) +target_proto_addincls(library-profiling-solomon + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-profiling-solomon + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) diff --git a/yt/yt/library/profiling/solomon/CMakeLists.txt b/yt/yt/library/profiling/solomon/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/library/profiling/solomon/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/profiling/solomon/CMakeLists.windows-x86_64.txt b/yt/yt/library/profiling/solomon/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..958f7e64fc --- /dev/null +++ b/yt/yt/library/profiling/solomon/CMakeLists.windows-x86_64.txt @@ -0,0 +1,65 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-profiling-solomon) +target_link_libraries(library-profiling-solomon PUBLIC + contrib-libs-cxxsupp + yutil + yt-library-profiling + yt-yt-core + yt-core-http + library-cpp-cgiparam + cpp-monlib-metrics + monlib-encode-prometheus + monlib-encode-spack + monlib-encode-json + cpp-yt-threading + contrib-libs-protobuf +) +target_proto_messages(library-profiling-solomon PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_dump.proto +) +target_sources(library-profiling-solomon PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/cube.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/exporter.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/percpu.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/producer.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/registry.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/remote.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_service.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/sensor_set.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/solomon/tag_registry.cpp +) +target_proto_addincls(library-profiling-solomon + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-profiling-solomon + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) diff --git a/yt/yt/library/profiling/tcmalloc/CMakeLists.darwin-x86_64.txt b/yt/yt/library/profiling/tcmalloc/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..341ff6724c --- /dev/null +++ b/yt/yt/library/profiling/tcmalloc/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,22 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-profiling-tcmalloc) +target_compile_options(library-profiling-tcmalloc PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-profiling-tcmalloc PUBLIC + contrib-libs-cxxsupp + yutil + yt-library-profiling + libs-tcmalloc-malloc_extension +) +target_sources(library-profiling-tcmalloc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/tcmalloc/profiler.cpp +) diff --git a/yt/yt/library/profiling/tcmalloc/CMakeLists.linux-aarch64.txt b/yt/yt/library/profiling/tcmalloc/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..3ebf7b2fd2 --- /dev/null +++ b/yt/yt/library/profiling/tcmalloc/CMakeLists.linux-aarch64.txt @@ -0,0 +1,23 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-profiling-tcmalloc) +target_compile_options(library-profiling-tcmalloc PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-profiling-tcmalloc PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-library-profiling + libs-tcmalloc-malloc_extension +) +target_sources(library-profiling-tcmalloc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/tcmalloc/profiler.cpp +) diff --git a/yt/yt/library/profiling/tcmalloc/CMakeLists.linux-x86_64.txt b/yt/yt/library/profiling/tcmalloc/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..3ebf7b2fd2 --- /dev/null +++ b/yt/yt/library/profiling/tcmalloc/CMakeLists.linux-x86_64.txt @@ -0,0 +1,23 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-profiling-tcmalloc) +target_compile_options(library-profiling-tcmalloc PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-profiling-tcmalloc PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-library-profiling + libs-tcmalloc-malloc_extension +) +target_sources(library-profiling-tcmalloc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/tcmalloc/profiler.cpp +) diff --git a/yt/yt/library/profiling/tcmalloc/CMakeLists.txt b/yt/yt/library/profiling/tcmalloc/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/library/profiling/tcmalloc/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/profiling/tcmalloc/CMakeLists.windows-x86_64.txt b/yt/yt/library/profiling/tcmalloc/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..3fcdf372c8 --- /dev/null +++ b/yt/yt/library/profiling/tcmalloc/CMakeLists.windows-x86_64.txt @@ -0,0 +1,19 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-profiling-tcmalloc) +target_link_libraries(library-profiling-tcmalloc PUBLIC + contrib-libs-cxxsupp + yutil + yt-library-profiling + libs-tcmalloc-malloc_extension +) +target_sources(library-profiling-tcmalloc PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/profiling/tcmalloc/profiler.cpp +) diff --git a/yt/yt/library/program/CMakeLists.darwin-x86_64.txt b/yt/yt/library/program/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..fc3e796e79 --- /dev/null +++ b/yt/yt/library/program/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,38 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-program) +target_compile_options(yt-library-program PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-program PUBLIC + contrib-libs-cxxsupp + yutil + yt-yt-core + core-service_discovery-yp + yt-library-monitoring + yt-library-containers + library-profiling-solomon + library-profiling-tcmalloc + library-profiling-perf + yt-library-ytprof + library-tracing-jaeger + cpp-yt-mlock + cpp-yt-stockpile + cpp-yt-string +) +target_sources(yt-library-program PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/build_attributes.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/helpers.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_config_mixin.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_pdeathsig_mixin.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_setsid_mixin.cpp +) diff --git a/yt/yt/library/program/CMakeLists.linux-aarch64.txt b/yt/yt/library/program/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..be31ba10cf --- /dev/null +++ b/yt/yt/library/program/CMakeLists.linux-aarch64.txt @@ -0,0 +1,39 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-program) +target_compile_options(yt-library-program PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-program PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + core-service_discovery-yp + yt-library-monitoring + yt-library-containers + library-profiling-solomon + library-profiling-tcmalloc + library-profiling-perf + yt-library-ytprof + library-tracing-jaeger + cpp-yt-mlock + cpp-yt-stockpile + cpp-yt-string +) +target_sources(yt-library-program PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/build_attributes.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/helpers.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_config_mixin.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_pdeathsig_mixin.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_setsid_mixin.cpp +) diff --git a/yt/yt/library/program/CMakeLists.linux-x86_64.txt b/yt/yt/library/program/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..be31ba10cf --- /dev/null +++ b/yt/yt/library/program/CMakeLists.linux-x86_64.txt @@ -0,0 +1,39 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-program) +target_compile_options(yt-library-program PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(yt-library-program PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-yt-core + core-service_discovery-yp + yt-library-monitoring + yt-library-containers + library-profiling-solomon + library-profiling-tcmalloc + library-profiling-perf + yt-library-ytprof + library-tracing-jaeger + cpp-yt-mlock + cpp-yt-stockpile + cpp-yt-string +) +target_sources(yt-library-program PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/build_attributes.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/helpers.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_config_mixin.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_pdeathsig_mixin.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_setsid_mixin.cpp +) diff --git a/yt/yt/library/program/CMakeLists.txt b/yt/yt/library/program/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/library/program/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/program/CMakeLists.windows-x86_64.txt b/yt/yt/library/program/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..1f2aea4bd0 --- /dev/null +++ b/yt/yt/library/program/CMakeLists.windows-x86_64.txt @@ -0,0 +1,35 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(yt-library-program) +target_link_libraries(yt-library-program PUBLIC + contrib-libs-cxxsupp + yutil + yt-yt-core + core-service_discovery-yp + yt-library-monitoring + yt-library-containers + library-profiling-solomon + library-profiling-tcmalloc + library-profiling-perf + yt-library-ytprof + library-tracing-jaeger + cpp-yt-mlock + cpp-yt-stockpile + cpp-yt-string +) +target_sources(yt-library-program PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/build_attributes.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/config.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/helpers.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_config_mixin.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_pdeathsig_mixin.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/program/program_setsid_mixin.cpp +) diff --git a/yt/yt/library/program/build_attributes.cpp b/yt/yt/library/program/build_attributes.cpp new file mode 100644 index 0000000000..38caf57997 --- /dev/null +++ b/yt/yt/library/program/build_attributes.cpp @@ -0,0 +1,107 @@ +#include "build_attributes.h" + +#include <yt/yt/build/build.h> + +#include <yt/yt/core/ytree/fluent.h> +#include <yt/yt/core/ytree/ypath_client.h> + +#include <yt/yt/core/misc/error_code.h> + +namespace NYT { + +using namespace NYTree; +using namespace NYson; + +static const NLogging::TLogger Logger("Build"); + +//////////////////////////////////////////////////////////////////////////////// + +void TBuildInfo::Register(TRegistrar registrar) +{ + registrar.Parameter("name", &TThis::Name) + .Default(); + + registrar.Parameter("version", &TThis::Version) + .Default(GetVersion()); + + registrar.Parameter("build_host", &TThis::BuildHost) + .Default(GetBuildHost()); + + registrar.Parameter("build_time", &TThis::BuildTime) + .Default(ParseBuildTime()); + + registrar.Parameter("start_time", &TThis::StartTime) + .Default(TInstant::Now()); +} + +std::optional<TInstant> TBuildInfo::ParseBuildTime() +{ + TString rawBuildTime(GetBuildTime()); + + // Build time may be empty if code is building + // without -DBUILD_DATE (for example, in opensource build). + if (rawBuildTime.empty()) { + return std::nullopt; + } + + try { + return TInstant::ParseIso8601(rawBuildTime); + } catch (const std::exception& ex) { + YT_LOG_ERROR(ex, "Error parsing build time"); + return std::nullopt; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +TBuildInfoPtr BuildBuildAttributes(const char* serviceName) +{ + auto info = New<TBuildInfo>(); + if (serviceName) { + info->Name = serviceName; + } + return info; +} + +void SetBuildAttributes(IYPathServicePtr orchidRoot, const char* serviceName) +{ + SyncYPathSet( + orchidRoot, + "/service", + BuildYsonStringFluently() + .BeginAttributes() + .Item("opaque").Value(true) + .EndAttributes() + .Value(BuildBuildAttributes(serviceName))); + SyncYPathSet( + orchidRoot, + "/error_codes", + BuildYsonStringFluently() + .BeginAttributes() + .Item("opaque").Value(true) + .EndAttributes() + .DoMapFor(TErrorCodeRegistry::Get()->GetAllErrorCodes(), [] (TFluentMap fluent, const auto& pair) { + fluent + .Item(ToString(pair.first)).BeginMap() + .Item("cpp_literal").Value(ToString(pair.second)) + .EndMap(); + })); + SyncYPathSet( + orchidRoot, + "/error_code_ranges", + BuildYsonStringFluently() + .BeginAttributes() + .Item("opaque").Value(true) + .EndAttributes() + .DoMapFor(TErrorCodeRegistry::Get()->GetAllErrorCodeRanges(), [] (TFluentMap fluent, const TErrorCodeRegistry::TErrorCodeRangeInfo& range) { + fluent + .Item(ToString(range)).BeginMap() + .Item("cpp_enum").Value(range.Namespace) + .EndMap(); + })); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT + diff --git a/yt/yt/library/program/build_attributes.h b/yt/yt/library/program/build_attributes.h new file mode 100644 index 0000000000..e02f86b351 --- /dev/null +++ b/yt/yt/library/program/build_attributes.h @@ -0,0 +1,44 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/public.h> +#include <yt/yt/core/ytree/yson_struct.h> + +#include <yt/yt/core/yson/public.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TBuildInfo + : public NYTree::TYsonStruct +{ +public: + std::optional<TString> Name; + TString Version; + TString BuildHost; + std::optional<TInstant> BuildTime; + TInstant StartTime; + + REGISTER_YSON_STRUCT(TBuildInfo); + + static void Register(TRegistrar registrar); + +private: + static std::optional<TInstant> ParseBuildTime(); +}; + +DEFINE_REFCOUNTED_TYPE(TBuildInfo) + +//////////////////////////////////////////////////////////////////////////////// + +//! Build build (pun intended) attributes as a TBuildInfo a-la /orchid/service. If service name is not provided, +//! it is omitted from the result. +TBuildInfoPtr BuildBuildAttributes(const char* serviceName = nullptr); + +void SetBuildAttributes(NYTree::IYPathServicePtr orchidRoot, const char* serviceName); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/config.cpp b/yt/yt/library/program/config.cpp new file mode 100644 index 0000000000..0705fb48fc --- /dev/null +++ b/yt/yt/library/program/config.cpp @@ -0,0 +1,210 @@ +#include "config.h" + +namespace NYT { + +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +void TRpcConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("tracing", &TThis::Tracing) + .Default(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void THeapSizeLimit::Register(TRegistrar registrar) +{ + registrar.Parameter("container_memory_ratio", &TThis::ContainerMemoryRatio) + .Optional(); + registrar.Parameter("is_hard", &TThis::IsHard) + .Default(false); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TTCMallocConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("background_release_rate", &TThis::BackgroundReleaseRate) + .Default(32_MB); + registrar.Parameter("max_per_cpu_cache_size", &TThis::MaxPerCpuCacheSize) + .Default(3_MB); + + registrar.Parameter("aggressive_release_threshold", &TThis::AggressiveReleaseThreshold) + .Default(20_GB); + registrar.Parameter("aggressive_release_threshold_ratio", &TThis::AggressiveReleaseThresholdRatio) + .Optional(); + + registrar.Parameter("aggressive_release_size", &TThis::AggressiveReleaseSize) + .Default(128_MB); + registrar.Parameter("aggressive_release_period", &TThis::AggressiveReleasePeriod) + .Default(TDuration::MilliSeconds(100)); + registrar.Parameter("guarded_sampling_rate", &TThis::GuardedSamplingRate) + .Default(128_MB); + + registrar.Parameter("heap_size_limit", &TThis::HeapSizeLimit) + .DefaultNew(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TStockpileConfig::Register(TRegistrar registrar) +{ + registrar.BaseClassParameter("buffer_size", &TThis::BufferSize) + .Default(DefaultBufferSize); + registrar.BaseClassParameter("thread_count", &TThis::ThreadCount) + .Default(DefaultThreadCount); + registrar.BaseClassParameter("period", &TThis::Period) + .Default(DefaultPeriod); +} + +//////////////////////////////////////////////////////////////////////////////// + +void THeapProfilerConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("snapshot_update_period", &TThis::SnapshotUpdatePeriod) + .Default(TDuration::Seconds(5)); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TSingletonsConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("spin_wait_slow_path_logging_threshold", &TThis::SpinWaitSlowPathLoggingThreshold) + .Default(TDuration::MicroSeconds(100)); + registrar.Parameter("yt_alloc", &TThis::YTAlloc) + .DefaultNew(); + registrar.Parameter("fiber_stack_pool_sizes", &TThis::FiberStackPoolSizes) + .Default({}); + registrar.Parameter("address_resolver", &TThis::AddressResolver) + .DefaultNew(); + registrar.Parameter("tcp_dispatcher", &TThis::TcpDispatcher) + .DefaultNew(); + registrar.Parameter("rpc_dispatcher", &TThis::RpcDispatcher) + .DefaultNew(); + registrar.Parameter("grpc_dispatcher", &TThis::GrpcDispatcher) + .DefaultNew(); + registrar.Parameter("yp_service_discovery", &TThis::YPServiceDiscovery) + .DefaultNew(); + registrar.Parameter("solomon_exporter", &TThis::SolomonExporter) + .DefaultNew(); + registrar.Parameter("logging", &TThis::Logging) + .DefaultCtor([] () { return NLogging::TLogManagerConfig::CreateDefault(); }); + registrar.Parameter("jaeger", &TThis::Jaeger) + .DefaultNew(); + registrar.Parameter("rpc", &TThis::Rpc) + .DefaultNew(); + registrar.Parameter("tcmalloc", &TThis::TCMalloc) + .DefaultNew(); + registrar.Parameter("stockpile", &TThis::Stockpile) + .DefaultNew(); + registrar.Parameter("enable_ref_counted_tracker_profiling", &TThis::EnableRefCountedTrackerProfiling) + .Default(true); + registrar.Parameter("enable_resource_tracker", &TThis::EnableResourceTracker) + .Default(true); + registrar.Parameter("enable_porto_resource_tracker", &TThis::EnablePortoResourceTracker) + .Default(false); + registrar.Parameter("resource_tracker_vcpu_factor", &TThis::ResourceTrackerVCpuFactor) + .Optional(); + registrar.Parameter("pod_spec", &TThis::PodSpec) + .DefaultNew(); + registrar.Parameter("heap_profiler", &TThis::HeapProfiler) + .DefaultNew(); + + registrar.Postprocessor([] (TThis* config) { + if (config->ResourceTrackerVCpuFactor && !config->EnableResourceTracker) { + THROW_ERROR_EXCEPTION("Option \"resource_tracker_vcpu_factor\" can be specified only if resource tracker is enabled"); + } + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TSingletonsDynamicConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("spin_lock_slow_path_logging_threshold", &TThis::SpinWaitSlowPathLoggingThreshold) + .Optional(); + registrar.Parameter("yt_alloc", &TThis::YTAlloc) + .Optional(); + registrar.Parameter("tcp_dispatcher", &TThis::TcpDispatcher) + .DefaultNew(); + registrar.Parameter("rpc_dispatcher", &TThis::RpcDispatcher) + .DefaultNew(); + registrar.Parameter("logging", &TThis::Logging) + .DefaultNew(); + registrar.Parameter("jaeger", &TThis::Jaeger) + .DefaultNew(); + registrar.Parameter("rpc", &TThis::Rpc) + .DefaultNew(); + registrar.Parameter("tcmalloc", &TThis::TCMalloc) + .Optional(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TDiagnosticDumpConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("yt_alloc_dump_period", &TThis::YTAllocDumpPeriod) + .Default(); + registrar.Parameter("ref_counted_tracker_dump_period", &TThis::RefCountedTrackerDumpPeriod) + .Default(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void WarnForUnrecognizedOptionsImpl( + const NLogging::TLogger& logger, + const IMapNodePtr& unrecognized) +{ + const auto& Logger = logger; + if (unrecognized && unrecognized->GetChildCount() > 0) { + YT_LOG_WARNING("Bootstrap config contains unrecognized options (Unrecognized: %v)", + ConvertToYsonString(unrecognized, NYson::EYsonFormat::Text)); + } +} + +void WarnForUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonStructPtr& config) +{ + WarnForUnrecognizedOptionsImpl(logger, config->GetRecursiveUnrecognized()); +} + +void WarnForUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonSerializablePtr& config) +{ + WarnForUnrecognizedOptionsImpl(logger, config->GetUnrecognizedRecursively()); +} + +void AbortOnUnrecognizedOptionsImpl( + const NLogging::TLogger& logger, + const IMapNodePtr& unrecognized) +{ + const auto& Logger = logger; + if (unrecognized && unrecognized->GetChildCount() > 0) { + YT_LOG_ERROR("Bootstrap config contains unrecognized options, terminating (Unrecognized: %v)", + ConvertToYsonString(unrecognized, NYson::EYsonFormat::Text)); + YT_ABORT(); + } +} + +void AbortOnUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonStructPtr& config) +{ + AbortOnUnrecognizedOptionsImpl(logger, config->GetRecursiveUnrecognized()); +} + +void AbortOnUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonSerializablePtr& config) +{ + AbortOnUnrecognizedOptionsImpl(logger, config->GetUnrecognizedRecursively()); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT + diff --git a/yt/yt/library/program/config.h b/yt/yt/library/program/config.h new file mode 100644 index 0000000000..7d92939f1c --- /dev/null +++ b/yt/yt/library/program/config.h @@ -0,0 +1,224 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/yson_serializable.h> +#include <yt/yt/core/ytree/yson_struct.h> + +#include <yt/yt/core/ytalloc/config.h> + +#include <yt/yt/core/net/config.h> + +#include <yt/yt/core/rpc/config.h> +#include <yt/yt/core/rpc/grpc/config.h> + +#include <yt/yt/core/bus/tcp/config.h> + +#include <yt/yt/core/logging/config.h> + +#include <yt/yt/core/tracing/config.h> + +#include <yt/yt/core/service_discovery/yp/config.h> + +#include <yt/yt/library/profiling/solomon/exporter.h> + +#include <yt/yt/library/containers/config.h> + +#include <yt/yt/library/tracing/jaeger/tracer.h> + +#include <library/cpp/yt/stockpile/stockpile.h> + + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TRpcConfig + : public NYTree::TYsonStruct +{ +public: + NTracing::TTracingConfigPtr Tracing; + + REGISTER_YSON_STRUCT(TRpcConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TRpcConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class THeapSizeLimit + : public virtual NYTree::TYsonStruct +{ +public: + //! Limit program memory in terms of container memory. + // If program heap size exceeds the limit tcmalloc is instructed to release memory to the kernel. + std::optional<double> ContainerMemoryRatio; + + //! If true tcmalloc crashes when system allocates more memory than #ContainerMemoryRatio. + bool IsHard; + + REGISTER_YSON_STRUCT(THeapSizeLimit); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(THeapSizeLimit) + +//////////////////////////////////////////////////////////////////////////////// + +class TTCMallocConfig + : public virtual NYTree::TYsonStruct +{ +public: + i64 BackgroundReleaseRate; + int MaxPerCpuCacheSize; + + //! Threshold in bytes + i64 AggressiveReleaseThreshold; + + //! Threshold in fractions of total memory of the container + std::optional<double> AggressiveReleaseThresholdRatio; + + i64 AggressiveReleaseSize; + TDuration AggressiveReleasePeriod; + + //! Approximately 1/#GuardedSamplingRate of all allocations of + //! size <= 256 KiB will be under GWP-ASAN. + std::optional<i64> GuardedSamplingRate; + + THeapSizeLimitPtr HeapSizeLimit; + + REGISTER_YSON_STRUCT(TTCMallocConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TTCMallocConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TStockpileConfig + : public TStockpileOptions + , public NYTree::TYsonStruct +{ +public: + REGISTER_YSON_STRUCT(TStockpileConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TStockpileConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class THeapProfilerConfig + : public NYTree::TYsonStruct +{ +public: + // Period of update snapshot in heap profiler. + std::optional<TDuration> SnapshotUpdatePeriod; + + REGISTER_YSON_STRUCT(THeapProfilerConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(THeapProfilerConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TSingletonsConfig + : public virtual NYTree::TYsonStruct +{ +public: + TDuration SpinWaitSlowPathLoggingThreshold; + NYTAlloc::TYTAllocConfigPtr YTAlloc; + THashMap<TString, int> FiberStackPoolSizes; + NNet::TAddressResolverConfigPtr AddressResolver; + NBus::TTcpDispatcherConfigPtr TcpDispatcher; + NRpc::TDispatcherConfigPtr RpcDispatcher; + NRpc::NGrpc::TDispatcherConfigPtr GrpcDispatcher; + NServiceDiscovery::NYP::TServiceDiscoveryConfigPtr YPServiceDiscovery; + NProfiling::TSolomonExporterConfigPtr SolomonExporter; + NLogging::TLogManagerConfigPtr Logging; + NTracing::TJaegerTracerConfigPtr Jaeger; + TRpcConfigPtr Rpc; + TTCMallocConfigPtr TCMalloc; + TStockpileConfigPtr Stockpile; + bool EnableRefCountedTrackerProfiling; + bool EnableResourceTracker; + bool EnablePortoResourceTracker; + std::optional<double> ResourceTrackerVCpuFactor; + NContainers::TPodSpecConfigPtr PodSpec; + THeapProfilerConfigPtr HeapProfiler; + + REGISTER_YSON_STRUCT(TSingletonsConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TSingletonsConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TSingletonsDynamicConfig + : public virtual NYTree::TYsonStruct +{ +public: + std::optional<TDuration> SpinWaitSlowPathLoggingThreshold; + NYTAlloc::TYTAllocConfigPtr YTAlloc; + NBus::TTcpDispatcherDynamicConfigPtr TcpDispatcher; + NRpc::TDispatcherDynamicConfigPtr RpcDispatcher; + NLogging::TLogManagerDynamicConfigPtr Logging; + NTracing::TJaegerTracerDynamicConfigPtr Jaeger; + TRpcConfigPtr Rpc; + TTCMallocConfigPtr TCMalloc; + + REGISTER_YSON_STRUCT(TSingletonsDynamicConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TSingletonsDynamicConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TDiagnosticDumpConfig + : public virtual NYTree::TYsonStruct +{ +public: + std::optional<TDuration> YTAllocDumpPeriod; + std::optional<TDuration> RefCountedTrackerDumpPeriod; + + REGISTER_YSON_STRUCT(TDiagnosticDumpConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TDiagnosticDumpConfig) + +//////////////////////////////////////////////////////////////////////////////// + +// NB: These functions should not be called from bootstrap +// config validator since logger is not set up yet. +void WarnForUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonStructPtr& config); + +void WarnForUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonSerializablePtr& config); + +void AbortOnUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonStructPtr& config); + +void AbortOnUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonSerializablePtr& config); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/helpers.cpp b/yt/yt/library/program/helpers.cpp new file mode 100644 index 0000000000..5c7ff29db1 --- /dev/null +++ b/yt/yt/library/program/helpers.cpp @@ -0,0 +1,335 @@ +#include "helpers.h" +#include "config.h" +#include "private.h" + +#include <yt/yt/core/ytalloc/bindings.h> + +#include <yt/yt/core/misc/lazy_ptr.h> +#include <yt/yt/core/misc/ref_counted_tracker.h> +#include <yt/yt/core/misc/ref_counted_tracker_profiler.h> + +#include <yt/yt/core/bus/tcp/dispatcher.h> + +#include <yt/yt/library/tracing/jaeger/tracer.h> + +#include <yt/yt/library/profiling/perf/counters.h> + +#include <yt/yt/library/profiling/resource_tracker/resource_tracker.h> + +#include <yt/yt/library/containers/config.h> +#include <yt/yt/library/containers/porto_resource_tracker.h> + +#include <yt/yt/core/logging/log_manager.h> + +#include <yt/yt/core/concurrency/execution_stack.h> +#include <yt/yt/core/concurrency/periodic_executor.h> +#include <yt/yt/core/concurrency/private.h> + +#include <tcmalloc/malloc_extension.h> + +#include <yt/yt/core/net/address.h> +#include <yt/yt/core/net/local_address.h> + +#include <yt/yt/core/rpc/dispatcher.h> +#include <yt/yt/core/rpc/grpc/dispatcher.h> + +#include <yt/yt/core/service_discovery/yp/service_discovery.h> + +#include <yt/yt/core/threading/spin_wait_slow_path_logger.h> + +#include <library/cpp/yt/threading/spin_wait_hook.h> + +#include <library/cpp/yt/memory/atomic_intrusive_ptr.h> + +#include <util/string/split.h> +#include <util/system/thread.h> + +#include <mutex> +#include <thread> + +namespace NYT { + +using namespace NConcurrency; +using namespace NThreading; + +//////////////////////////////////////////////////////////////////////////////// + +static std::once_flag InitAggressiveReleaseThread; +static auto& Logger = ProgramLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TCMallocLimitsAdjuster +{ +public: + void Adjust(const TTCMallocConfigPtr& config) + { + i64 totalMemory = GetContainerMemoryLimit(); + AdjustPageHeapLimit(totalMemory, config); + AdjustAggressiveReleaseThreshold(totalMemory, config); + } + + i64 GetAggressiveReleaseThreshold() + { + return AggressiveReleaseThreshold_; + } + +private: + using TAllocatorMemoryLimit = tcmalloc::MallocExtension::MemoryLimit; + + TAllocatorMemoryLimit AppliedLimit_; + i64 AggressiveReleaseThreshold_ = 0; + + + void AdjustPageHeapLimit(i64 totalMemory, const TTCMallocConfigPtr& config) + { + auto proposed = ProposeHeapMemoryLimit(totalMemory, config); + + if (proposed.limit == AppliedLimit_.limit && proposed.hard == AppliedLimit_.hard) { + // Already applied + return; + } + + YT_LOG_INFO("Changing tcmalloc memory limit (Limit: %v, IsHard: %v)", + proposed.limit, + proposed.hard); + + tcmalloc::MallocExtension::SetMemoryLimit(proposed); + AppliedLimit_ = proposed; + } + + void AdjustAggressiveReleaseThreshold(i64 totalMemory, const TTCMallocConfigPtr& config) + { + if (totalMemory && config->AggressiveReleaseThresholdRatio) { + AggressiveReleaseThreshold_ = *config->AggressiveReleaseThresholdRatio * totalMemory; + } else { + AggressiveReleaseThreshold_ = config->AggressiveReleaseThreshold; + } + } + + i64 GetContainerMemoryLimit() const + { + auto resourceTracker = NProfiling::GetResourceTracker(); + if (!resourceTracker) { + return 0; + } + + return resourceTracker->GetTotalMemoryLimit(); + } + + TAllocatorMemoryLimit ProposeHeapMemoryLimit(i64 totalMemory, const TTCMallocConfigPtr& config) const + { + const auto& heapLimitConfig = config->HeapSizeLimit; + + if (totalMemory == 0 || !heapLimitConfig->ContainerMemoryRatio) { + return {}; + } + + TAllocatorMemoryLimit proposed; + proposed.limit = *heapLimitConfig->ContainerMemoryRatio * totalMemory; + proposed.hard = heapLimitConfig->IsHard; + + return proposed; + } +}; + +void ConfigureTCMalloc(const TTCMallocConfigPtr& config) +{ + tcmalloc::MallocExtension::SetBackgroundReleaseRate( + tcmalloc::MallocExtension::BytesPerSecond{static_cast<size_t>(config->BackgroundReleaseRate)}); + + tcmalloc::MallocExtension::SetMaxPerCpuCacheSize(config->MaxPerCpuCacheSize); + + if (config->GuardedSamplingRate) { + tcmalloc::MallocExtension::SetGuardedSamplingRate(*config->GuardedSamplingRate); + tcmalloc::MallocExtension::ActivateGuardedSampling(); + } + + struct TConfigSingleton + { + TAtomicIntrusivePtr<TTCMallocConfig> Config; + }; + + LeakySingleton<TConfigSingleton>()->Config.Store(config); + + if (tcmalloc::MallocExtension::NeedsProcessBackgroundActions()) { + std::call_once(InitAggressiveReleaseThread, [] { + std::thread([] { + ::TThread::SetCurrentThreadName("TCAllocYT"); + + TCMallocLimitsAdjuster limitsAdjuster; + + while (true) { + auto config = LeakySingleton<TConfigSingleton>()->Config.Acquire(); + limitsAdjuster.Adjust(config); + + auto freeBytes = tcmalloc::MallocExtension::GetNumericProperty("tcmalloc.page_heap_free"); + YT_VERIFY(freeBytes); + + if (static_cast<i64>(*freeBytes) > limitsAdjuster.GetAggressiveReleaseThreshold()) { + + YT_LOG_DEBUG("Aggressively releasing memory (FreeBytes: %v, Threshold: %v)", + static_cast<i64>(*freeBytes), + limitsAdjuster.GetAggressiveReleaseThreshold()); + + tcmalloc::MallocExtension::ReleaseMemoryToSystem(config->AggressiveReleaseSize); + } + + Sleep(config->AggressiveReleasePeriod); + } + }).detach(); + }); + } +} + +template <class TConfig> +void ConfigureSingletonsImpl(const TConfig& config) +{ + SetSpinWaitSlowPathLoggingThreshold(config->SpinWaitSlowPathLoggingThreshold); + + if (!NYTAlloc::ConfigureFromEnv()) { + NYTAlloc::Configure(config->YTAlloc); + } + + for (const auto& [kind, size] : config->FiberStackPoolSizes) { + NConcurrency::SetFiberStackPoolSize(ParseEnum<NConcurrency::EExecutionStackKind>(kind), size); + } + + NLogging::TLogManager::Get()->EnableReopenOnSighup(); + if (!NLogging::TLogManager::Get()->IsConfiguredFromEnv()) { + NLogging::TLogManager::Get()->Configure(config->Logging); + } + + NNet::TAddressResolver::Get()->Configure(config->AddressResolver); + // By default, server components must have a reasonable FQDN. + // Failure to do so may result in issues like YT-4561. + NNet::TAddressResolver::Get()->EnsureLocalHostName(); + + NBus::TTcpDispatcher::Get()->Configure(config->TcpDispatcher); + + NRpc::TDispatcher::Get()->Configure(config->RpcDispatcher); + + NRpc::NGrpc::TDispatcher::Get()->Configure(config->GrpcDispatcher); + + NRpc::TDispatcher::Get()->SetServiceDiscovery( + NServiceDiscovery::NYP::CreateServiceDiscovery(config->YPServiceDiscovery)); + + NTracing::SetGlobalTracer(New<NTracing::TJaegerTracer>(config->Jaeger)); + + NProfiling::EnablePerfCounters(); + + if (auto tracingConfig = config->Rpc->Tracing) { + NTracing::SetTracingConfig(tracingConfig); + } + + ConfigureTCMalloc(config->TCMalloc); + + ConfigureStockpile(*config->Stockpile); + + if (config->EnableRefCountedTrackerProfiling) { + EnableRefCountedTrackerProfiling(); + } + + if (config->EnableResourceTracker) { + NProfiling::EnableResourceTracker(); + if (config->ResourceTrackerVCpuFactor.has_value()) { + NProfiling::SetVCpuFactor(config->ResourceTrackerVCpuFactor.value()); + } + } + + if (config->EnablePortoResourceTracker) { + NContainers::EnablePortoResourceTracker(config->PodSpec); + } +} + +void ConfigureSingletons(const TSingletonsConfigPtr& config) +{ + ConfigureSingletonsImpl(config); +} + +template <class TStaticConfig, class TDynamicConfig> +void ReconfigureSingletonsImpl(const TStaticConfig& config, const TDynamicConfig& dynamicConfig) +{ + SetSpinWaitSlowPathLoggingThreshold(dynamicConfig->SpinWaitSlowPathLoggingThreshold.value_or(config->SpinWaitSlowPathLoggingThreshold)); + + if (!NYTAlloc::IsConfiguredFromEnv()) { + NYTAlloc::Configure(dynamicConfig->YTAlloc ? dynamicConfig->YTAlloc : config->YTAlloc); + } + + if (!NLogging::TLogManager::Get()->IsConfiguredFromEnv()) { + NLogging::TLogManager::Get()->Configure( + config->Logging->ApplyDynamic(dynamicConfig->Logging), + /*sync*/ false); + } + + auto tracer = NTracing::GetGlobalTracer(); + if (auto jaeger = DynamicPointerCast<NTracing::TJaegerTracer>(tracer); jaeger) { + jaeger->Configure(config->Jaeger->ApplyDynamic(dynamicConfig->Jaeger)); + } + + NBus::TTcpDispatcher::Get()->Configure(config->TcpDispatcher->ApplyDynamic(dynamicConfig->TcpDispatcher)); + + NRpc::TDispatcher::Get()->Configure(config->RpcDispatcher->ApplyDynamic(dynamicConfig->RpcDispatcher)); + + if (dynamicConfig->Rpc->Tracing) { + NTracing::SetTracingConfig(dynamicConfig->Rpc->Tracing); + } else if (config->Rpc->Tracing) { + NTracing::SetTracingConfig(config->Rpc->Tracing); + } + + if (dynamicConfig->TCMalloc) { + ConfigureTCMalloc(dynamicConfig->TCMalloc); + } else if (config->TCMalloc) { + ConfigureTCMalloc(config->TCMalloc); + } +} + +void ReconfigureSingletons(const TSingletonsConfigPtr& config, const TSingletonsDynamicConfigPtr& dynamicConfig) +{ + ReconfigureSingletonsImpl(config, dynamicConfig); +} + +template <class TConfig> +void StartDiagnosticDumpImpl(const TConfig& config) +{ + static NLogging::TLogger Logger("DiagDump"); + + auto logDumpString = [&] (TStringBuf banner, const TString& str) { + for (const auto& line : StringSplitter(str).Split('\n')) { + YT_LOG_DEBUG("%v %v", banner, line.Token()); + } + }; + + if (config->YTAllocDumpPeriod) { + static const TLazyIntrusivePtr<TPeriodicExecutor> Executor(BIND([&] { + return New<TPeriodicExecutor>( + NRpc::TDispatcher::Get()->GetHeavyInvoker(), + BIND([&] { + logDumpString("YTAlloc", NYTAlloc::FormatAllocationCounters()); + })); + })); + Executor->SetPeriod(config->YTAllocDumpPeriod); + Executor->Start(); + } + + if (config->RefCountedTrackerDumpPeriod) { + static const TLazyIntrusivePtr<TPeriodicExecutor> Executor(BIND([&] { + return New<TPeriodicExecutor>( + NRpc::TDispatcher::Get()->GetHeavyInvoker(), + BIND([&] { + logDumpString("RCT", TRefCountedTracker::Get()->GetDebugInfo()); + })); + })); + Executor->SetPeriod(config->RefCountedTrackerDumpPeriod); + Executor->Start(); + } +} + +void StartDiagnosticDump(const TDiagnosticDumpConfigPtr& config) +{ + StartDiagnosticDumpImpl(config); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/helpers.h b/yt/yt/library/program/helpers.h new file mode 100644 index 0000000000..be09ec889c --- /dev/null +++ b/yt/yt/library/program/helpers.h @@ -0,0 +1,18 @@ +#pragma once + +#include "public.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +void ConfigureSingletons(const TSingletonsConfigPtr& config); +void ReconfigureSingletons( + const TSingletonsConfigPtr& config, + const TSingletonsDynamicConfigPtr& dynamicConfig); + +void StartDiagnosticDump(const TDiagnosticDumpConfigPtr& config); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/private.h b/yt/yt/library/program/private.h new file mode 100644 index 0000000000..e328f30667 --- /dev/null +++ b/yt/yt/library/program/private.h @@ -0,0 +1,15 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/logging/log.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger ProgramLogger("Program"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program.cpp b/yt/yt/library/program/program.cpp new file mode 100644 index 0000000000..318d998b3a --- /dev/null +++ b/yt/yt/library/program/program.cpp @@ -0,0 +1,385 @@ +#include "program.h" + +#include "build_attributes.h" + +#include <yt/yt/build/build.h> + +#include <yt/yt/core/misc/crash_handler.h> +#include <yt/yt/core/misc/signal_registry.h> +#include <yt/yt/core/misc/fs.h> +#include <yt/yt/core/misc/shutdown.h> + +#include <yt/yt/core/ytalloc/bindings.h> + +#include <yt/yt/core/yson/writer.h> +#include <yt/yt/core/yson/null_consumer.h> + +#include <yt/yt/core/logging/log_manager.h> + +#include <yt/yt/library/ytprof/heap_profiler.h> + +#include <yt/yt/library/profiling/tcmalloc/profiler.h> + +#include <library/cpp/ytalloc/api/ytalloc.h> + +#include <library/cpp/yt/mlock/mlock.h> +#include <library/cpp/yt/stockpile/stockpile.h> + +#include <tcmalloc/malloc_extension.h> + +#include <absl/debugging/stacktrace.h> + +#include <util/system/thread.h> +#include <util/system/sigset.h> + +#include <util/string/subst.h> + +#include <thread> + +#include <stdlib.h> + +#ifdef _unix_ +#include <unistd.h> +#include <sys/types.h> +#include <sys/stat.h> +#endif + +#ifdef _linux_ +#include <grp.h> +#include <sys/prctl.h> +#endif + +#if defined(_linux_) && defined(CLANG_COVERAGE) +extern "C" int __llvm_profile_write_file(void); +extern "C" void __llvm_profile_set_filename(const char* name); +#endif + +namespace NYT { + +using namespace NYson; + +//////////////////////////////////////////////////////////////////////////////// + +class TProgram::TOptsParseResult + : public NLastGetopt::TOptsParseResult +{ +public: + TOptsParseResult(TProgram* owner, int argc, const char** argv) + : Owner_(owner) + { + Init(&Owner_->Opts_, argc, argv); + } + + void HandleError() const override + { + Owner_->OnError(CurrentExceptionMessage()); + Cerr << Endl << "Try running '" << Owner_->Argv0_ << " --help' for more information." << Endl; + Owner_->Exit(EProgramExitCode::OptionsError); + } + +private: + TProgram* const Owner_; +}; + +TProgram::TProgram() +{ + Opts_.AddHelpOption(); + Opts_.AddLongOption("yt-version", "print YT version and exit") + .NoArgument() + .StoreValue(&PrintYTVersion_, true); + Opts_.AddLongOption("version", "print version and exit") + .NoArgument() + .StoreValue(&PrintVersion_, true); + Opts_.AddLongOption("yson", "print build information in YSON") + .NoArgument() + .StoreValue(&UseYson_, true); + Opts_.AddLongOption("build", "print build information and exit") + .NoArgument() + .StoreValue(&PrintBuild_, true); + Opts_.SetFreeArgsNum(0); + + ConfigureCoverageOutput(); +} + +void TProgram::SetCrashOnError() +{ + CrashOnError_ = true; +} + +TProgram::~TProgram() = default; + +void TProgram::HandleVersionAndBuild() +{ + if (PrintVersion_) { + PrintVersionAndExit(); + } + if (PrintYTVersion_) { + PrintYTVersionAndExit(); + } + if (PrintBuild_) { + PrintBuildAndExit(); + } +} + +int TProgram::Run(int argc, const char** argv) +{ + ::srand(time(nullptr)); + + auto run = [&] { + Argv0_ = TString(argv[0]); + TOptsParseResult result(this, argc, argv); + + HandleVersionAndBuild(); + + DoRun(result); + }; + + if (!CrashOnError_) { + try { + run(); + Exit(EProgramExitCode::OK); + } catch (...) { + OnError(CurrentExceptionMessage()); + Exit(EProgramExitCode::ProgramError); + } + } else { + run(); + Exit(EProgramExitCode::OK); + } + + // Cannot reach this due to #Exit calls above. + YT_ABORT(); +} + +void TProgram::Abort(EProgramExitCode code) noexcept +{ + Abort(static_cast<int>(code)); +} + +void TProgram::Abort(int code) noexcept +{ + NLogging::TLogManager::Get()->Shutdown(); + + ::_exit(code); +} + +void TProgram::Exit(EProgramExitCode code) noexcept +{ + Exit(static_cast<int>(code)); +} + +void TProgram::Exit(int code) noexcept +{ +#if defined(_linux_) && defined(CLANG_COVERAGE) + __llvm_profile_write_file(); +#endif + + // This explicit call may become obsolete some day; + // cf. the comment section for NYT::Shutdown. + Shutdown({ + .AbortOnHang = ShouldAbortOnHungShutdown(), + .HungExitCode = code + }); + + ::exit(code); +} + +bool TProgram::ShouldAbortOnHungShutdown() noexcept +{ + return true; +} + +void TProgram::OnError(const TString& message) noexcept +{ + try { + Cerr << message << Endl; + } catch (...) { + // Just ignore it; STDERR might be closed already, + // and write() would result in EPIPE. + } +} + +void TProgram::PrintYTVersionAndExit() +{ + if (UseYson_) { + THROW_ERROR_EXCEPTION("--yson is not supported when printing version"); + } + Cout << GetVersion() << Endl; + Exit(0); +} + +void TProgram::PrintBuildAndExit() +{ + if (UseYson_) { + TYsonWriter writer(&Cout, EYsonFormat::Pretty); + Serialize(BuildBuildAttributes(), &writer); + Cout << Endl; + } else { + Cout << "Build Time: " << GetBuildTime() << Endl; + Cout << "Build Host: " << GetBuildHost() << Endl; + } + Exit(0); +} + +void TProgram::PrintVersionAndExit() +{ + PrintYTVersionAndExit(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TProgramException::TProgramException(TString what) + : What_(std::move(what)) +{ } + +const char* TProgramException::what() const noexcept +{ + return What_.c_str(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TString CheckPathExistsArgMapper(const TString& arg) +{ + if (!NFS::Exists(arg)) { + throw TProgramException(Format("File %v does not exist", arg)); + } + return arg; +} + +TGuid CheckGuidArgMapper(const TString& arg) +{ + TGuid result; + if (!TGuid::FromString(arg, &result)) { + throw TProgramException(Format("Error parsing guid %Qv", arg)); + } + return result; +} + +NYson::TYsonString CheckYsonArgMapper(const TString& arg) +{ + ParseYsonStringBuffer(arg, EYsonType::Node, GetNullYsonConsumer()); + return NYson::TYsonString(arg); +} + +void ConfigureUids() +{ +#ifdef _unix_ + uid_t ruid, euid; +#ifdef _linux_ + uid_t suid; + YT_VERIFY(getresuid(&ruid, &euid, &suid) == 0); +#else + ruid = getuid(); + euid = geteuid(); +#endif + if (euid == 0) { + // if real uid is already root do not set root as supplementary ids. + if (ruid != 0) { + YT_VERIFY(setgroups(0, nullptr) == 0); + } + // if effective uid == 0 (e. g. set-uid-root), alter saved = effective, effective = real. +#ifdef _linux_ + YT_VERIFY(setresuid(ruid, ruid, euid) == 0); + // Make server suid_dumpable = 1. + YT_VERIFY(prctl(PR_SET_DUMPABLE, 1) == 0); +#else + YT_VERIFY(setuid(euid) == 0); + YT_VERIFY(seteuid(ruid) == 0); + YT_VERIFY(setruid(ruid) == 0); +#endif + } + umask(0000); +#endif +} + +void ConfigureCoverageOutput() +{ +#if defined(_linux_) && defined(CLANG_COVERAGE) + // YT tests use pid namespaces. We can't use process id as unique identifier for output file. + if (auto profileFile = getenv("LLVM_PROFILE_FILE")) { + TString fixedProfile{profileFile}; + SubstGlobal(fixedProfile, "%e", "ytserver-all"); + SubstGlobal(fixedProfile, "%p", ToString(TInstant::Now().NanoSeconds())); + __llvm_profile_set_filename(fixedProfile.c_str()); + } +#endif +} + +void ConfigureIgnoreSigpipe() +{ +#ifdef _unix_ + signal(SIGPIPE, SIG_IGN); +#endif +} + +void ConfigureCrashHandler() +{ + TSignalRegistry::Get()->PushCallback(AllCrashSignals, CrashSignalHandler); + TSignalRegistry::Get()->PushDefaultSignalHandler(AllCrashSignals); +} + +namespace { + +void ExitZero(int /*unused*/) +{ +#if defined(_linux_) && defined(CLANG_COVERAGE) + __llvm_profile_write_file(); +#endif + // TODO(babenko): replace with pure "exit" some day. + // Currently this causes some RPC requests to master to be replied with "Promise abandoned" error, + // which is not retriable. + _exit(0); +} + +} // namespace + +void ConfigureExitZeroOnSigterm() +{ +#ifdef _unix_ + signal(SIGTERM, ExitZero); +#endif +} + +void ConfigureAllocator(const TAllocatorOptions& options) +{ + NYT::MlockFileMappings(); + +#ifdef _linux_ + NYTAlloc::EnableYTLogging(); + NYTAlloc::EnableYTProfiling(); + NYTAlloc::InitializeLibunwindInterop(); + NYTAlloc::SetEnableEagerMemoryRelease(options.YTAllocEagerMemoryRelease); + + if (tcmalloc::MallocExtension::NeedsProcessBackgroundActions()) { + std::thread backgroundThread([] { + TThread::SetCurrentThreadName("TCAllocBack"); + tcmalloc::MallocExtension::ProcessBackgroundActions(); + YT_ABORT(); + }); + backgroundThread.detach(); + } + + NProfiling::EnableTCMallocProfiler(); + + NYTProf::EnableMemoryProfilingTags(options.SnapshotUpdatePeriod); + + absl::SetStackUnwinder(NYTProf::AbslStackUnwinder); + // TODO(prime@): tune parameters. + tcmalloc::MallocExtension::SetProfileSamplingRate(2_MB); + if (options.TCMallocGuardedSamplingRate) { + tcmalloc::MallocExtension::SetGuardedSamplingRate(*options.TCMallocGuardedSamplingRate); + tcmalloc::MallocExtension::ActivateGuardedSampling(); + } + tcmalloc::MallocExtension::SetMaxPerCpuCacheSize(3_MB); + tcmalloc::MallocExtension::SetMaxTotalThreadCacheBytes(24_MB); + tcmalloc::MallocExtension::SetBackgroundReleaseRate(tcmalloc::MallocExtension::BytesPerSecond{32_MB}); + tcmalloc::MallocExtension::EnableForkSupport(); +#else + Y_UNUSED(options); +#endif +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program.h b/yt/yt/library/program/program.h new file mode 100644 index 0000000000..f1f236cfab --- /dev/null +++ b/yt/yt/library/program/program.h @@ -0,0 +1,148 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +#include <library/cpp/yt/stockpile/stockpile.h> + +#include <library/cpp/getopt/last_getopt.h> + +#include <yt/yt/core/yson/string.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(EProgramExitCode, + ((OK)(0)) + ((OptionsError)(1)) + ((ProgramError)(2)) +); + +class TProgram +{ +public: + TProgram(); + ~TProgram(); + + TProgram(const TProgram&) = delete; + TProgram(TProgram&&) = delete; + + // This call actually never returns; + // |int| return type is just for the symmetry with |main|. + [[noreturn]] + int Run(int argc, const char** argv); + + //! Handles --version/--yt-version/--build [--yson] if they are present. + void HandleVersionAndBuild(); + + //! Nongracefully aborts the program. + /*! + * Tries to flush logging messages. + * Aborts via |_exit| call. + */ + [[noreturn]] + static void Abort(EProgramExitCode code) noexcept; + [[noreturn]] + static void Abort(int code) noexcept; + +protected: + NLastGetopt::TOpts Opts_; + TString Argv0_; + bool PrintYTVersion_ = false; + bool PrintVersion_ = false; + bool PrintBuild_ = false; + bool UseYson_ = false; + + virtual void DoRun(const NLastGetopt::TOptsParseResult& parseResult) = 0; + + virtual void OnError(const TString& message) noexcept; + + virtual bool ShouldAbortOnHungShutdown() noexcept; + + void SetCrashOnError(); + + //! Handler for --yt-version command argument. + [[noreturn]] + void PrintYTVersionAndExit(); + + //! Handler for --build command argument. + [[noreturn]] + void PrintBuildAndExit(); + + //! Handler for --version command argument. + //! By default, --version and --yt-version work the same way, + //! but some YT components (e.g. CHYT) can override it to provide its own version. + [[noreturn]] + virtual void PrintVersionAndExit(); + + [[noreturn]] + void Exit(EProgramExitCode code) noexcept; + + [[noreturn]] + void Exit(int code) noexcept; + +private: + bool CrashOnError_ = false; + + // Custom handler for option parsing errors. + class TOptsParseResult; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//! The simplest exception possible. +//! Here we refrain from using TErrorException, as it relies on proper configuration of singleton subsystems, +//! which might not be the case during startup. +class TProgramException + : public std::exception +{ +public: + explicit TProgramException(TString what); + + const char* what() const noexcept override; + +private: + const TString What_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//! Helper for TOpt::StoreMappedResult to validate file paths for existence. +TString CheckPathExistsArgMapper(const TString& arg); + +//! Helper for TOpt::StoreMappedResult to parse GUIDs. +TGuid CheckGuidArgMapper(const TString& arg); + +//! Helper for TOpt::StoreMappedResult to parse YSON strings. +NYson::TYsonString CheckYsonArgMapper(const TString& arg); + +//! Drop privileges and save them if running with suid-bit. +void ConfigureUids(); + +void ConfigureCoverageOutput(); + +void ConfigureIgnoreSigpipe(); + +//! Intercepts standard crash signals (see signal_registry.h for full list) with a nice handler. +void ConfigureCrashHandler(); + +//! Intercepts SIGTERM and terminates the process immediately with zero exit code. +void ConfigureExitZeroOnSigterm(); + +//////////////////////////////////////////////////////////////////////////////// + +struct TAllocatorOptions +{ + bool YTAllocEagerMemoryRelease = false; + + bool TCMallocOptimizeSize = false; + std::optional<i64> TCMallocGuardedSamplingRate = 128_MB; + + std::optional<TDuration> SnapshotUpdatePeriod; +}; + +void ConfigureAllocator(const TAllocatorOptions& options = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program_config_mixin.cpp b/yt/yt/library/program/program_config_mixin.cpp new file mode 100644 index 0000000000..9ced4de64f --- /dev/null +++ b/yt/yt/library/program/program_config_mixin.cpp @@ -0,0 +1 @@ +#include "program_config_mixin.h" diff --git a/yt/yt/library/program/program_config_mixin.h b/yt/yt/library/program/program_config_mixin.h new file mode 100644 index 0000000000..80f681d06e --- /dev/null +++ b/yt/yt/library/program/program_config_mixin.h @@ -0,0 +1,166 @@ +#pragma once + +#include "program.h" + +#include <library/cpp/yt/string/enum.h> + +#include <yt/yt/core/ytree/convert.h> +#include <yt/yt/core/ytree/yson_serializable.h> + +#include <util/stream/file.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +template <class TConfig, class TDynamicConfig = void> +class TProgramConfigMixin +{ +protected: + explicit TProgramConfigMixin( + NLastGetopt::TOpts& opts, + bool required = true, + const TString& argumentName = "config") + : ArgumentName_(argumentName) + { + auto opt = opts + .AddLongOption(TString(argumentName), Format("path to %v file", argumentName)) + .StoreMappedResult(&ConfigPath_, &CheckPathExistsArgMapper) + .RequiredArgument("FILE"); + if (required) { + opt.Required(); + } else { + opt.Optional(); + } + opts + .AddLongOption( + Format("%v-template", argumentName), + Format("print %v template and exit", argumentName)) + .SetFlag(&ConfigTemplate_); + opts + .AddLongOption( + Format("%v-actual", argumentName), + Format("print actual %v and exit", argumentName)) + .SetFlag(&ConfigActual_); + opts + .AddLongOption( + Format("%v-unrecognized-strategy", argumentName), + Format("configure strategy for unrecognized attributes in %v", argumentName)) + .Handler1T<TStringBuf>([this](TStringBuf value) { + UnrecognizedStrategy_ = ParseEnum<NYTree::EUnrecognizedStrategy>(value); + }); + + if constexpr (std::is_same_v<TDynamicConfig, void>) { + return; + } + + opts + .AddLongOption( + Format("dynamic-%v-template", argumentName), + Format("print dynamic %v template and exit", argumentName)) + .SetFlag(&DynamicConfigTemplate_); + } + + TIntrusivePtr<TConfig> GetConfig(bool returnNullIfNotSupplied = false) + { + if (returnNullIfNotSupplied && !ConfigPath_) { + return nullptr; + } + + if (!Config_) { + LoadConfig(); + } + return Config_; + } + + NYTree::INodePtr GetConfigNode(bool returnNullIfNotSupplied = false) + { + if (returnNullIfNotSupplied && !ConfigPath_) { + return nullptr; + } + + if (!ConfigNode_) { + LoadConfigNode(); + } + return ConfigNode_; + } + + bool HandleConfigOptions() + { + auto print = [] (const auto& config) { + using namespace NYson; + TYsonWriter writer(&Cout, EYsonFormat::Pretty); + config->Save(&writer); + Cout << Flush; + }; + if (ConfigTemplate_) { + print(New<TConfig>()); + return true; + } + if (ConfigActual_) { + print(GetConfig()); + return true; + } + + if constexpr (!std::is_same_v<TDynamicConfig, void>) { + if (DynamicConfigTemplate_) { + print(New<TDynamicConfig>()); + return true; + } + } + return false; + } + +private: + void LoadConfigNode() + { + using namespace NYTree; + + if (!ConfigPath_){ + THROW_ERROR_EXCEPTION("Missing --%v option", ArgumentName_); + } + + try { + TIFStream stream(ConfigPath_); + ConfigNode_ = ConvertToNode(&stream); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Error parsing %v file %v", + ArgumentName_, + ConfigPath_) + << ex; + } + } + + void LoadConfig() + { + if (!ConfigNode_) { + LoadConfigNode(); + } + + try { + Config_ = New<TConfig>(); + Config_->SetUnrecognizedStrategy(UnrecognizedStrategy_); + Config_->Load(ConfigNode_); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Error loading %v file %v", + ArgumentName_, + ConfigPath_) + << ex; + } + } + + const TString ArgumentName_; + + TString ConfigPath_; + bool ConfigTemplate_; + bool ConfigActual_; + bool DynamicConfigTemplate_ = false; + NYTree::EUnrecognizedStrategy UnrecognizedStrategy_ = NYTree::EUnrecognizedStrategy::KeepRecursive; + + TIntrusivePtr<TConfig> Config_; + NYTree::INodePtr ConfigNode_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program_pdeathsig_mixin.cpp b/yt/yt/library/program/program_pdeathsig_mixin.cpp new file mode 100644 index 0000000000..34f1f3b9a8 --- /dev/null +++ b/yt/yt/library/program/program_pdeathsig_mixin.cpp @@ -0,0 +1,36 @@ +#include "program_pdeathsig_mixin.h" + +#ifdef _linux_ +#include <sys/prctl.h> +#endif + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +TProgramPdeathsigMixin::TProgramPdeathsigMixin(NLastGetopt::TOpts& opts) +{ + opts.AddLongOption("pdeathsig", "parent death signal") + .StoreResult(&ParentDeathSignal_) + .RequiredArgument("PDEATHSIG"); +} + +bool TProgramPdeathsigMixin::HandlePdeathsigOptions() +{ + if (ParentDeathSignal_ > 0) { +#ifdef _linux_ + // Parent death signal is set by testing framework to avoid dangling processes when test runner crashes. + // Unfortunately, setting pdeathsig in preexec_fn in subprocess call in test runner is not working + // when the program has suid bit (pdeath_sig is reset after exec call in this case) + // More details can be found in + // http://linux.die.net/man/2/prctl + // http://www.isec.pl/vulnerabilities/isec-0024-death-signal.txt + YT_VERIFY(prctl(PR_SET_PDEATHSIG, ParentDeathSignal_) == 0); +#endif + } + return false; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program_pdeathsig_mixin.h b/yt/yt/library/program/program_pdeathsig_mixin.h new file mode 100644 index 0000000000..3e4bcfd4a6 --- /dev/null +++ b/yt/yt/library/program/program_pdeathsig_mixin.h @@ -0,0 +1,22 @@ +#pragma once + +#include "program.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TProgramPdeathsigMixin +{ +protected: + explicit TProgramPdeathsigMixin(NLastGetopt::TOpts& opts); + + bool HandlePdeathsigOptions(); + +private: + int ParentDeathSignal_ = -1; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program_setsid_mixin.cpp b/yt/yt/library/program/program_setsid_mixin.cpp new file mode 100644 index 0000000000..a745fcd3a2 --- /dev/null +++ b/yt/yt/library/program/program_setsid_mixin.cpp @@ -0,0 +1,30 @@ +#include "program_setsid_mixin.h" + +#ifdef _linux_ +#include <unistd.h> +#endif + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +TProgramSetsidMixin::TProgramSetsidMixin(NLastGetopt::TOpts& opts) +{ + opts.AddLongOption("setsid", "create a new session") + .StoreTrue(&Setsid_) + .Optional(); +} + +bool TProgramSetsidMixin::HandleSetsidOptions() +{ + if (Setsid_) { +#ifdef _linux_ + setsid(); +#endif + } + return false; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program_setsid_mixin.h b/yt/yt/library/program/program_setsid_mixin.h new file mode 100644 index 0000000000..00b3dff50e --- /dev/null +++ b/yt/yt/library/program/program_setsid_mixin.h @@ -0,0 +1,22 @@ +#pragma once + +#include "program.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TProgramSetsidMixin +{ +protected: + explicit TProgramSetsidMixin(NLastGetopt::TOpts& opts); + + bool HandleSetsidOptions(); + +private: + bool Setsid_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/public.h b/yt/yt/library/program/public.h new file mode 100644 index 0000000000..b10575778e --- /dev/null +++ b/yt/yt/library/program/public.h @@ -0,0 +1,21 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_CLASS(TBuildInfo) +DECLARE_REFCOUNTED_CLASS(TRpcConfig) +DECLARE_REFCOUNTED_CLASS(TTCMallocConfig) +DECLARE_REFCOUNTED_CLASS(TStockpileConfig) +DECLARE_REFCOUNTED_CLASS(TSingletonsConfig) +DECLARE_REFCOUNTED_CLASS(TSingletonsDynamicConfig) +DECLARE_REFCOUNTED_CLASS(TDiagnosticDumpConfig) +DECLARE_REFCOUNTED_CLASS(THeapSizeLimit) +DECLARE_REFCOUNTED_CLASS(THeapProfilerConfig) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/ya.make b/yt/yt/library/program/ya.make new file mode 100644 index 0000000000..5742ce9287 --- /dev/null +++ b/yt/yt/library/program/ya.make @@ -0,0 +1,30 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + build_attributes.cpp + config.cpp + helpers.cpp + program.cpp + program_config_mixin.cpp + program_pdeathsig_mixin.cpp + program_setsid_mixin.cpp +) + +PEERDIR( + yt/yt/core + yt/yt/core/service_discovery/yp + yt/yt/library/monitoring + yt/yt/library/containers + yt/yt/library/profiling/solomon + yt/yt/library/profiling/tcmalloc + yt/yt/library/profiling/perf + yt/yt/library/ytprof + yt/yt/library/tracing/jaeger + library/cpp/yt/mlock + library/cpp/yt/stockpile + library/cpp/yt/string +) + +END() diff --git a/yt/yt/library/tracing/CMakeLists.darwin-x86_64.txt b/yt/yt/library/tracing/CMakeLists.darwin-x86_64.txt index 54f67f1ec5..68eb83d171 100644 --- a/yt/yt/library/tracing/CMakeLists.darwin-x86_64.txt +++ b/yt/yt/library/tracing/CMakeLists.darwin-x86_64.txt @@ -6,6 +6,7 @@ # original buildsystem will not be accepted. +add_subdirectory(jaeger) add_library(yt-library-tracing) target_compile_options(yt-library-tracing PRIVATE diff --git a/yt/yt/library/tracing/CMakeLists.linux-aarch64.txt b/yt/yt/library/tracing/CMakeLists.linux-aarch64.txt index 923b49e3e7..5681c6a06e 100644 --- a/yt/yt/library/tracing/CMakeLists.linux-aarch64.txt +++ b/yt/yt/library/tracing/CMakeLists.linux-aarch64.txt @@ -6,6 +6,7 @@ # original buildsystem will not be accepted. +add_subdirectory(jaeger) add_library(yt-library-tracing) target_compile_options(yt-library-tracing PRIVATE diff --git a/yt/yt/library/tracing/CMakeLists.linux-x86_64.txt b/yt/yt/library/tracing/CMakeLists.linux-x86_64.txt index 923b49e3e7..5681c6a06e 100644 --- a/yt/yt/library/tracing/CMakeLists.linux-x86_64.txt +++ b/yt/yt/library/tracing/CMakeLists.linux-x86_64.txt @@ -6,6 +6,7 @@ # original buildsystem will not be accepted. +add_subdirectory(jaeger) add_library(yt-library-tracing) target_compile_options(yt-library-tracing PRIVATE diff --git a/yt/yt/library/tracing/CMakeLists.windows-x86_64.txt b/yt/yt/library/tracing/CMakeLists.windows-x86_64.txt index e33cbed391..0adf7dc67a 100644 --- a/yt/yt/library/tracing/CMakeLists.windows-x86_64.txt +++ b/yt/yt/library/tracing/CMakeLists.windows-x86_64.txt @@ -6,6 +6,7 @@ # original buildsystem will not be accepted. +add_subdirectory(jaeger) add_library(yt-library-tracing) target_link_libraries(yt-library-tracing PUBLIC diff --git a/yt/yt/library/tracing/jaeger/CMakeLists.darwin-x86_64.txt b/yt/yt/library/tracing/jaeger/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..ebc07421e3 --- /dev/null +++ b/yt/yt/library/tracing/jaeger/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,67 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-tracing-jaeger) +target_compile_options(library-tracing-jaeger PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-tracing-jaeger PUBLIC + contrib-libs-cxxsupp + yutil + yt-library-tracing + core-rpc-grpc + contrib-libs-protobuf +) +target_proto_messages(library-tracing-jaeger PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/model.proto +) +target_sources(library-tracing-jaeger PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/sampler.cpp +) +target_proto_addincls(library-tracing-jaeger + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-tracing-jaeger + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) + +add_global_library_for(library-tracing-jaeger.global library-tracing-jaeger) +target_compile_options(library-tracing-jaeger.global PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-tracing-jaeger.global PUBLIC + contrib-libs-cxxsupp + yutil + yt-library-tracing + core-rpc-grpc + contrib-libs-protobuf +) +target_sources(library-tracing-jaeger.global PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/tracer.cpp +) diff --git a/yt/yt/library/tracing/jaeger/CMakeLists.linux-aarch64.txt b/yt/yt/library/tracing/jaeger/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..0946a55b6c --- /dev/null +++ b/yt/yt/library/tracing/jaeger/CMakeLists.linux-aarch64.txt @@ -0,0 +1,69 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-tracing-jaeger) +target_compile_options(library-tracing-jaeger PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-tracing-jaeger PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-library-tracing + core-rpc-grpc + contrib-libs-protobuf +) +target_proto_messages(library-tracing-jaeger PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/model.proto +) +target_sources(library-tracing-jaeger PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/sampler.cpp +) +target_proto_addincls(library-tracing-jaeger + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-tracing-jaeger + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) + +add_global_library_for(library-tracing-jaeger.global library-tracing-jaeger) +target_compile_options(library-tracing-jaeger.global PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-tracing-jaeger.global PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-library-tracing + core-rpc-grpc + contrib-libs-protobuf +) +target_sources(library-tracing-jaeger.global PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/tracer.cpp +) diff --git a/yt/yt/library/tracing/jaeger/CMakeLists.linux-x86_64.txt b/yt/yt/library/tracing/jaeger/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..0946a55b6c --- /dev/null +++ b/yt/yt/library/tracing/jaeger/CMakeLists.linux-x86_64.txt @@ -0,0 +1,69 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-tracing-jaeger) +target_compile_options(library-tracing-jaeger PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-tracing-jaeger PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-library-tracing + core-rpc-grpc + contrib-libs-protobuf +) +target_proto_messages(library-tracing-jaeger PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/model.proto +) +target_sources(library-tracing-jaeger PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/sampler.cpp +) +target_proto_addincls(library-tracing-jaeger + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-tracing-jaeger + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) + +add_global_library_for(library-tracing-jaeger.global library-tracing-jaeger) +target_compile_options(library-tracing-jaeger.global PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-tracing-jaeger.global PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + yt-library-tracing + core-rpc-grpc + contrib-libs-protobuf +) +target_sources(library-tracing-jaeger.global PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/tracer.cpp +) diff --git a/yt/yt/library/tracing/jaeger/CMakeLists.txt b/yt/yt/library/tracing/jaeger/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/library/tracing/jaeger/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/tracing/jaeger/CMakeLists.windows-x86_64.txt b/yt/yt/library/tracing/jaeger/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..e6ee22e85b --- /dev/null +++ b/yt/yt/library/tracing/jaeger/CMakeLists.windows-x86_64.txt @@ -0,0 +1,61 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-tracing-jaeger) +target_link_libraries(library-tracing-jaeger PUBLIC + contrib-libs-cxxsupp + yutil + yt-library-tracing + core-rpc-grpc + contrib-libs-protobuf +) +target_proto_messages(library-tracing-jaeger PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/model.proto +) +target_sources(library-tracing-jaeger PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/sampler.cpp +) +target_proto_addincls(library-tracing-jaeger + ./ + ${CMAKE_SOURCE_DIR}/ + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-tracing-jaeger + --cpp_out=${CMAKE_BINARY_DIR}/ + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/ +) + +add_global_library_for(library-tracing-jaeger.global library-tracing-jaeger) +target_link_libraries(library-tracing-jaeger.global PUBLIC + contrib-libs-cxxsupp + yutil + yt-library-tracing + core-rpc-grpc + contrib-libs-protobuf +) +target_sources(library-tracing-jaeger.global PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/tracing/jaeger/tracer.cpp +) diff --git a/yt/yt/library/ytprof/CMakeLists.darwin-x86_64.txt b/yt/yt/library/ytprof/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..ac3d970f12 --- /dev/null +++ b/yt/yt/library/ytprof/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,41 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(api) +add_subdirectory(proto) + +add_library(yt-library-ytprof) +target_compile_options(yt-library-ytprof PRIVATE + -Wdeprecated-this-capture + -DYTPROF_BUILD_TYPE="RELEASE" +) +target_link_libraries(yt-library-ytprof PUBLIC + contrib-libs-cxxsupp + yutil + cpp-yt-memory + cpp-yt-threading + backtrace-cursors-interop + backtrace-cursors-frame_pointer + backtrace-cursors-libunwind + library-ytprof-api + library-ytprof-proto + contrib-libs-libunwind + libs-tcmalloc-malloc_extension + library-cpp-svnversion + yt-yt-core +) +target_sources(yt-library-ytprof PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/signal_safe_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/cpu_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/heap_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/spinlock_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/profile.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/build_info.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/external_pprof.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/symbolize_other.cpp +) diff --git a/yt/yt/library/ytprof/CMakeLists.linux-aarch64.txt b/yt/yt/library/ytprof/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..c015a8d3da --- /dev/null +++ b/yt/yt/library/ytprof/CMakeLists.linux-aarch64.txt @@ -0,0 +1,43 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(api) +add_subdirectory(http) +add_subdirectory(proto) + +add_library(yt-library-ytprof) +target_compile_options(yt-library-ytprof PRIVATE + -Wdeprecated-this-capture + -DYTPROF_BUILD_TYPE="RELEASE" +) +target_link_libraries(yt-library-ytprof PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + cpp-yt-memory + cpp-yt-threading + backtrace-cursors-interop + backtrace-cursors-frame_pointer + backtrace-cursors-libunwind + library-ytprof-api + library-ytprof-proto + contrib-libs-libunwind + libs-tcmalloc-malloc_extension + library-cpp-svnversion + yt-yt-core +) +target_sources(yt-library-ytprof PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/signal_safe_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/cpu_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/heap_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/spinlock_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/profile.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/build_info.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/external_pprof.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/symbolize.cpp +) diff --git a/yt/yt/library/ytprof/CMakeLists.linux-x86_64.txt b/yt/yt/library/ytprof/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..c015a8d3da --- /dev/null +++ b/yt/yt/library/ytprof/CMakeLists.linux-x86_64.txt @@ -0,0 +1,43 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(api) +add_subdirectory(http) +add_subdirectory(proto) + +add_library(yt-library-ytprof) +target_compile_options(yt-library-ytprof PRIVATE + -Wdeprecated-this-capture + -DYTPROF_BUILD_TYPE="RELEASE" +) +target_link_libraries(yt-library-ytprof PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + cpp-yt-memory + cpp-yt-threading + backtrace-cursors-interop + backtrace-cursors-frame_pointer + backtrace-cursors-libunwind + library-ytprof-api + library-ytprof-proto + contrib-libs-libunwind + libs-tcmalloc-malloc_extension + library-cpp-svnversion + yt-yt-core +) +target_sources(yt-library-ytprof PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/signal_safe_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/cpu_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/heap_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/spinlock_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/profile.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/build_info.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/external_pprof.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/symbolize.cpp +) diff --git a/yt/yt/library/ytprof/CMakeLists.txt b/yt/yt/library/ytprof/CMakeLists.txt index dbfb934bae..f8b31df0c1 100644 --- a/yt/yt/library/ytprof/CMakeLists.txt +++ b/yt/yt/library/ytprof/CMakeLists.txt @@ -6,4 +6,12 @@ # original buildsystem will not be accepted. -add_subdirectory(api) +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/ytprof/CMakeLists.windows-x86_64.txt b/yt/yt/library/ytprof/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..ec5652c9e9 --- /dev/null +++ b/yt/yt/library/ytprof/CMakeLists.windows-x86_64.txt @@ -0,0 +1,40 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +add_subdirectory(api) +add_subdirectory(proto) + +add_library(yt-library-ytprof) +target_compile_options(yt-library-ytprof PRIVATE + -DYTPROF_BUILD_TYPE="RELEASE" +) +target_link_libraries(yt-library-ytprof PUBLIC + contrib-libs-cxxsupp + yutil + cpp-yt-memory + cpp-yt-threading + backtrace-cursors-interop + backtrace-cursors-frame_pointer + backtrace-cursors-libunwind + library-ytprof-api + library-ytprof-proto + contrib-libs-libunwind + libs-tcmalloc-malloc_extension + library-cpp-svnversion + yt-yt-core +) +target_sources(yt-library-ytprof PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/signal_safe_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/cpu_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/heap_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/spinlock_profiler.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/profile.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/build_info.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/external_pprof.cpp + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/symbolize_other.cpp +) diff --git a/yt/yt/library/ytprof/bundle/ya.make b/yt/yt/library/ytprof/bundle/ya.make new file mode 100644 index 0000000000..7f88583494 --- /dev/null +++ b/yt/yt/library/ytprof/bundle/ya.make @@ -0,0 +1,28 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +# Built with ya make -DNO_DEBUGINFO=yes -r --musl contrib/libs/llvm12/tools/llvm-symbolizer +FROM_SANDBOX( + FILE 2531143113 + OUT_NOAUTO llvm-symbolizer +) + +RESOURCE( + yt/yt/library/ytprof/bundle/llvm-symbolizer + /ytprof/llvm-symbolizer +) + +# Built with env CGO_ENABLED=0 ya tool go install github.com/google/pprof@latest +FROM_SANDBOX( + FILE 2531135322 + OUT_NOAUTO pprof +) + +RESOURCE( + yt/yt/library/ytprof/bundle/pprof + /ytprof/pprof +) + +END() + diff --git a/yt/yt/library/ytprof/example/main.cpp b/yt/yt/library/ytprof/example/main.cpp new file mode 100644 index 0000000000..bc9dec690a --- /dev/null +++ b/yt/yt/library/ytprof/example/main.cpp @@ -0,0 +1,70 @@ +#include <yt/yt/core/concurrency/poller.h> +#include <yt/yt/core/concurrency/thread_pool_poller.h> +#include <yt/yt/core/concurrency/action_queue.h> +#include <yt/yt/core/concurrency/thread_pool.h> +#include <yt/yt/core/http/server.h> + +#include <yt/yt/library/ytprof/http/handler.h> +#include <yt/yt/library/ytprof/heap_profiler.h> + +#include <absl/debugging/stacktrace.h> + +using namespace NYT; +using namespace NYT::NHttp; +using namespace NYT::NConcurrency; +using namespace NYT::NYTProf; + +int main(int argc, char* argv[]) +{ + absl::SetStackUnwinder(AbslStackUnwinder); + tcmalloc::MallocExtension::SetProfileSamplingRate(2_MB); + + try { + if (argc != 2 && argc != 3) { + throw yexception() << "usage: " << argv[0] << " PORT"; + } + + auto port = FromString<int>(argv[1]); + auto poller = CreateThreadPoolPoller(1, "Example"); + auto server = CreateServer(port, poller); + + Register(server, ""); + server->Start(); + + THashMap<TString, std::vector<int>> data; + for (int i = 0; i < 1024 * 16; i++) { + data[ToString(i)].resize(1024); + } + + auto burnCpu = [] { + ui64 value = 0; + while (true) { + THash<TString> hasher; + for (int i = 0; i < 10000000; i++) { + value += hasher(ToString(i)); + } + + std::vector<TString> data; + for (int i = 0; i < 10000; i++) { + data.push_back(TString(1024, 'x')); + } + + if (value == 1) { + Sleep(TDuration::Seconds(1)); + } + } + }; + + auto pool = CreateThreadPool(64, "Pool"); + for (int i = 0; i < 64; i++) { + pool->GetInvoker()->Invoke(BIND(burnCpu)); + } + + burnCpu(); + } catch (const std::exception& ex) { + Cerr << ex.what() << Endl; + _exit(1); + } + + return 0; +} diff --git a/yt/yt/library/ytprof/example/ya.make b/yt/yt/library/ytprof/example/ya.make new file mode 100644 index 0000000000..2f79922400 --- /dev/null +++ b/yt/yt/library/ytprof/example/ya.make @@ -0,0 +1,19 @@ +PROGRAM(ytprof-example) + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +IF (OS_LINUX) + ALLOCATOR(TCMALLOC_256K) +ENDIF() + +SRCS(main.cpp) + +IF (OS_LINUX) + LDFLAGS("-Wl,--build-id=sha1") +ENDIF() + +PEERDIR( + yt/yt/library/ytprof/http +) + +END() diff --git a/yt/yt/library/ytprof/http/CMakeLists.linux-aarch64.txt b/yt/yt/library/ytprof/http/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..334ca90c02 --- /dev/null +++ b/yt/yt/library/ytprof/http/CMakeLists.linux-aarch64.txt @@ -0,0 +1,25 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-ytprof-http) +target_compile_options(library-ytprof-http PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-ytprof-http PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + library-cpp-cgiparam + yt-core-http + yt-library-ytprof + yt-library-process +) +target_sources(library-ytprof-http PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/http/handler.cpp +) diff --git a/yt/yt/library/ytprof/http/CMakeLists.linux-x86_64.txt b/yt/yt/library/ytprof/http/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..334ca90c02 --- /dev/null +++ b/yt/yt/library/ytprof/http/CMakeLists.linux-x86_64.txt @@ -0,0 +1,25 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_library(library-ytprof-http) +target_compile_options(library-ytprof-http PRIVATE + -Wdeprecated-this-capture +) +target_link_libraries(library-ytprof-http PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + library-cpp-cgiparam + yt-core-http + yt-library-ytprof + yt-library-process +) +target_sources(library-ytprof-http PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/http/handler.cpp +) diff --git a/yt/yt/library/ytprof/http/CMakeLists.txt b/yt/yt/library/ytprof/http/CMakeLists.txt new file mode 100644 index 0000000000..4d48dcdee6 --- /dev/null +++ b/yt/yt/library/ytprof/http/CMakeLists.txt @@ -0,0 +1,13 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/ytprof/http/handler.cpp b/yt/yt/library/ytprof/http/handler.cpp new file mode 100644 index 0000000000..382ffc1fec --- /dev/null +++ b/yt/yt/library/ytprof/http/handler.cpp @@ -0,0 +1,311 @@ +#include "handler.h" + +#include <yt/yt/core/concurrency/async_stream.h> + +#include <yt/yt/core/http/http.h> +#include <yt/yt/core/http/server.h> + +#include <yt/yt/library/ytprof/cpu_profiler.h> +#include <yt/yt/library/ytprof/spinlock_profiler.h> +#include <yt/yt/library/ytprof/heap_profiler.h> +#include <yt/yt/library/ytprof/profile.h> +#include <yt/yt/library/ytprof/symbolize.h> +#include <yt/yt/library/ytprof/external_pprof.h> + +#include <yt/yt/library/process/subprocess.h> + +#include <yt/yt/core/misc/finally.h> + +#include <library/cpp/cgiparam/cgiparam.h> + +#include <util/system/mutex.h> + +namespace NYT::NYTProf { + +using namespace NHttp; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +class TBaseHandler + : public IHttpHandler +{ +public: + explicit TBaseHandler(const TBuildInfo& buildInfo) + : BuildInfo_(buildInfo) + { } + + virtual NProto::Profile BuildProfile(const TCgiParameters& params) = 0; + + void HandleRequest(const IRequestPtr& req, const IResponseWriterPtr& rsp) override + { + try { + TTryGuard guard(Lock_); + if (!guard) { + rsp->SetStatus(EStatusCode::TooManyRequests); + WaitFor(rsp->WriteBody(TSharedRef::FromString("Profile fetch already running"))) + .ThrowOnError(); + return; + } + + TCgiParameters params(req->GetUrl().RawQuery); + auto profile = BuildProfile(params); + Symbolize(&profile, true); + AddBuildInfo(&profile, BuildInfo_); + + if (auto it = params.Find("symbolize"); it == params.end() || it->second != "0") { + SymbolizeByExternalPProf(&profile, TSymbolizationOptions{ + .RunTool = RunSubprocess, + }); + } + + TStringStream profileBlob; + WriteProfile(&profileBlob, profile); + + rsp->SetStatus(EStatusCode::OK); + WaitFor(rsp->WriteBody(TSharedRef::FromString(profileBlob.Str()))) + .ThrowOnError(); + } catch (const std::exception& ex) { + if (rsp->AreHeadersFlushed()) { + throw; + } + + rsp->SetStatus(EStatusCode::InternalServerError); + WaitFor(rsp->WriteBody(TSharedRef::FromString(ex.what()))) + .ThrowOnError(); + + throw; + } + } + +protected: + const TBuildInfo BuildInfo_; + +private: + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, Lock_); +}; + +class TCpuProfilerHandler + : public TBaseHandler +{ +public: + using TBaseHandler::TBaseHandler; + + NProto::Profile BuildProfile(const TCgiParameters& params) override + { + auto duration = TDuration::Seconds(15); + if (auto it = params.Find("d"); it != params.end()) { + duration = TDuration::Parse(it->second); + } + + TCpuProfilerOptions options; + if (auto it = params.Find("freq"); it != params.end()) { + options.SamplingFrequency = FromString<int>(it->second); + } + + if (auto it = params.Find("record_action_run_time"); it != params.end()) { + options.RecordActionRunTime = true; + } + + if (auto it = params.Find("action_min_exec_time"); it != params.end()) { + options.SampleFilters.push_back(GetActionMinExecTimeFilter(TDuration::Parse(it->second))); + } + + TCpuProfiler profiler{options}; + profiler.Start(); + TDelayedExecutor::WaitForDuration(duration); + profiler.Stop(); + + return profiler.ReadProfile(); + } +}; + +class TSpinlockProfilerHandler + : public TBaseHandler +{ +public: + TSpinlockProfilerHandler(const TBuildInfo& buildInfo, bool yt) + : TBaseHandler(buildInfo) + , YT_(yt) + { } + + NProto::Profile BuildProfile(const TCgiParameters& params) override + { + auto duration = TDuration::Seconds(15); + if (auto it = params.Find("d"); it != params.end()) { + duration = TDuration::Parse(it->second); + } + + TSpinlockProfilerOptions options; + if (auto it = params.Find("frac"); it != params.end()) { + options.ProfileFraction = FromString<int>(it->second); + } + + if (YT_) { + TBlockingProfiler profiler{options}; + profiler.Start(); + TDelayedExecutor::WaitForDuration(duration); + profiler.Stop(); + + return profiler.ReadProfile(); + } else { + TSpinlockProfiler profiler{options}; + profiler.Start(); + TDelayedExecutor::WaitForDuration(duration); + profiler.Stop(); + + return profiler.ReadProfile(); + } + } + +private: + const bool YT_; +}; + +class TTCMallocSnapshotProfilerHandler + : public TBaseHandler +{ +public: + TTCMallocSnapshotProfilerHandler(const TBuildInfo& buildInfo, tcmalloc::ProfileType profileType) + : TBaseHandler(buildInfo) + , ProfileType_(profileType) + { } + + NProto::Profile BuildProfile(const TCgiParameters& /*params*/) override + { + return ReadHeapProfile(ProfileType_); + } + +private: + tcmalloc::ProfileType ProfileType_; +}; + +class TTCMallocAllocationProfilerHandler + : public TBaseHandler +{ +public: + using TBaseHandler::TBaseHandler; + + NProto::Profile BuildProfile(const TCgiParameters& params) override + { + auto duration = TDuration::Seconds(15); + if (auto it = params.Find("d"); it != params.end()) { + duration = TDuration::Parse(it->second); + } + + auto token = tcmalloc::MallocExtension::StartAllocationProfiling(); + TDelayedExecutor::WaitForDuration(duration); + return ConvertAllocationProfile(std::move(token).Stop()); + } +}; + +class TTCMallocStatHandler + : public IHttpHandler +{ +public: + void HandleRequest(const IRequestPtr& /* req */, const IResponseWriterPtr& rsp) override + { + auto stat = tcmalloc::MallocExtension::GetStats(); + rsp->SetStatus(EStatusCode::OK); + WaitFor(rsp->WriteBody(TSharedRef::FromString(TString{stat}))) + .ThrowOnError(); + } +}; + +class TBinaryHandler + : public IHttpHandler +{ +public: + void HandleRequest(const IRequestPtr& req, const IResponseWriterPtr& rsp) override + { + try { + auto buildId = GetBuildId(); + TCgiParameters params(req->GetUrl().RawQuery); + + if (auto it = params.Find("check_build_id"); it != params.end()) { + if (it->second != buildId) { + THROW_ERROR_EXCEPTION("Wrong build id: %v != %v", it->second, buildId); + } + } + + rsp->SetStatus(EStatusCode::OK); + + TFileInput file{"/proc/self/exe"}; + auto adapter = CreateBufferedSyncAdapter(rsp); + file.ReadAll(*adapter); + adapter->Finish(); + + WaitFor(rsp->Close()) + .ThrowOnError(); + } catch (const std::exception& ex) { + if (rsp->AreHeadersFlushed()) { + throw; + } + + rsp->SetStatus(EStatusCode::InternalServerError); + WaitFor(rsp->WriteBody(TSharedRef::FromString(ex.what()))) + .ThrowOnError(); + + throw; + } + } +}; + +class TVersionHandler + : public IHttpHandler +{ +public: + void HandleRequest(const IRequestPtr& /* req */, const IResponseWriterPtr& rsp) override + { + rsp->SetStatus(EStatusCode::OK); + WaitFor(rsp->WriteBody(TSharedRef::FromString(GetVersion()))) + .ThrowOnError(); + } +}; + +class TBuildIdHandler + : public IHttpHandler +{ +public: + void HandleRequest(const IRequestPtr& /* req */, const IResponseWriterPtr& rsp) override + { + rsp->SetStatus(EStatusCode::OK); + WaitFor(rsp->WriteBody(TSharedRef::FromString(GetVersion()))) + .ThrowOnError(); + } +}; + +void Register( + const NHttp::IServerPtr& server, + const TString& prefix, + const TBuildInfo& buildInfo) +{ + Register(server->GetPathMatcher(), prefix, buildInfo); +} + +void Register( + const IRequestPathMatcherPtr& handlers, + const TString& prefix, + const TBuildInfo& buildInfo) +{ + handlers->Add(prefix + "/profile", New<TCpuProfilerHandler>(buildInfo)); + + handlers->Add(prefix + "/lock", New<TSpinlockProfilerHandler>(buildInfo, false)); + handlers->Add(prefix + "/block", New<TSpinlockProfilerHandler>(buildInfo, true)); + + handlers->Add(prefix + "/heap", New<TTCMallocSnapshotProfilerHandler>(buildInfo, tcmalloc::ProfileType::kHeap)); + handlers->Add(prefix + "/peak", New<TTCMallocSnapshotProfilerHandler>(buildInfo, tcmalloc::ProfileType::kPeakHeap)); + handlers->Add(prefix + "/fragmentation", New<TTCMallocSnapshotProfilerHandler>(buildInfo, tcmalloc::ProfileType::kFragmentation)); + handlers->Add(prefix + "/allocations", New<TTCMallocAllocationProfilerHandler>(buildInfo)); + + handlers->Add(prefix + "/tcmalloc", New<TTCMallocStatHandler>()); + + handlers->Add(prefix + "/binary", New<TBinaryHandler>()); + + handlers->Add(prefix + "/version", New<TVersionHandler>()); + handlers->Add(prefix + "/buildid", New<TBuildIdHandler>()); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NYTProf diff --git a/yt/yt/library/ytprof/http/handler.h b/yt/yt/library/ytprof/http/handler.h new file mode 100644 index 0000000000..fa96412d95 --- /dev/null +++ b/yt/yt/library/ytprof/http/handler.h @@ -0,0 +1,24 @@ +#pragma once + +#include <yt/yt/core/http/public.h> + +#include <yt/yt/library/ytprof/build_info.h> + +namespace NYT::NYTProf { + +//////////////////////////////////////////////////////////////////////////////// + +//! Register profiling handlers. +void Register( + const NHttp::IServerPtr& server, + const TString& prefix, + const TBuildInfo& buildInfo = TBuildInfo::GetDefault()); + +void Register( + const NHttp::IRequestPathMatcherPtr& handlers, + const TString& prefix, + const TBuildInfo& buildInfo = TBuildInfo::GetDefault()); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NYTProf diff --git a/yt/yt/library/ytprof/http/ya.make b/yt/yt/library/ytprof/http/ya.make new file mode 100644 index 0000000000..1a1f3ff20c --- /dev/null +++ b/yt/yt/library/ytprof/http/ya.make @@ -0,0 +1,16 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + handler.cpp +) + +PEERDIR( + library/cpp/cgiparam + yt/yt/core/http + yt/yt/library/ytprof + yt/yt/library/process +) + +END() diff --git a/yt/yt/library/ytprof/integration/test_http.py b/yt/yt/library/ytprof/integration/test_http.py new file mode 100644 index 0000000000..2135d8daf9 --- /dev/null +++ b/yt/yt/library/ytprof/integration/test_http.py @@ -0,0 +1,113 @@ +import pytest +import requests +import time + +import asyncio +import httpx + +import yatest.common +import yatest.common.network + + +TIMEOUT = 5000 + + +@pytest.fixture(scope="session") +def running_example(): + with yatest.common.network.PortManager() as pm: + port = pm.get_port() + + cmd = [ + yatest.common.binary_path("yt/yt/library/ytprof/example/ytprof-example"), + str(port) + ] + + p = yatest.common.execute(cmd, wait=False, env={"YT_LOG_LEVEL": "DEBUG"}) + time.sleep(1) + assert p.running + + try: + yield {"port": port} + finally: + p.kill() + + +def fetch_data(running_example, name): + rsp = requests.get(f"http://localhost:{running_example['port']}/{name}") + if rsp.status_code == 200: + return rsp.content + + if rsp.status_code == 500: + raise Exception(rsp.text) + + rsp.raise_for_status() + + +def test_smoke_tcmalloc(running_example): + fetch_data(running_example, "heap") + fetch_data(running_example, "allocations?d=1") + fetch_data(running_example, "peak") + fetch_data(running_example, "fragmentation") + + +async def get_async(url): + async with httpx.AsyncClient() as client: + return await client.get(url, timeout=TIMEOUT) + + +async def launch(running_example, name): + url = f"http://localhost:{running_example['port']}/{name}" + + urls = [url, url, url] + + resps = await asyncio.gather(*map(get_async, urls)) + data = [resp.status_code for resp in resps] + + assert data == [200, 429, 429] + + +def test_async(running_example): + fetch_data(running_example, "heap") + asyncio.run(launch(running_example, "heap")) + fetch_data(running_example, "heap") + + +def test_status_handlers(running_example): + assert fetch_data(running_example, "buildid") + assert fetch_data(running_example, "version") + + +def test_cpu_profile(running_example): + if yatest.common.context.build_type != "profile": + pytest.skip() + + fetch_data(running_example, "profile?d=1") + fetch_data(running_example, "profile?d=1&freq=1000") + + +def test_spinlock_profile(running_example): + if yatest.common.context.build_type != "profile": + pytest.skip() + + fetch_data(running_example, "lock?d=1") + fetch_data(running_example, "lock?d=1&frac=1") + + +def test_block_profile(running_example): + if yatest.common.context.build_type != "profile": + pytest.skip() + + fetch_data(running_example, "block?d=1") + fetch_data(running_example, "block?d=1&frac=1") + + +def test_binary_handler(running_example): + binary = fetch_data(running_example, "binary") + + with open(yatest.common.binary_path("yt/yt/library/ytprof/example/ytprof-example"), "rb") as f: + real_binary = f.read() + + assert binary == real_binary + + with pytest.raises(Exception): + fetch_data(running_example, "binary?check_build_id=1234") diff --git a/yt/yt/library/ytprof/integration/ya.make b/yt/yt/library/ytprof/integration/ya.make new file mode 100644 index 0000000000..889c25c0e1 --- /dev/null +++ b/yt/yt/library/ytprof/integration/ya.make @@ -0,0 +1,20 @@ +PY3TEST() + +SIZE(MEDIUM) + +INCLUDE(${ARCADIA_ROOT}/yt/opensource_tests.inc) + +PEERDIR( + contrib/python/requests + contrib/python/httpx +) + +TEST_SRCS( + test_http.py +) + +DEPENDS( + yt/yt/library/ytprof/example +) + +END() diff --git a/yt/yt/library/ytprof/proto/CMakeLists.darwin-x86_64.txt b/yt/yt/library/ytprof/proto/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..2a3265d317 --- /dev/null +++ b/yt/yt/library/ytprof/proto/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,47 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-ytprof-proto) +target_include_directories(library-ytprof-proto PUBLIC + ${CMAKE_BINARY_DIR}/yt +) +target_link_libraries(library-ytprof-proto PUBLIC + contrib-libs-cxxsupp + yutil + contrib-libs-protobuf +) +target_proto_messages(library-ytprof-proto PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/proto/profile.proto +) +target_proto_addincls(library-ytprof-proto + ./yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-ytprof-proto + --cpp_out=${CMAKE_BINARY_DIR}/yt + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/yt +) diff --git a/yt/yt/library/ytprof/proto/CMakeLists.linux-aarch64.txt b/yt/yt/library/ytprof/proto/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..d16987cc1d --- /dev/null +++ b/yt/yt/library/ytprof/proto/CMakeLists.linux-aarch64.txt @@ -0,0 +1,48 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-ytprof-proto) +target_include_directories(library-ytprof-proto PUBLIC + ${CMAKE_BINARY_DIR}/yt +) +target_link_libraries(library-ytprof-proto PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + contrib-libs-protobuf +) +target_proto_messages(library-ytprof-proto PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/proto/profile.proto +) +target_proto_addincls(library-ytprof-proto + ./yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-ytprof-proto + --cpp_out=${CMAKE_BINARY_DIR}/yt + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/yt +) diff --git a/yt/yt/library/ytprof/proto/CMakeLists.linux-x86_64.txt b/yt/yt/library/ytprof/proto/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..d16987cc1d --- /dev/null +++ b/yt/yt/library/ytprof/proto/CMakeLists.linux-x86_64.txt @@ -0,0 +1,48 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-ytprof-proto) +target_include_directories(library-ytprof-proto PUBLIC + ${CMAKE_BINARY_DIR}/yt +) +target_link_libraries(library-ytprof-proto PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + contrib-libs-protobuf +) +target_proto_messages(library-ytprof-proto PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/proto/profile.proto +) +target_proto_addincls(library-ytprof-proto + ./yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-ytprof-proto + --cpp_out=${CMAKE_BINARY_DIR}/yt + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/yt +) diff --git a/yt/yt/library/ytprof/proto/CMakeLists.txt b/yt/yt/library/ytprof/proto/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/yt/yt/library/ytprof/proto/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/yt/yt/library/ytprof/proto/CMakeLists.windows-x86_64.txt b/yt/yt/library/ytprof/proto/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..2a3265d317 --- /dev/null +++ b/yt/yt/library/ytprof/proto/CMakeLists.windows-x86_64.txt @@ -0,0 +1,47 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +get_built_tool_path( + TOOL_protoc_bin + TOOL_protoc_dependency + contrib/tools/protoc/bin + protoc +) +get_built_tool_path( + TOOL_cpp_styleguide_bin + TOOL_cpp_styleguide_dependency + contrib/tools/protoc/plugins/cpp_styleguide + cpp_styleguide +) + +add_library(library-ytprof-proto) +target_include_directories(library-ytprof-proto PUBLIC + ${CMAKE_BINARY_DIR}/yt +) +target_link_libraries(library-ytprof-proto PUBLIC + contrib-libs-cxxsupp + yutil + contrib-libs-protobuf +) +target_proto_messages(library-ytprof-proto PRIVATE + ${CMAKE_SOURCE_DIR}/yt/yt/library/ytprof/proto/profile.proto +) +target_proto_addincls(library-ytprof-proto + ./yt + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/yt + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src + ${CMAKE_BINARY_DIR} + ${CMAKE_SOURCE_DIR}/contrib/libs/protobuf/src +) +target_proto_outs(library-ytprof-proto + --cpp_out=${CMAKE_BINARY_DIR}/yt + --cpp_styleguide_out=${CMAKE_BINARY_DIR}/yt +) diff --git a/yt/yt/library/ytprof/unittests/cpu_profiler_ut.cpp b/yt/yt/library/ytprof/unittests/cpu_profiler_ut.cpp new file mode 100644 index 0000000000..3edf6a6f89 --- /dev/null +++ b/yt/yt/library/ytprof/unittests/cpu_profiler_ut.cpp @@ -0,0 +1,343 @@ +#include <dlfcn.h> + +#include <gtest/gtest.h> + +#include <library/cpp/testing/common/env.h> + +#include <library/cpp/yt/memory/new.h> + +#include <yt/yt/core/concurrency/action_queue.h> +#include <yt/yt/core/concurrency/scheduler_api.h> + +#include <yt/yt/core/actions/bind.h> + +#include <yt/yt/core/tracing/trace_context.h> + +#include <yt/yt/library/ytprof/cpu_profiler.h> +#include <yt/yt/library/ytprof/symbolize.h> +#include <yt/yt/library/ytprof/profile.h> +#include <yt/yt/library/ytprof/external_pprof.h> + +#include <util/string/cast.h> +#include <util/stream/file.h> +#include <util/datetime/base.h> +#include <util/system/shellcommand.h> + +#include <yt/yt/core/concurrency/thread_pool.h> + +namespace NYT::NYTProf { +namespace { + +using namespace NConcurrency; +using namespace NTracing; + +using TSampleFilter = TCpuProfilerOptions::TSampleFilter; + +//////////////////////////////////////////////////////////////////////////////// + +template <size_t Index> +Y_NO_INLINE void BurnCpu() +{ + THash<TString> hasher; + ui64 value = 0; + for (int i = 0; i < 10000000; i++) { + value += hasher(ToString(i)); + } + EXPECT_NE(Index, value); +} + +static std::atomic<int> Counter{0}; + +struct TNoTailCall +{ + ~TNoTailCall() + { + Counter++; + } +}; + +static Y_NO_INLINE void StaticFunction() +{ + TNoTailCall noTail; + BurnCpu<0>(); +} + +void RunUnderProfiler( + const TString& name, + std::function<void()> work, + bool checkSamples = true, + const std::vector<TSampleFilter>& filters = {}, + bool expectEmpty = false) +{ + TCpuProfilerOptions options; + options.SampleFilters = filters; + options.SamplingFrequency = 100000; + options.RecordActionRunTime = true; + +#ifdef YTPROF_DEBUG_BUILD + options.SamplingFrequency = 100; +#endif + + TCpuProfiler profiler(options); + + profiler.Start(); + + work(); + + profiler.Stop(); + + auto profile = profiler.ReadProfile(); + if (checkSamples) { + ASSERT_EQ(expectEmpty, profile.sampleSize() == 0); + } + + Symbolize(&profile, true); + AddBuildInfo(&profile, TBuildInfo::GetDefault()); + SymbolizeByExternalPProf(&profile, TSymbolizationOptions{ + .TmpDir = GetOutputPath(), + .KeepTmpDir = true, + .RunTool = [] (const std::vector<TString>& args) { + TShellCommand command{args[0], TList<TString>{args.begin()+1, args.end()}}; + command.Run(); + + EXPECT_TRUE(command.GetExitCode() == 0) + << command.GetError(); + }, + }); + + TFileOutput output(GetOutputPath() / name); + WriteProfile(&output, profile); + output.Finish(); +} + +class TCpuProfilerTest + : public ::testing::Test +{ + void SetUp() override + { + if (!IsProfileBuild()) { + GTEST_SKIP() << "rebuild with --build=profile"; + } + } +}; + +TEST_F(TCpuProfilerTest, SingleThreadRun) +{ + RunUnderProfiler("single_thread.pb.gz", [] { + BurnCpu<0>(); + }); +} + +TEST_F(TCpuProfilerTest, MultipleThreads) +{ + RunUnderProfiler("multiple_threads.pb.gz", [] { + std::thread t1([] { + BurnCpu<1>(); + }); + + std::thread t2([] { + BurnCpu<2>(); + }); + + t1.join(); + t2.join(); + }); +} + +TEST_F(TCpuProfilerTest, StaticFunction) +{ + RunUnderProfiler("static_function.pb.gz", [] { + StaticFunction(); + }); +} + +Y_NO_INLINE void RecursiveFunction(int n) +{ + TNoTailCall noTail; + if (n == 0) { + BurnCpu<0>(); + } else { + RecursiveFunction(n-1); + } +} + +TEST_F(TCpuProfilerTest, DeepRecursion) +{ + RunUnderProfiler("recursive_function.pb.gz", [] { + RecursiveFunction(1024); + }); +} + +TEST_F(TCpuProfilerTest, DlOpen) +{ + RunUnderProfiler("dlopen.pb.gz", [] { + auto libraryPath = BinaryPath("yt/yt/library/ytprof/unittests/testso/libtestso.so"); + + auto dl = dlopen(libraryPath.c_str(), RTLD_LAZY); + ASSERT_TRUE(dl); + + auto sym = dlsym(dl, "CallNext"); + ASSERT_TRUE(sym); + + auto callNext = reinterpret_cast<void(*)(void(*)())>(sym); + callNext(BurnCpu<0>); + }); +} + +TEST_F(TCpuProfilerTest, DlClose) +{ + RunUnderProfiler("dlclose.pb.gz", [] { + auto libraryPath = BinaryPath("yt/yt/library/ytprof/unittests/testso1/libtestso1.so"); + + auto dl = dlopen(libraryPath.c_str(), RTLD_LAZY); + ASSERT_TRUE(dl); + + auto sym = dlsym(dl, "CallOtherNext"); + ASSERT_TRUE(sym); + + auto callNext = reinterpret_cast<void(*)(void(*)())>(sym); + callNext(BurnCpu<0>); + + ASSERT_EQ(dlclose(dl), 0); + }); +} + +void ReadUrandom() +{ + TIFStream input("/dev/urandom"); + + std::array<char, 1 << 20> buffer; + + for (int i = 0; i < 100; i++) { + input.Read(buffer.data(), buffer.size()); + } +} + +TEST_F(TCpuProfilerTest, Syscalls) +{ + RunUnderProfiler("syscalls.pb.gz", [] { + ReadUrandom(); + }); +} + +TEST_F(TCpuProfilerTest, VDSO) +{ + RunUnderProfiler("vdso.pb.gz", [] { + auto now = TInstant::Now(); + while (TInstant::Now() < now + TDuration::MilliSeconds(100)) + { } + }, false); +} + +TEST_F(TCpuProfilerTest, ProfilerTags) +{ + auto userTag = New<TProfilerTag>("user", "prime"); + auto intTag = New<TProfilerTag>("block_size", 1024); + + RunUnderProfiler("tags.pb.gz", [&] { + { + TCpuProfilerTagGuard guard(userTag); + BurnCpu<0>(); + } + { + TCpuProfilerTagGuard guard(intTag); + BurnCpu<1>(); + } + { + TCpuProfilerTagGuard guard(userTag); + TCpuProfilerTagGuard secondGuard(intTag); + BurnCpu<2>(); + } + }); +} + +TEST_F(TCpuProfilerTest, MultipleProfilers) +{ + TCpuProfiler profiler, secondProfiler; + + profiler.Start(); + EXPECT_THROW(secondProfiler.Start(), std::exception); +} + +TEST_F(TCpuProfilerTest, TraceContext) +{ + RunUnderProfiler("trace_context.pb.gz", [] { + auto actionQueue = New<TActionQueue>("CpuProfileTest"); + + BIND([] { + auto rootTraceContext = TTraceContext::NewRoot(""); + rootTraceContext->AddProfilingTag("user", "prime"); + TCurrentTraceContextGuard guard(rootTraceContext); + + auto asyncSubrequest = BIND([&] { + TChildTraceContextGuard guard(""); + AnnotateTraceContext([] (const auto traceContext) { + traceContext->AddProfilingTag("table", "//foo"); + }); + + BurnCpu<0>(); + }) + .AsyncVia(GetCurrentInvoker()) + .Run(); + + BurnCpu<1>(); + WaitFor(asyncSubrequest) + .ThrowOnError(); + }) + .AsyncVia(actionQueue->GetInvoker()) + .Run() + .Get(); + + actionQueue->Shutdown(); + }); +} + +TEST_F(TCpuProfilerTest, SlowActions) +{ + static const TString WorkerThreadName = "Heavy"; + static const auto TraceThreshold = TDuration::MilliSeconds(50); + + const std::vector<TSampleFilter> filters = { + GetActionMinExecTimeFilter(TraceThreshold), + }; + + auto busyWait = [] (TDuration duration) { + auto now = TInstant::Now(); + while (TInstant::Now() < now + duration) + { } + }; + + auto traceContext = NTracing::TTraceContext::NewRoot("Test"); + + const bool ExpectEmptyTraces = true; + + auto threadPool = CreateThreadPool(2, WorkerThreadName); + + // No slow actions. + RunUnderProfiler("slow_actions_empty.pb.gz", [&] { + NTracing::TTraceContextGuard guard(traceContext); + auto future = BIND(busyWait, TraceThreshold / 2) + .AsyncVia(threadPool->GetInvoker()) + .Run(); + future.Get(); + }, + true, + filters, + ExpectEmptyTraces); + + // Slow actions = non empty traces. + RunUnderProfiler("slow_actions.pb.gz", [&] { + NTracing::TTraceContextGuard guard(traceContext); + auto future = BIND(busyWait, TraceThreshold * 3) + .AsyncVia(threadPool->GetInvoker()) + .Run(); + future.Get(); + }, + true, + filters); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT::NYTProf diff --git a/yt/yt/library/ytprof/unittests/heap_profiler_ut.cpp b/yt/yt/library/ytprof/unittests/heap_profiler_ut.cpp new file mode 100644 index 0000000000..42dcf6e802 --- /dev/null +++ b/yt/yt/library/ytprof/unittests/heap_profiler_ut.cpp @@ -0,0 +1,226 @@ +#include <gtest/gtest.h> + +#include <yt/yt/library/ytprof/heap_profiler.h> +#include <yt/yt/library/ytprof/symbolize.h> +#include <yt/yt/library/ytprof/profile.h> + +#include <yt/yt/core/actions/current_invoker.h> +#include <yt/yt/core/actions/invoker_detail.h> + +#include <yt/yt/core/concurrency/action_queue.h> + +#include <yt/yt/core/misc/lazy_ptr.h> + +#include <yt/yt/core/tracing/allocation_tags.h> +#include <yt/yt/core/tracing/trace_context.h> + +#include <library/cpp/testing/common/env.h> + +#include <util/string/cast.h> +#include <util/stream/file.h> +#include <util/generic/hash_set.h> +#include <util/datetime/base.h> +#include <util/generic/size_literals.h> + +#include <tcmalloc/common.h> + +#include <absl/debugging/stacktrace.h> + +namespace NYT::NYTProf { +namespace { + +using namespace NTracing; + +//////////////////////////////////////////////////////////////////////////////// + +constexpr auto MemoryAllocationTag = "memory_allocation_tag"; +const std::vector<TString> MemoryAllocationTags = {"0", "1", "2", "3", "4", "5", "6", "7"}; + +//////////////////////////////////////////////////////////////////////////////// + +template <size_t Index> +Y_NO_INLINE auto BlowHeap() +{ + std::vector<TString> data; + for (int i = 0; i < 10240; i++) { + data.push_back(TString(1024, 'x')); + } + return data; +} + +TEST(HeapProfiler, ReadProfile) +{ + absl::SetStackUnwinder(AbslStackUnwinder); + tcmalloc::MallocExtension::SetProfileSamplingRate(256_KB); + + auto token = tcmalloc::MallocExtension::StartAllocationProfiling(); + + EnableMemoryProfilingTags(); + auto traceContext = TTraceContext::NewRoot("Root"); + TTraceContextGuard guard(traceContext); + + traceContext->SetAllocationTags({{"user", "first"}, {"sometag", "my"}}); + + auto h0 = BlowHeap<0>(); + + auto tag = TMemoryTag(1); + traceContext->SetAllocationTags({{"user", "second"}, {"sometag", "notmy"}, {MemoryAllocationTag, ToString(tag)}}); + auto currentTag = traceContext->FindAllocationTag<TMemoryTag>(MemoryAllocationTag); + ASSERT_EQ(currentTag, tag); + + auto h1 = BlowHeap<1>(); + + traceContext->ClearAllocationTagsPtr(); + + auto h2 = BlowHeap<2>(); + h2.clear(); + + auto usage = CollectMemoryUsageSnapshot()->GetUsage(MemoryAllocationTag, ToString(tag)); + ASSERT_GE(usage, 5_MB); + + auto dumpProfile = [] (auto name, auto type) { + auto profile = ReadHeapProfile(type); + + TFileOutput output(GetOutputPath() / name); + WriteProfile(&output, profile); + output.Finish(); + }; + + dumpProfile("heap.pb.gz", tcmalloc::ProfileType::kHeap); + dumpProfile("peak.pb.gz", tcmalloc::ProfileType::kPeakHeap); + dumpProfile("fragmentation.pb.gz", tcmalloc::ProfileType::kFragmentation); + dumpProfile("allocations.pb.gz", tcmalloc::ProfileType::kAllocations); + + auto profile = std::move(token).Stop(); + + TFileOutput output(GetOutputPath() / "allocations.pb.gz"); + WriteProfile(&output, ConvertAllocationProfile(profile)); + output.Finish(); +} + +TEST(HeapProfiler, AllocationTagsWithMemoryTag) +{ + EnableMemoryProfilingTags(); + auto traceContext = TTraceContext::NewRoot("Root"); + TTraceContextGuard guard(traceContext); + + ASSERT_EQ(traceContext->FindAllocationTag<TString>(MemoryAllocationTag), std::nullopt); + traceContext->SetAllocationTags({{"user", "first user"}, {MemoryAllocationTag, MemoryAllocationTags[0]}}); + ASSERT_EQ(traceContext->FindAllocationTag<TString>("user"), "first user"); + ASSERT_EQ(traceContext->FindAllocationTag<TString>(MemoryAllocationTag), MemoryAllocationTags[0]); + + std::vector<std::vector<TString>> heap; + heap.push_back(BlowHeap<0>()); + + traceContext->SetAllocationTags({{"user", "second user"}, {MemoryAllocationTag, MemoryAllocationTags[1]}}); + ASSERT_EQ(traceContext->FindAllocationTag<TMemoryTag>(MemoryAllocationTag), 1); + + heap.push_back(BlowHeap<1>()); + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[0]); + + auto usage1 = CollectMemoryUsageSnapshot()->GetUsage(MemoryAllocationTag, MemoryAllocationTags[1]); + + ASSERT_NEAR(usage1, 12_MB, 8_MB); + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[2]); + ASSERT_EQ(traceContext->FindAllocationTag<TString>(MemoryAllocationTag), MemoryAllocationTags[2]); + + { + volatile auto h = BlowHeap<2>(); + } + + traceContext->ClearAllocationTagsPtr(); + ASSERT_EQ(traceContext->FindAllocationTag<TString>(MemoryAllocationTag), std::nullopt); + + heap.push_back(BlowHeap<0>()); + + { + auto snapshot = CollectMemoryUsageSnapshot()->GetUsage(MemoryAllocationTag); + ASSERT_EQ(snapshot[MemoryAllocationTags[1]], usage1); + ASSERT_LE(snapshot[MemoryAllocationTags[2]], 1_MB); + } + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[6]); + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[3]); + heap.push_back(BlowHeap<3>()); + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[4]); + heap.push_back(BlowHeap<4>()); + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[7]); + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[5]); + heap.push_back(BlowHeap<5>()); + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[4]); + heap.push_back(BlowHeap<4>()); + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[7]); + + traceContext->SetAllocationTagsPtr(nullptr); + + auto snapshot = CollectMemoryUsageSnapshot()->GetUsage(MemoryAllocationTag); + + constexpr auto maxDifference = 10_MB; + ASSERT_NEAR(snapshot[MemoryAllocationTags[1]], snapshot[MemoryAllocationTags[3]], maxDifference); + ASSERT_NEAR(snapshot[MemoryAllocationTags[3]], snapshot[MemoryAllocationTags[5]], maxDifference); + ASSERT_NEAR(snapshot[MemoryAllocationTags[1]], snapshot[MemoryAllocationTags[5]], maxDifference); + + ASSERT_NEAR(snapshot[MemoryAllocationTags[4]], 20_MB, 15_MB); + + ASSERT_NEAR(snapshot[MemoryAllocationTags[4]], snapshot[MemoryAllocationTags[1]] + snapshot[MemoryAllocationTags[3]], 2 * maxDifference); + ASSERT_NEAR(snapshot[MemoryAllocationTags[4]], snapshot[MemoryAllocationTags[1]] + snapshot[MemoryAllocationTags[5]], 2 * maxDifference); + ASSERT_NEAR(snapshot[MemoryAllocationTags[4]], snapshot[MemoryAllocationTags[3]] + snapshot[MemoryAllocationTags[5]], 2 * maxDifference); + + ASSERT_LE(snapshot[MemoryAllocationTags[6]], 1_MB); + ASSERT_LE(snapshot[MemoryAllocationTags[7]], 1_MB); +} + +template <size_t Index> +Y_NO_INLINE auto BlowHeap(int64_t megabytes) +{ + std::vector<TString> data; + megabytes <<= 10; + for (int64_t i = 0; i < megabytes; i++) { + data.push_back(TString( 1024, 'x')); + } + return data; +} + +TEST(HeapProfiler, HugeAllocationsTagsWithMemoryTag) +{ + EnableMemoryProfilingTags(); + auto traceContext = TTraceContext::NewRoot("Root"); + TCurrentTraceContextGuard guard(traceContext); + + std::vector<std::vector<TString>> heap; + + heap.push_back(BlowHeap<0>()); + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[1]); + ASSERT_EQ(traceContext->FindAllocationTag<TMemoryTag>(MemoryAllocationTag), 1); + + heap.push_back(BlowHeap<1>(100)); + + { + traceContext->SetAllocationTagsPtr(nullptr); + auto usage = CollectMemoryUsageSnapshot()->GetUsage(MemoryAllocationTag, MemoryAllocationTags[1]); + ASSERT_GE(usage, 100_MB); + ASSERT_LE(usage, 150_MB); + } + + traceContext->SetAllocationTag(MemoryAllocationTag, MemoryAllocationTags[2]); + heap.push_back(BlowHeap<1>(1000)); + + traceContext->SetAllocationTagsPtr(nullptr); + auto usage = CollectMemoryUsageSnapshot()->GetUsage(MemoryAllocationTag, MemoryAllocationTags[2]); + ASSERT_GE(usage, 1000_MB); + ASSERT_LE(usage, 1300_MB); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT::NYTProf diff --git a/yt/yt/library/ytprof/unittests/queue_ut.cpp b/yt/yt/library/ytprof/unittests/queue_ut.cpp new file mode 100644 index 0000000000..f44ae8da32 --- /dev/null +++ b/yt/yt/library/ytprof/unittests/queue_ut.cpp @@ -0,0 +1,66 @@ +#include <gtest/gtest.h> + +#include <yt/yt/library/ytprof/queue.h> + +namespace NYT::NYTProf { +namespace { + +//////////////////////////////////////////////////////////////////////////////// + +TEST(StaticQueue, PushPop) +{ + TStaticQueue queue(10); + + int a, b, c; + void* aptr = reinterpret_cast<void*>(&a); + void* bptr = reinterpret_cast<void*>(&b); + void* cptr = reinterpret_cast<void*>(&c); + + std::vector<void*> backtrace; + auto getBacktrace = [&] () -> std::pair<void*, bool> { + if (backtrace.empty()) { + return {nullptr, false}; + } + + auto ip = backtrace.front(); + backtrace.erase(backtrace.begin()); + return {ip, true}; + }; + + ASSERT_TRUE(queue.TryPush(getBacktrace)); + + backtrace = {aptr, bptr, cptr}; + ASSERT_TRUE(queue.TryPush(getBacktrace)); + + backtrace.push_back(aptr); + ASSERT_TRUE(queue.TryPush(getBacktrace)); + + auto readBacktrace = [&] (void *ip) { + backtrace.push_back(ip); + }; + + ASSERT_TRUE(queue.TryPop(readBacktrace)); + ASSERT_EQ(backtrace, std::vector<void*>{}); + + backtrace.clear(); + ASSERT_TRUE(queue.TryPop(readBacktrace)); + ASSERT_EQ(backtrace, (std::vector<void*>{aptr, bptr, cptr})); + + backtrace.clear(); + ASSERT_TRUE(queue.TryPop(readBacktrace)); + ASSERT_EQ(backtrace, (std::vector<void*>{aptr})); + + ASSERT_FALSE(queue.TryPop(readBacktrace)); +} + +TEST(StaticQueue, Overflow) +{ + TStaticQueue queue(10); + + ASSERT_FALSE(queue.TryPush([] () -> std::pair<void*, bool> { + return {nullptr, true}; + })); +} + +} // namespace +} // namespace NYT::NYTProf diff --git a/yt/yt/library/ytprof/unittests/spinlock_profiler_ut.cpp b/yt/yt/library/ytprof/unittests/spinlock_profiler_ut.cpp new file mode 100644 index 0000000000..1cf315155c --- /dev/null +++ b/yt/yt/library/ytprof/unittests/spinlock_profiler_ut.cpp @@ -0,0 +1,172 @@ +#include <gtest/gtest.h> + +#include <library/cpp/testing/common/env.h> + +#include <yt/yt/library/ytprof/spinlock_profiler.h> +#include <yt/yt/library/ytprof/symbolize.h> +#include <yt/yt/library/ytprof/profile.h> +#include <yt/yt/library/ytprof/external_pprof.h> + +#include <tcmalloc/malloc_extension.h> + +#include <library/cpp/yt/threading/spin_lock.h> + +#include <util/string/cast.h> +#include <util/stream/file.h> +#include <util/datetime/base.h> +#include <util/system/compiler.h> +#include <util/system/shellcommand.h> +#include <util/thread/lfstack.h> +#include <util/generic/size_literals.h> + +namespace NYT::NYTProf { +namespace { + +//////////////////////////////////////////////////////////////////////////////// + +template <class TProfiler> +void RunUnderProfiler(const TString& name, std::function<void()> work, bool checkSamples = true) +{ + TSpinlockProfilerOptions options; + options.ProfileFraction = 10; + + TProfiler profiler(options); + + profiler.Start(); + + work(); + + profiler.Stop(); + + auto profile = profiler.ReadProfile(); + if (checkSamples) { + ASSERT_NE(0, profile.sample_size()); + } + + Symbolize(&profile, true); + AddBuildInfo(&profile, TBuildInfo::GetDefault()); + SymbolizeByExternalPProf(&profile, TSymbolizationOptions{ + .TmpDir = GetOutputPath(), + .KeepTmpDir = true, + .RunTool = [] (const std::vector<TString>& args) { + TShellCommand command{args[0], TList<TString>{args.begin()+1, args.end()}}; + command.Run(); + + EXPECT_TRUE(command.GetExitCode() == 0) + << command.GetError(); + }, + }); + + TFileOutput output(GetOutputPath() / name); + WriteProfile(&output, profile); + output.Finish(); +} + +class TSpinlockProfilerTest + : public ::testing::Test +{ + void SetUp() override + { + if (!IsProfileBuild()) { + GTEST_SKIP() << "rebuild with --build=profile"; + } + } +}; + +TEST_F(TSpinlockProfilerTest, PageHeapLock) +{ + RunUnderProfiler<TSpinlockProfiler>("pageheap_lock.pb.gz", [] { + std::atomic<bool> Stop = false; + + std::thread release([&] { + while (!Stop) { + tcmalloc::MallocExtension::ReleaseMemoryToSystem(1_GB); + } + }); + + std::vector<std::thread> allocators; + for (int i = 0; i < 16; i++) { + allocators.emplace_back([&] { + while (!Stop) { + auto ptr = malloc(4_MB); + DoNotOptimizeAway(ptr); + free(ptr); + } + }); + } + + Sleep(TDuration::Seconds(5)); + Stop = true; + + release.join(); + for (auto& t : allocators) { + t.join(); + } + }); +} + + +TEST_F(TSpinlockProfilerTest, TransferCacheLock) +{ + RunUnderProfiler<TSpinlockProfiler>("transfer_cache_lock.pb.gz", [] { + std::atomic<bool> Stop = false; + + TLockFreeStack<int> stack; + std::thread producer([&] { + while (!Stop) { + stack.Enqueue(1); + } + }); + + std::thread consumer([&] { + while (!Stop) { + int value; + stack.Dequeue(&value); + } + }); + + Sleep(TDuration::Seconds(5)); + + Stop = true; + producer.join(); + consumer.join(); + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST_F(TSpinlockProfilerTest, YTLocks) +{ + RunUnderProfiler<TBlockingProfiler>("ytlock.pb.gz", [] { + std::atomic<bool> Stop = false; + + NThreading::TSpinLock lock; + std::thread slow([&] { + while (!Stop) { + lock.Acquire(); + Sleep(TDuration::MilliSeconds(10)); + lock.Release(); + Sleep(TDuration::MilliSeconds(10)); + } + }); + + std::thread fast([&] { + while (!Stop) { + lock.Acquire(); + lock.Release(); + } + }); + + Sleep(TDuration::Seconds(5)); + + Stop = true; + slow.join(); + fast.join(); + }); +} + +//////////////////////////////////////////////////////////////////////////////// + + +} // namespace +} // namespace NYT::NYTProf diff --git a/yt/yt/library/ytprof/unittests/symbolizer_ut.cpp b/yt/yt/library/ytprof/unittests/symbolizer_ut.cpp new file mode 100644 index 0000000000..646ecaafad --- /dev/null +++ b/yt/yt/library/ytprof/unittests/symbolizer_ut.cpp @@ -0,0 +1,83 @@ +#include <gtest/gtest.h> + +#include <yt/yt/library/ytprof/symbolize.h> +#include <yt/yt/library/ytprof/build_info.h> + +namespace NYT::NYTProf { +namespace { + +//////////////////////////////////////////////////////////////////////////////// + +Y_NO_INLINE void* GetIP() +{ + return __builtin_return_address(0); +} + +TEST(Symbolize, EmptyProfile) +{ + NProto::Profile profile; + profile.add_string_table(); + + Symbolize(&profile); + AddBuildInfo(&profile, TBuildInfo::GetDefault()); +} + +TEST(Symbolize, SingleLocation) +{ + NProto::Profile profile; + profile.add_string_table(); + + auto thisIP = GetIP(); + + { + auto location = profile.add_location(); + location->set_address(reinterpret_cast<ui64>(thisIP)); + + auto line = location->add_line(); + line->set_function_id(reinterpret_cast<ui64>(thisIP)); + + auto function = profile.add_function(); + function->set_id(reinterpret_cast<ui64>(thisIP)); + } + + Symbolize(&profile); + + ASSERT_EQ(1, profile.function_size()); + auto function = profile.function(0); + + auto name = profile.string_table(function.name()); + ASSERT_TRUE(name.find("SingleLocation") != TString::npos) + << "function name is " << name; +} + +TEST(Symbolize, GetBuildId) +{ + if (!IsProfileBuild()) { + GTEST_SKIP(); + } + + return; + + auto buildId = GetBuildId(); + ASSERT_TRUE(buildId); + ASSERT_NE(*buildId, TString{""}); +} + +TEST(BuildInfo, Test) +{ + if (!IsProfileBuild()) { + GTEST_SKIP(); + } + + auto info = TBuildInfo::GetDefault(); + if (IsProfileBuild()) { + ASSERT_EQ(info.BuildType, "profile"); + } + + ASSERT_NE(info.ArcRevision, ""); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT::NYTProf diff --git a/yt/yt/library/ytprof/unittests/testso/testso.cpp b/yt/yt/library/ytprof/unittests/testso/testso.cpp new file mode 100644 index 0000000000..4209d5f6cc --- /dev/null +++ b/yt/yt/library/ytprof/unittests/testso/testso.cpp @@ -0,0 +1,9 @@ +#include <atomic> + +static std::atomic<int> CallCount; + +extern "C" void CallNext(void (*next)()) +{ + next(); + CallCount++; +} diff --git a/yt/yt/library/ytprof/unittests/testso/ya.make b/yt/yt/library/ytprof/unittests/testso/ya.make new file mode 100644 index 0000000000..8851d8eed6 --- /dev/null +++ b/yt/yt/library/ytprof/unittests/testso/ya.make @@ -0,0 +1,7 @@ +DLL(testso) + +SRCS( + testso.cpp +) + +END() diff --git a/yt/yt/library/ytprof/unittests/testso1/testso.cpp b/yt/yt/library/ytprof/unittests/testso1/testso.cpp new file mode 100644 index 0000000000..56e348c020 --- /dev/null +++ b/yt/yt/library/ytprof/unittests/testso1/testso.cpp @@ -0,0 +1,9 @@ +#include <atomic> + +static std::atomic<int> CallCount; + +extern "C" void CallOtherNext(void (*next)()) +{ + next(); + CallCount++; +} diff --git a/yt/yt/library/ytprof/unittests/testso1/ya.make b/yt/yt/library/ytprof/unittests/testso1/ya.make new file mode 100644 index 0000000000..d0d2f90143 --- /dev/null +++ b/yt/yt/library/ytprof/unittests/testso1/ya.make @@ -0,0 +1,7 @@ +DLL(testso1) + +SRCS( + testso.cpp +) + +END() diff --git a/yt/yt/library/ytprof/unittests/ya.make b/yt/yt/library/ytprof/unittests/ya.make new file mode 100644 index 0000000000..8c8bce98f5 --- /dev/null +++ b/yt/yt/library/ytprof/unittests/ya.make @@ -0,0 +1,43 @@ +GTEST() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + symbolizer_ut.cpp + cpu_profiler_ut.cpp + heap_profiler_ut.cpp + spinlock_profiler_ut.cpp + queue_ut.cpp +) + +INCLUDE(${ARCADIA_ROOT}/yt/opensource_tests.inc) + +PEERDIR( + yt/yt/library/ytprof + yt/yt/library/profiling + yt/yt/core +) + +IF (OS_LINUX) + LDFLAGS("-Wl,--build-id=sha1") +ENDIF() + +DEPENDS( + yt/yt/library/ytprof/unittests/testso + yt/yt/library/ytprof/unittests/testso1 +) + +IF (BUILD_TYPE != "release" AND BUILD_TYPE != "relwithdebinfo") + CFLAGS(-DYTPROF_DEBUG_BUILD) +ENDIF() + +SIZE(MEDIUM) + +ALLOCATOR(TCMALLOC_256K) + +END() + +RECURSE( + testso + testso1 +) |