aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/mypy-protobuf/mypy_protobuf/main.py
diff options
context:
space:
mode:
authornkozlovskiy <nmk@ydb.tech>2023-09-29 12:24:06 +0300
committernkozlovskiy <nmk@ydb.tech>2023-09-29 12:41:34 +0300
commite0e3e1717e3d33762ce61950504f9637a6e669ed (patch)
treebca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/mypy-protobuf/mypy_protobuf/main.py
parent38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff)
downloadydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz
add ydb deps
Diffstat (limited to 'contrib/python/mypy-protobuf/mypy_protobuf/main.py')
-rw-r--r--contrib/python/mypy-protobuf/mypy_protobuf/main.py1022
1 files changed, 1022 insertions, 0 deletions
diff --git a/contrib/python/mypy-protobuf/mypy_protobuf/main.py b/contrib/python/mypy-protobuf/mypy_protobuf/main.py
new file mode 100644
index 0000000000..ea4635cb44
--- /dev/null
+++ b/contrib/python/mypy-protobuf/mypy_protobuf/main.py
@@ -0,0 +1,1022 @@
+#!/usr/bin/env python
+"""Protoc Plugin to generate mypy stubs."""
+from __future__ import annotations
+
+import sys
+from collections import defaultdict
+from contextlib import contextmanager
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Set,
+ Sequence,
+ Tuple,
+)
+
+import google.protobuf.descriptor_pb2 as d
+from google.protobuf.compiler import plugin_pb2 as plugin_pb2
+from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
+from google.protobuf.internal.well_known_types import WKTBASES
+from . import extensions_pb2
+
+__version__ = "3.3.0"
+
+# SourceCodeLocation is defined by `message Location` here
+# https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
+SourceCodeLocation = List[int]
+
+# So phabricator doesn't think mypy_protobuf.py is generated
+GENERATED = "@ge" + "nerated"
+HEADER = f"""
+{GENERATED} by mypy-protobuf. Do not edit manually!
+isort:skip_file
+"""
+
+# See https://github.com/nipunn1313/mypy-protobuf/issues/73 for details
+PYTHON_RESERVED = {
+ "False",
+ "None",
+ "True",
+ "and",
+ "as",
+ "async",
+ "await",
+ "assert",
+ "break",
+ "class",
+ "continue",
+ "def",
+ "del",
+ "elif",
+ "else",
+ "except",
+ "finally",
+ "for",
+ "from",
+ "global",
+ "if",
+ "import",
+ "in",
+ "is",
+ "lambda",
+ "nonlocal",
+ "not",
+ "or",
+ "pass",
+ "raise",
+ "return",
+ "try",
+ "while",
+ "with",
+ "yield",
+}
+
+PROTO_ENUM_RESERVED = {
+ "Name",
+ "Value",
+ "keys",
+ "values",
+ "items",
+}
+
+
+def _mangle_global_identifier(name: str) -> str:
+ """
+ Module level identifiers are mangled and aliased so that they can be disambiguated
+ from fields/enum variants with the same name within the file.
+
+ Eg:
+ Enum variant `Name` or message field `Name` might conflict with a top level
+ message or enum named `Name`, so mangle it with a global___ prefix for
+ internal references. Note that this doesn't affect inner enums/messages
+ because they get fuly qualified when referenced within a file"""
+ return f"global___{name}"
+
+
+class Descriptors(object):
+ def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None:
+ files = {f.name: f for f in request.proto_file}
+ to_generate = {n: files[n] for n in request.file_to_generate}
+ self.files: Dict[str, d.FileDescriptorProto] = files
+ self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate
+ self.messages: Dict[str, d.DescriptorProto] = {}
+ self.message_to_fd: Dict[str, d.FileDescriptorProto] = {}
+
+ def _add_enums(
+ enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]",
+ prefix: str,
+ _fd: d.FileDescriptorProto,
+ ) -> None:
+ for enum in enums:
+ self.message_to_fd[prefix + enum.name] = _fd
+ self.message_to_fd[prefix + enum.name + ".ValueType"] = _fd
+
+ def _add_messages(
+ messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]",
+ prefix: str,
+ _fd: d.FileDescriptorProto,
+ ) -> None:
+ for message in messages:
+ self.messages[prefix + message.name] = message
+ self.message_to_fd[prefix + message.name] = _fd
+ sub_prefix = prefix + message.name + "."
+ _add_messages(message.nested_type, sub_prefix, _fd)
+ _add_enums(message.enum_type, sub_prefix, _fd)
+
+ for fd in request.proto_file:
+ start_prefix = "." + fd.package + "." if fd.package else "."
+ _add_messages(fd.message_type, start_prefix, fd)
+ _add_enums(fd.enum_type, start_prefix, fd)
+
+
+class PkgWriter(object):
+ """Writes a single pyi file"""
+
+ def __init__(
+ self,
+ fd: d.FileDescriptorProto,
+ descriptors: Descriptors,
+ readable_stubs: bool,
+ relax_strict_optional_primitives: bool,
+ grpc: bool,
+ ) -> None:
+ self.fd = fd
+ self.descriptors = descriptors
+ self.readable_stubs = readable_stubs
+ self.relax_strict_optional_primitives = relax_strict_optional_primitives
+ self.grpc = grpc
+ self.lines: List[str] = []
+ self.indent = ""
+
+ # Set of {x}, where {x} corresponds to to `import {x}`
+ self.imports: Set[str] = set()
+ # dictionary of x->(y,z) for `from {x} import {y} as {z}`
+ # if {z} is None, then it shortens to `from {x} import {y}`
+ self.from_imports: Dict[str, Set[Tuple[str, str | None]]] = defaultdict(set)
+ self.typing_extensions_min: Optional[Tuple[int, int]] = None
+
+ # Comments
+ self.source_code_info_by_scl = {tuple(location.path): location for location in fd.source_code_info.location}
+
+ def _import(self, path: str, name: str) -> str:
+ """Imports a stdlib path and returns a handle to it
+ eg. self._import("typing", "Literal") -> "Literal"
+ """
+ if path == "typing_extensions":
+ stabilization = {
+ "Literal": (3, 8),
+ "TypeAlias": (3, 10),
+ }
+ assert name in stabilization
+ if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]:
+ self.typing_extensions_min = stabilization[name]
+ return "typing_extensions." + name
+
+ imp = path.replace("/", ".")
+ if self.readable_stubs:
+ self.from_imports[imp].add((name, None))
+ return name
+ else:
+ self.imports.add(imp)
+ return imp + "." + name
+
+ def _import_message(self, name: str) -> str:
+ """Import a referenced message and return a handle"""
+ message_fd = self.descriptors.message_to_fd[name]
+ assert message_fd.name.endswith(".proto")
+
+ # Strip off package name
+ if message_fd.package:
+ assert name.startswith("." + message_fd.package + ".")
+ name = name[len("." + message_fd.package + ".") :]
+ else:
+ assert name.startswith(".")
+ name = name[1:]
+
+ # Use prepended "_r_" to disambiguate message names that alias python reserved keywords
+ split = name.split(".")
+ for i, part in enumerate(split):
+ if part in PYTHON_RESERVED:
+ split[i] = "_r_" + part
+ name = ".".join(split)
+
+ # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files
+ if not self.grpc and message_fd.name == self.fd.name:
+ return name if self.readable_stubs else _mangle_global_identifier(name)
+
+ # Not in file. Must import
+ # Python generated code ignores proto packages, so the only relevant factor is
+ # whether it is in the file or not.
+ import_name = self._import(message_fd.name[:-6].replace("-", "_") + "_pb2", split[0])
+
+ remains = ".".join(split[1:])
+ if not remains:
+ return import_name
+
+ # remains could either be a direct import of a nested enum or message
+ # from another package.
+ return import_name + "." + remains
+
+ def _builtin(self, name: str) -> str:
+ return self._import("builtins", name)
+
+ @contextmanager
+ def _indent(self) -> Iterator[None]:
+ self.indent = self.indent + " "
+ yield
+ self.indent = self.indent[:-4]
+
+ def _write_line(self, line: str, *args: Any) -> None:
+ if args:
+ line = line.format(*args)
+ if line == "":
+ self.lines.append(line)
+ else:
+ self.lines.append(self.indent + line)
+
+ def _break_text(self, text_block: str) -> List[str]:
+ if text_block == "":
+ return []
+ return [line[1:] if line.startswith(" ") else line for line in text_block.rstrip().split("\n")]
+
+ def _has_comments(self, scl: SourceCodeLocation) -> bool:
+ sci_loc = self.source_code_info_by_scl.get(tuple(scl))
+ return sci_loc is not None and bool(sci_loc.leading_detached_comments or sci_loc.leading_comments or sci_loc.trailing_comments)
+
+ def _write_comments(self, scl: SourceCodeLocation) -> bool:
+ """Return true if any comments were written"""
+ if not self._has_comments(scl):
+ return False
+
+ sci_loc = self.source_code_info_by_scl.get(tuple(scl))
+ assert sci_loc is not None
+
+ leading_detached_lines = []
+ leading_lines = []
+ trailing_lines = []
+ for leading_detached_comment in sci_loc.leading_detached_comments:
+ leading_detached_lines = self._break_text(leading_detached_comment)
+ if sci_loc.leading_comments is not None:
+ leading_lines = self._break_text(sci_loc.leading_comments)
+ # Trailing comments also go in the header - to make sure it gets into the docstring
+ if sci_loc.trailing_comments is not None:
+ trailing_lines = self._break_text(sci_loc.trailing_comments)
+
+ lines = leading_detached_lines
+ if leading_detached_lines and (leading_lines or trailing_lines):
+ lines.append("")
+ lines.extend(leading_lines)
+ lines.extend(trailing_lines)
+
+ lines = [
+ # Escape triple-quotes that would otherwise end the docstring early.
+ line.replace("\\", "\\\\").replace('"""', r"\"\"\"")
+ for line in lines
+ ]
+ if len(lines) == 1:
+ line = lines[0]
+ if line.endswith(('"', "\\")):
+ # Docstrings are terminated with triple-quotes, so if the documentation itself ends in a quote,
+ # insert some whitespace to separate it from the closing quotes.
+ # This is not necessary with multiline comments
+ # because in that case we always insert a newline before the trailing triple-quotes.
+ line = line + " "
+ self._write_line(f'"""{line}"""')
+ else:
+ for i, line in enumerate(lines):
+ if i == 0:
+ self._write_line(f'"""{line}')
+ else:
+ self._write_line(f"{line}")
+ self._write_line('"""')
+
+ return True
+
+ def write_enum_values(
+ self,
+ values: Iterable[Tuple[int, d.EnumValueDescriptorProto]],
+ value_type: str,
+ scl_prefix: SourceCodeLocation,
+ ) -> None:
+ for i, val in values:
+ if val.name in PYTHON_RESERVED:
+ continue
+
+ scl = scl_prefix + [i]
+ self._write_line(
+ f"{val.name}: {value_type} # {val.number}",
+ )
+ self._write_comments(scl)
+
+ def write_module_attributes(self) -> None:
+ wl = self._write_line
+ fd_type = self._import("google.protobuf.descriptor", "FileDescriptor")
+ wl(f"DESCRIPTOR: {fd_type}")
+ wl("")
+
+ def write_enums(
+ self,
+ enums: Iterable[d.EnumDescriptorProto],
+ prefix: str,
+ scl_prefix: SourceCodeLocation,
+ ) -> None:
+ wl = self._write_line
+ for i, enum in enumerate(enums):
+ class_name = enum.name if enum.name not in PYTHON_RESERVED else "_r_" + enum.name
+ value_type_fq = prefix + class_name + ".ValueType"
+ enum_helper_class = "_" + enum.name
+ value_type_helper_fq = prefix + enum_helper_class + ".ValueType"
+ etw_helper_class = "_" + enum.name + "EnumTypeWrapper"
+ scl = scl_prefix + [i]
+
+ wl(f"class {enum_helper_class}:")
+ with self._indent():
+ wl(
+ 'ValueType = {}("ValueType", {})',
+ self._import("typing", "NewType"),
+ self._builtin("int"),
+ )
+ # Alias to the classic shorter definition "V"
+ wl("V: {} = ValueType", self._import("typing_extensions", "TypeAlias"))
+ wl("")
+ wl(
+ "class {}({}[{}], {}): # noqa: F821",
+ etw_helper_class,
+ self._import("google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"),
+ value_type_helper_fq,
+ self._builtin("type"),
+ )
+ with self._indent():
+ ed = self._import("google.protobuf.descriptor", "EnumDescriptor")
+ wl(f"DESCRIPTOR: {ed}")
+ self.write_enum_values(
+ [(i, v) for i, v in enumerate(enum.value) if v.name not in PROTO_ENUM_RESERVED],
+ value_type_helper_fq,
+ scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
+ )
+ wl("")
+
+ if self._has_comments(scl):
+ wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}):")
+ with self._indent():
+ self._write_comments(scl)
+ wl("")
+ else:
+ wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}): ...")
+ if prefix == "":
+ wl("")
+
+ self.write_enum_values(
+ enumerate(enum.value),
+ value_type_fq,
+ scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
+ )
+ if prefix == "" and not self.readable_stubs:
+ wl(f"{_mangle_global_identifier(class_name)} = {class_name}")
+ wl("")
+
+ def write_messages(
+ self,
+ messages: Iterable[d.DescriptorProto],
+ prefix: str,
+ scl_prefix: SourceCodeLocation,
+ ) -> None:
+ wl = self._write_line
+
+ for i, desc in enumerate(messages):
+ qualified_name = prefix + desc.name
+
+ # Reproduce some hardcoded logic from the protobuf implementation - where
+ # some specific "well_known_types" generated protos to have additional
+ # base classes
+ addl_base = ""
+ if self.fd.package + "." + desc.name in WKTBASES:
+ # chop off the .proto - and import the well known type
+ # eg `from google.protobuf.duration import Duration`
+ well_known_type = WKTBASES[self.fd.package + "." + desc.name]
+ addl_base = ", " + self._import(
+ "google.protobuf.internal.well_known_types",
+ well_known_type.__name__,
+ )
+
+ class_name = desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name
+ message_class = self._import("google.protobuf.message", "Message")
+ wl(f"class {class_name}({message_class}{addl_base}):")
+ with self._indent():
+ scl = scl_prefix + [i]
+ if self._write_comments(scl):
+ wl("")
+
+ desc_type = self._import("google.protobuf.descriptor", "Descriptor")
+ wl(f"DESCRIPTOR: {desc_type}")
+ wl("")
+
+ # Nested enums/messages
+ self.write_enums(
+ desc.enum_type,
+ qualified_name + ".",
+ scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER],
+ )
+ self.write_messages(
+ desc.nested_type,
+ qualified_name + ".",
+ scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER],
+ )
+
+ # integer constants for field numbers
+ for f in desc.field:
+ wl(f"{f.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
+
+ for idx, field in enumerate(desc.field):
+ if field.name in PYTHON_RESERVED:
+ continue
+ field_type = self.python_type(field)
+
+ if is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED:
+ # Scalar non repeated fields are r/w
+ wl(f"{field.name}: {field_type}")
+ self._write_comments(scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx])
+ else:
+ # r/o Getters for non-scalar fields and scalar-repeated fields
+ scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
+ wl("@property")
+ body = " ..." if not self._has_comments(scl_field) else ""
+ wl(f"def {field.name}(self) -> {field_type}:{body}")
+ if self._has_comments(scl_field):
+ with self._indent():
+ self._write_comments(scl_field)
+
+ self.write_extensions(desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER])
+
+ # Constructor
+ wl("def __init__(")
+ with self._indent():
+ if any(f.name == "self" for f in desc.field):
+ wl("# pyright: reportSelfClsParameterName=false")
+ wl("self_,")
+ else:
+ wl("self,")
+ with self._indent():
+ constructor_fields = [f for f in desc.field if f.name not in PYTHON_RESERVED]
+ if len(constructor_fields) > 0:
+ # Only positional args allowed
+ # See https://github.com/nipunn1313/mypy-protobuf/issues/71
+ wl("*,")
+ for field in constructor_fields:
+ field_type = self.python_type(field, generic_container=True)
+ if self.fd.syntax == "proto3" and is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED and not self.relax_strict_optional_primitives and not field.proto3_optional:
+ wl(f"{field.name}: {field_type} = ...,")
+ else:
+ wl(f"{field.name}: {field_type} | None = ...,")
+ wl(") -> None: ...")
+
+ self.write_stringly_typed_fields(desc)
+
+ if prefix == "" and not self.readable_stubs:
+ wl("")
+ wl(f"{_mangle_global_identifier(class_name)} = {class_name}")
+ wl("")
+
+ def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None:
+ """Type the stringly-typed methods as a Union[Literal, Literal ...]"""
+ wl = self._write_line
+ # HasField, ClearField, WhichOneof accepts both bytes/str
+ # HasField only supports singular. ClearField supports repeated as well
+ # In proto3, HasField only supports message fields and optional fields
+ # HasField always supports oneof fields
+ hf_fields = [f.name for f in desc.field if f.HasField("oneof_index") or (f.label != d.FieldDescriptorProto.LABEL_REPEATED and (self.fd.syntax != "proto3" or f.type == d.FieldDescriptorProto.TYPE_MESSAGE or f.proto3_optional))]
+ cf_fields = [f.name for f in desc.field]
+ wo_fields = {oneof.name: [f.name for f in desc.field if f.HasField("oneof_index") and f.oneof_index == idx] for idx, oneof in enumerate(desc.oneof_decl)}
+
+ hf_fields.extend(wo_fields.keys())
+ cf_fields.extend(wo_fields.keys())
+
+ hf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in hf_fields))
+ cf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in cf_fields))
+
+ if not hf_fields and not cf_fields and not wo_fields:
+ return
+
+ if hf_fields:
+ wl(
+ "def HasField(self, field_name: {}[{}]) -> {}: ...",
+ self._import("typing_extensions", "Literal"),
+ hf_fields_text,
+ self._builtin("bool"),
+ )
+ if cf_fields:
+ wl(
+ "def ClearField(self, field_name: {}[{}]) -> None: ...",
+ self._import("typing_extensions", "Literal"),
+ cf_fields_text,
+ )
+
+ for wo_field, members in sorted(wo_fields.items()):
+ if len(wo_fields) > 1:
+ wl("@{}", self._import("typing", "overload"))
+ wl(
+ "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}] | None: ...",
+ self._import("typing_extensions", "Literal"),
+ # Accepts both str and bytes
+ f'"{wo_field}", b"{wo_field}"',
+ self._import("typing_extensions", "Literal"),
+ # Returns `str`
+ ", ".join(f'"{m}"' for m in members),
+ )
+
+ def write_extensions(
+ self,
+ extensions: Sequence[d.FieldDescriptorProto],
+ scl_prefix: SourceCodeLocation,
+ ) -> None:
+ wl = self._write_line
+
+ for ext in extensions:
+ wl(f"{ext.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
+
+ for i, ext in enumerate(extensions):
+ scl = scl_prefix + [i]
+
+ wl(
+ "{}: {}[{}, {}]",
+ ext.name,
+ self._import(
+ "google.protobuf.internal.extension_dict",
+ "_ExtensionFieldDescriptor",
+ ),
+ self._import_message(ext.extendee),
+ self.python_type(ext),
+ )
+ self._write_comments(scl)
+
+ def write_methods(
+ self,
+ service: d.ServiceDescriptorProto,
+ class_name: str,
+ is_abstract: bool,
+ scl_prefix: SourceCodeLocation,
+ ) -> None:
+ wl = self._write_line
+ wl(
+ "DESCRIPTOR: {}",
+ self._import("google.protobuf.descriptor", "ServiceDescriptor"),
+ )
+ methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
+ if not methods:
+ wl("...")
+ for i, method in methods:
+ if is_abstract:
+ wl("@{}", self._import("abc", "abstractmethod"))
+ wl(f"def {method.name}(")
+ with self._indent():
+ wl(f"inst: {class_name},")
+ wl(
+ "rpc_controller: {},",
+ self._import("google.protobuf.service", "RpcController"),
+ )
+ wl("request: {},", self._import_message(method.input_type))
+ wl(
+ "callback: {}[[{}], None] | None{},",
+ self._import("collections.abc", "Callable"),
+ self._import_message(method.output_type),
+ "" if is_abstract else " = ...",
+ )
+
+ scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
+ wl(
+ ") -> {}[{}]:{}",
+ self._import("concurrent.futures", "Future"),
+ self._import_message(method.output_type),
+ " ..." if not self._has_comments(scl_method) else "",
+ )
+ if self._has_comments(scl_method):
+ with self._indent():
+ if not self._write_comments(scl_method):
+ wl("...")
+
+ def write_services(
+ self,
+ services: Iterable[d.ServiceDescriptorProto],
+ scl_prefix: SourceCodeLocation,
+ ) -> None:
+ wl = self._write_line
+ for i, service in enumerate(services):
+ scl = scl_prefix + [i]
+ class_name = service.name if service.name not in PYTHON_RESERVED else "_r_" + service.name
+ # The service definition interface
+ wl(
+ "class {}({}, metaclass={}):",
+ class_name,
+ self._import("google.protobuf.service", "Service"),
+ self._import("abc", "ABCMeta"),
+ )
+ with self._indent():
+ if self._write_comments(scl):
+ wl("")
+ self.write_methods(service, class_name, is_abstract=True, scl_prefix=scl)
+ wl("")
+
+ # The stub client
+ stub_class_name = service.name + "_Stub"
+ wl("class {}({}):", stub_class_name, class_name)
+ with self._indent():
+ if self._write_comments(scl):
+ wl("")
+ wl(
+ "def __init__(self, rpc_channel: {}) -> None: ...",
+ self._import("google.protobuf.service", "RpcChannel"),
+ )
+ self.write_methods(service, stub_class_name, is_abstract=False, scl_prefix=scl)
+ wl("")
+
+ def _import_casttype(self, casttype: str) -> str:
+ split = casttype.split(".")
+ assert len(split) == 2, "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile"
+ pkg = split[0].replace("/", ".")
+ return self._import(pkg, split[1])
+
+ def _map_key_value_types(
+ self,
+ map_field: d.FieldDescriptorProto,
+ key_field: d.FieldDescriptorProto,
+ value_field: d.FieldDescriptorProto,
+ ) -> Tuple[str, str]:
+ oldstyle_keytype = map_field.options.Extensions[extensions_pb2.keytype]
+ if oldstyle_keytype:
+ print(f"Warning: Map Field {map_field.name}: (mypy_protobuf.keytype) is deprecated. Prefer (mypy_protobuf.options).keytype", file=sys.stderr)
+ key_casttype = map_field.options.Extensions[extensions_pb2.options].keytype or oldstyle_keytype
+ ktype = self._import_casttype(key_casttype) if key_casttype else self.python_type(key_field)
+
+ oldstyle_valuetype = map_field.options.Extensions[extensions_pb2.valuetype]
+ if oldstyle_valuetype:
+ print(f"Warning: Map Field {map_field.name}: (mypy_protobuf.valuetype) is deprecated. Prefer (mypy_protobuf.options).valuetype", file=sys.stderr)
+ value_casttype = map_field.options.Extensions[extensions_pb2.options].valuetype or map_field.options.Extensions[extensions_pb2.valuetype]
+ vtype = self._import_casttype(value_casttype) if value_casttype else self.python_type(value_field)
+
+ return ktype, vtype
+
+ def _callable_type(self, method: d.MethodDescriptorProto) -> str:
+ if method.client_streaming:
+ if method.server_streaming:
+ return self._import("grpc", "StreamStreamMultiCallable")
+ else:
+ return self._import("grpc", "StreamUnaryMultiCallable")
+ else:
+ if method.server_streaming:
+ return self._import("grpc", "UnaryStreamMultiCallable")
+ else:
+ return self._import("grpc", "UnaryUnaryMultiCallable")
+
+ def _input_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
+ result = self._import_message(method.input_type)
+ if use_stream_iterator and method.client_streaming:
+ result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
+ return result
+
+ def _output_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
+ result = self._import_message(method.output_type)
+ if use_stream_iterator and method.server_streaming:
+ result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
+ return result
+
+ def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
+ wl = self._write_line
+ methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
+ if not methods:
+ wl("...")
+ wl("")
+ for i, method in methods:
+ scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
+
+ wl("@{}", self._import("abc", "abstractmethod"))
+ wl("def {}(", method.name)
+ with self._indent():
+ wl("self,")
+ input_name = "request_iterator" if method.client_streaming else "request"
+ input_type = self._input_type(method)
+ wl(f"{input_name}: {input_type},")
+ wl("context: {},", self._import("grpc", "ServicerContext"))
+ wl(
+ ") -> {}:{}",
+ self._output_type(method),
+ " ..." if not self._has_comments(scl) else "",
+ ),
+ if self._has_comments(scl):
+ with self._indent():
+ if not self._write_comments(scl):
+ wl("...")
+
+ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
+ wl = self._write_line
+ methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
+ if not methods:
+ wl("...")
+ wl("")
+ for i, method in methods:
+ scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
+
+ wl("{}: {}[", method.name, self._callable_type(method))
+ with self._indent():
+ wl("{},", self._input_type(method, False))
+ wl("{},", self._output_type(method, False))
+ wl("]")
+ self._write_comments(scl)
+
+ def write_grpc_services(
+ self,
+ services: Iterable[d.ServiceDescriptorProto],
+ scl_prefix: SourceCodeLocation,
+ ) -> None:
+ wl = self._write_line
+ for i, service in enumerate(services):
+ if service.name in PYTHON_RESERVED:
+ continue
+
+ scl = scl_prefix + [i]
+
+ # The stub client
+ wl(f"class {service.name}Stub:")
+ with self._indent():
+ if self._write_comments(scl):
+ wl("")
+ wl(
+ "def __init__(self, channel: {}) -> None: ...",
+ self._import("grpc", "Channel"),
+ )
+ self.write_grpc_stub_methods(service, scl)
+ wl("")
+
+ # The service definition interface
+ wl(
+ "class {}Servicer(metaclass={}):",
+ service.name,
+ self._import("abc", "ABCMeta"),
+ )
+ with self._indent():
+ if self._write_comments(scl):
+ wl("")
+ self.write_grpc_methods(service, scl)
+ wl("")
+ wl(
+ "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
+ service.name,
+ service.name,
+ self._import("grpc", "Server"),
+ )
+ wl("")
+
+ def python_type(self, field: d.FieldDescriptorProto, generic_container: bool = False) -> str:
+ """
+ generic_container
+ if set, type the field with generic interfaces. Eg.
+ - Iterable[int] rather than RepeatedScalarFieldContainer[int]
+ - Mapping[k, v] rather than MessageMap[k, v]
+ Can be useful for input types (eg constructor)
+ """
+ oldstyle_casttype = field.options.Extensions[extensions_pb2.casttype]
+ if oldstyle_casttype:
+ print(f"Warning: Field {field.name}: (mypy_protobuf.casttype) is deprecated. Prefer (mypy_protobuf.options).casttype", file=sys.stderr)
+ casttype = field.options.Extensions[extensions_pb2.options].casttype or oldstyle_casttype
+ if casttype:
+ return self._import_casttype(casttype)
+
+ mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = {
+ d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"),
+ d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"),
+ d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"),
+ d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"),
+ d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"),
+ d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"),
+ d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"),
+ d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"),
+ d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"),
+ d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"),
+ d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"),
+ d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"),
+ d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"),
+ d.FieldDescriptorProto.TYPE_STRING: lambda: self._builtin("str"),
+ d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"),
+ d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(field.type_name + ".ValueType"),
+ d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(field.type_name),
+ d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(field.type_name),
+ }
+
+ assert field.type in mapping, "Unrecognized type: " + repr(field.type)
+ field_type = mapping[field.type]()
+
+ # For non-repeated fields, we're done!
+ if field.label != d.FieldDescriptorProto.LABEL_REPEATED:
+ return field_type
+
+ # Scalar repeated fields go in RepeatedScalarFieldContainer
+ if is_scalar(field):
+ container = (
+ self._import("collections.abc", "Iterable")
+ if generic_container
+ else self._import(
+ "google.protobuf.internal.containers",
+ "RepeatedScalarFieldContainer",
+ )
+ )
+ return f"{container}[{field_type}]"
+
+ # non-scalar repeated map fields go in ScalarMap/MessageMap
+ msg = self.descriptors.messages[field.type_name]
+ if msg.options.map_entry:
+ # map generates a special Entry wrapper message
+ if generic_container:
+ container = self._import("collections.abc", "Mapping")
+ elif is_scalar(msg.field[1]):
+ container = self._import("google.protobuf.internal.containers", "ScalarMap")
+ else:
+ container = self._import("google.protobuf.internal.containers", "MessageMap")
+ ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1])
+ return f"{container}[{ktype}, {vtype}]"
+
+ # non-scalar repetated fields go in RepeatedCompositeFieldContainer
+ container = (
+ self._import("collections.abc", "Iterable")
+ if generic_container
+ else self._import(
+ "google.protobuf.internal.containers",
+ "RepeatedCompositeFieldContainer",
+ )
+ )
+ return f"{container}[{field_type}]"
+
+ def write(self) -> str:
+ # save current module content, so that imports and module docstring can be inserted
+ saved_lines = self.lines
+ self.lines = []
+
+ # module docstring may exist as comment before syntax (optional) or package name
+ if not self._write_comments([d.FileDescriptorProto.PACKAGE_FIELD_NUMBER]):
+ self._write_comments([d.FileDescriptorProto.SYNTAX_FIELD_NUMBER])
+
+ if self.lines:
+ assert self.lines[0].startswith('"""')
+ self.lines[0] = f'"""{HEADER}{self.lines[0][3:]}'
+ else:
+ self._write_line(f'"""{HEADER}"""')
+
+ for reexport_idx in self.fd.public_dependency:
+ reexport_file = self.fd.dependency[reexport_idx]
+ reexport_fd = self.descriptors.files[reexport_file]
+ reexport_imp = reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2"
+ names = [m.name for m in reexport_fd.message_type] + [m.name for m in reexport_fd.enum_type] + [v.name for m in reexport_fd.enum_type for v in m.value] + [m.name for m in reexport_fd.extension]
+ if reexport_fd.options.py_generic_services:
+ names.extend(m.name for m in reexport_fd.service)
+
+ if names:
+ # n,n to force a reexport (from x import y as y)
+ self.from_imports[reexport_imp].update((n, n) for n in names)
+
+ if self.typing_extensions_min:
+ self.imports.add("sys")
+ for pkg in sorted(self.imports):
+ self._write_line(f"import {pkg}")
+ if self.typing_extensions_min:
+ self._write_line("")
+ self._write_line(f"if sys.version_info >= {self.typing_extensions_min}:")
+ self._write_line(" import typing as typing_extensions")
+ self._write_line("else:")
+ self._write_line(" import typing_extensions")
+
+ for pkg, items in sorted(self.from_imports.items()):
+ self._write_line(f"from {pkg} import (")
+ for (name, reexport_name) in sorted(items):
+ if reexport_name is None:
+ self._write_line(f" {name},")
+ else:
+ self._write_line(f" {name} as {reexport_name},")
+ self._write_line(")")
+ self._write_line("")
+
+ # restore module content
+ self.lines += saved_lines
+
+ content = "\n".join(self.lines)
+ if not content.endswith("\n"):
+ content = content + "\n"
+ return content
+
+
+def is_scalar(fd: d.FieldDescriptorProto) -> bool:
+ return not (fd.type == d.FieldDescriptorProto.TYPE_MESSAGE or fd.type == d.FieldDescriptorProto.TYPE_GROUP)
+
+
+def generate_mypy_stubs(
+ descriptors: Descriptors,
+ response: plugin_pb2.CodeGeneratorResponse,
+ quiet: bool,
+ readable_stubs: bool,
+ relax_strict_optional_primitives: bool,
+) -> None:
+ for name, fd in descriptors.to_generate.items():
+ pkg_writer = PkgWriter(
+ fd,
+ descriptors,
+ readable_stubs,
+ relax_strict_optional_primitives,
+ grpc=False,
+ )
+
+ pkg_writer.write_module_attributes()
+ pkg_writer.write_enums(fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER])
+ pkg_writer.write_messages(fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER])
+ pkg_writer.write_extensions(fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER])
+ if fd.options.py_generic_services:
+ pkg_writer.write_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
+
+ assert name == fd.name
+ assert fd.name.endswith(".proto")
+ output = response.file.add()
+ output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi"
+ output.content = pkg_writer.write()
+
+
+def generate_mypy_grpc_stubs(
+ descriptors: Descriptors,
+ response: plugin_pb2.CodeGeneratorResponse,
+ quiet: bool,
+ readable_stubs: bool,
+ relax_strict_optional_primitives: bool,
+) -> None:
+ for name, fd in descriptors.to_generate.items():
+ pkg_writer = PkgWriter(
+ fd,
+ descriptors,
+ readable_stubs,
+ relax_strict_optional_primitives,
+ grpc=True,
+ )
+ pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
+
+ assert name == fd.name
+ assert fd.name.endswith(".proto")
+ output = response.file.add()
+ output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi"
+ output.content = pkg_writer.write()
+
+
+@contextmanager
+def code_generation() -> Iterator[
+ Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],
+]:
+ if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"):
+ print("mypy-protobuf " + __version__)
+ sys.exit(0)
+
+ # Read request message from stdin
+ data = sys.stdin.buffer.read()
+
+ # Parse request
+ request = plugin_pb2.CodeGeneratorRequest()
+ request.ParseFromString(data)
+
+ # Create response
+ response = plugin_pb2.CodeGeneratorResponse()
+
+ # Declare support for optional proto3 fields
+ response.supported_features |= plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
+
+ yield request, response
+
+ # Serialise response message
+ output = response.SerializeToString()
+
+ # Write to stdout
+ sys.stdout.buffer.write(output)
+
+
+def main() -> None:
+ # Generate mypy
+ with code_generation() as (request, response):
+ generate_mypy_stubs(
+ Descriptors(request),
+ response,
+ "quiet" in request.parameter,
+ "readable_stubs" in request.parameter,
+ "relax_strict_optional_primitives" in request.parameter,
+ )
+
+
+def grpc() -> None:
+ # Generate grpc mypy
+ with code_generation() as (request, response):
+ generate_mypy_grpc_stubs(
+ Descriptors(request),
+ response,
+ "quiet" in request.parameter,
+ "readable_stubs" in request.parameter,
+ "relax_strict_optional_primitives" in request.parameter,
+ )
+
+
+if __name__ == "__main__":
+ main()