aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/mypy-protobuf/mypy_protobuf/main.py
diff options
context:
space:
mode:
authorrobot-contrib <robot-contrib@yandex-team.com>2024-10-02 15:07:20 +0300
committerrobot-contrib <robot-contrib@yandex-team.com>2024-10-02 15:17:02 +0300
commit91a4451afcbafd41dd6c49c0a7b6b1d701ab4c9c (patch)
tree675adb650d5d1f37f8cd415b39f69ea4660d7334 /contrib/python/mypy-protobuf/mypy_protobuf/main.py
parent6c39e2242d0ff2aa6ea07080256c950fb9ef5ab3 (diff)
downloadydb-91a4451afcbafd41dd6c49c0a7b6b1d701ab4c9c.tar.gz
Update contrib/python/mypy-protobuf to 3.5.0
commit_hash:bfda51486cc75a7834db8b71f122c4f26dde8b37
Diffstat (limited to 'contrib/python/mypy-protobuf/mypy_protobuf/main.py')
-rw-r--r--contrib/python/mypy-protobuf/mypy_protobuf/main.py131
1 files changed, 97 insertions, 34 deletions
diff --git a/contrib/python/mypy-protobuf/mypy_protobuf/main.py b/contrib/python/mypy-protobuf/mypy_protobuf/main.py
index ea4635cb44..f0b3dbc7e8 100644
--- a/contrib/python/mypy-protobuf/mypy_protobuf/main.py
+++ b/contrib/python/mypy-protobuf/mypy_protobuf/main.py
@@ -24,7 +24,7 @@ from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
from google.protobuf.internal.well_known_types import WKTBASES
from . import extensions_pb2
-__version__ = "3.3.0"
+__version__ = "3.5.0"
# SourceCodeLocation is defined by `message Location` here
# https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
@@ -171,6 +171,7 @@ class PkgWriter(object):
stabilization = {
"Literal": (3, 8),
"TypeAlias": (3, 10),
+ "final": (3, 8),
}
assert name in stabilization
if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]:
@@ -345,7 +346,7 @@ class PkgWriter(object):
wl("V: {} = ValueType", self._import("typing_extensions", "TypeAlias"))
wl("")
wl(
- "class {}({}[{}], {}): # noqa: F821",
+ "class {}({}[{}], {}):",
etw_helper_class,
self._import("google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"),
value_type_helper_fq,
@@ -406,6 +407,7 @@ class PkgWriter(object):
class_name = desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name
message_class = self._import("google.protobuf.message", "Message")
+ wl("@{}", self._import("typing_extensions", "final"))
wl(f"class {class_name}({message_class}{addl_base}):")
with self._indent():
scl = scl_prefix + [i]
@@ -457,8 +459,7 @@ class PkgWriter(object):
wl("def __init__(")
with self._indent():
if any(f.name == "self" for f in desc.field):
- wl("# pyright: reportSelfClsParameterName=false")
- wl("self_,")
+ wl("self_, # pyright: ignore[reportSelfClsParameterName]")
else:
wl("self,")
with self._indent():
@@ -574,7 +575,7 @@ class PkgWriter(object):
wl("@{}", self._import("abc", "abstractmethod"))
wl(f"def {method.name}(")
with self._indent():
- wl(f"inst: {class_name},")
+ wl(f"inst: {class_name}, # pyright: ignore[reportSelfClsParameterName]")
wl(
"rpc_controller: {},",
self._import("google.protobuf.service", "RpcController"),
@@ -660,30 +661,77 @@ class PkgWriter(object):
return ktype, vtype
- def _callable_type(self, method: d.MethodDescriptorProto) -> str:
+ def _callable_type(self, method: d.MethodDescriptorProto, is_async: bool = False) -> str:
+ module = "grpc.aio" if is_async else "grpc"
if method.client_streaming:
if method.server_streaming:
- return self._import("grpc", "StreamStreamMultiCallable")
+ return self._import(module, "StreamStreamMultiCallable")
else:
- return self._import("grpc", "StreamUnaryMultiCallable")
+ return self._import(module, "StreamUnaryMultiCallable")
else:
if method.server_streaming:
- return self._import("grpc", "UnaryStreamMultiCallable")
+ return self._import(module, "UnaryStreamMultiCallable")
else:
- return self._import("grpc", "UnaryUnaryMultiCallable")
+ return self._import(module, "UnaryUnaryMultiCallable")
- def _input_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
+ def _input_type(self, method: d.MethodDescriptorProto) -> str:
result = self._import_message(method.input_type)
- if use_stream_iterator and method.client_streaming:
- result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
return result
- def _output_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
+ def _servicer_input_type(self, method: d.MethodDescriptorProto) -> str:
+ result = self._import_message(method.input_type)
+ if method.client_streaming:
+ # See write_grpc_async_hacks().
+ result = f"_MaybeAsyncIterator[{result}]"
+ return result
+
+ def _output_type(self, method: d.MethodDescriptorProto) -> str:
result = self._import_message(method.output_type)
- if use_stream_iterator and method.server_streaming:
- result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
return result
+ def _servicer_output_type(self, method: d.MethodDescriptorProto) -> str:
+ result = self._import_message(method.output_type)
+ if method.server_streaming:
+ # Union[Iterator[Resp], AsyncIterator[Resp]] is subtyped by Iterator[Resp] and AsyncIterator[Resp].
+ # So both can be used in the covariant function return position.
+ iterator = f"{self._import('collections.abc', 'Iterator')}[{result}]"
+ aiterator = f"{self._import('collections.abc', 'AsyncIterator')}[{result}]"
+ result = f"{self._import('typing', 'Union')}[{iterator}, {aiterator}]"
+ else:
+ # Union[Resp, Awaitable[Resp]] is subtyped by Resp and Awaitable[Resp].
+ # So both can be used in the covariant function return position.
+ # Awaitable[Resp] is equivalent to async def.
+ awaitable = f"{self._import('collections.abc', 'Awaitable')}[{result}]"
+ result = f"{self._import('typing', 'Union')}[{result}, {awaitable}]"
+ return result
+
+ def write_grpc_async_hacks(self) -> None:
+ wl = self._write_line
+ # _MaybeAsyncIterator[Req] is supertyped by Iterator[Req] and AsyncIterator[Req].
+ # So both can be used in the contravariant function parameter position.
+ wl("_T = {}('_T')", self._import("typing", "TypeVar"))
+ wl("")
+ wl(
+ "class _MaybeAsyncIterator({}[_T], {}[_T], metaclass={}):",
+ self._import("collections.abc", "AsyncIterator"),
+ self._import("collections.abc", "Iterator"),
+ self._import("abc", "ABCMeta"),
+ )
+ with self._indent():
+ wl("...")
+ wl("")
+
+ # _ServicerContext is supertyped by grpc.ServicerContext and grpc.aio.ServicerContext
+ # So both can be used in the contravariant function parameter position.
+ wl(
+ "class _ServicerContext({}, {}): # type: ignore",
+ self._import("grpc", "ServicerContext"),
+ self._import("grpc.aio", "ServicerContext"),
+ )
+ with self._indent():
+ wl("...")
+ wl("")
+
def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
@@ -698,20 +746,20 @@ class PkgWriter(object):
with self._indent():
wl("self,")
input_name = "request_iterator" if method.client_streaming else "request"
- input_type = self._input_type(method)
+ input_type = self._servicer_input_type(method)
wl(f"{input_name}: {input_type},")
- wl("context: {},", self._import("grpc", "ServicerContext"))
+ wl("context: _ServicerContext,")
wl(
") -> {}:{}",
- self._output_type(method),
+ self._servicer_output_type(method),
" ..." if not self._has_comments(scl) else "",
- ),
+ )
if self._has_comments(scl):
with self._indent():
if not self._write_comments(scl):
wl("...")
- def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
+ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation, is_async: bool = False) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
if not methods:
@@ -720,10 +768,10 @@ class PkgWriter(object):
for i, method in methods:
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
- wl("{}: {}[", method.name, self._callable_type(method))
+ wl("{}: {}[", method.name, self._callable_type(method, is_async=is_async))
with self._indent():
- wl("{},", self._input_type(method, False))
- wl("{},", self._output_type(method, False))
+ wl("{},", self._input_type(method))
+ wl("{},", self._output_type(method))
wl("]")
self._write_comments(scl)
@@ -740,17 +788,31 @@ class PkgWriter(object):
scl = scl_prefix + [i]
# The stub client
- wl(f"class {service.name}Stub:")
+ wl(
+ "class {}Stub:",
+ service.name,
+ )
with self._indent():
if self._write_comments(scl):
wl("")
- wl(
- "def __init__(self, channel: {}) -> None: ...",
- self._import("grpc", "Channel"),
- )
+ # To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
+ channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]"
+ wl("def __init__(self, channel: {}) -> None: ...", channel)
self.write_grpc_stub_methods(service, scl)
wl("")
+ # The (fake) async stub client
+ wl(
+ "class {}AsyncStub:",
+ service.name,
+ )
+ with self._indent():
+ if self._write_comments(scl):
+ wl("")
+ # No __init__ since this isn't a real class (yet), and requires manual casting to work.
+ self.write_grpc_stub_methods(service, scl, is_async=True)
+ wl("")
+
# The service definition interface
wl(
"class {}Servicer(metaclass={}):",
@@ -762,11 +824,13 @@ class PkgWriter(object):
wl("")
self.write_grpc_methods(service, scl)
wl("")
+ server = self._import("grpc", "Server")
+ aserver = self._import("grpc.aio", "Server")
wl(
"def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
service.name,
service.name,
- self._import("grpc", "Server"),
+ f"{self._import('typing', 'Union')}[{server}, {aserver}]",
)
wl("")
@@ -889,7 +953,7 @@ class PkgWriter(object):
for pkg, items in sorted(self.from_imports.items()):
self._write_line(f"from {pkg} import (")
- for (name, reexport_name) in sorted(items):
+ for name, reexport_name in sorted(items):
if reexport_name is None:
self._write_line(f" {name},")
else:
@@ -955,6 +1019,7 @@ def generate_mypy_grpc_stubs(
relax_strict_optional_primitives,
grpc=True,
)
+ pkg_writer.write_grpc_async_hacks()
pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
assert name == fd.name
@@ -965,9 +1030,7 @@ def generate_mypy_grpc_stubs(
@contextmanager
-def code_generation() -> Iterator[
- Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],
-]:
+def code_generation() -> Iterator[Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],]:
if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"):
print("mypy-protobuf " + __version__)
sys.exit(0)