aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/mypy-protobuf/mypy_protobuf/main.py
diff options
context:
space:
mode:
authormonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
committermonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
commit06e5c21a835c0e923506c4ff27929f34e00761c2 (patch)
tree75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /contrib/python/mypy-protobuf/mypy_protobuf/main.py
parent03f024c4412e3aa613bb543cf1660176320ba8f4 (diff)
downloadydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz
fix ya.make
Diffstat (limited to 'contrib/python/mypy-protobuf/mypy_protobuf/main.py')
-rw-r--r--contrib/python/mypy-protobuf/mypy_protobuf/main.py1086
1 files changed, 0 insertions, 1086 deletions
diff --git a/contrib/python/mypy-protobuf/mypy_protobuf/main.py b/contrib/python/mypy-protobuf/mypy_protobuf/main.py
deleted file mode 100644
index 6e825d8280..0000000000
--- a/contrib/python/mypy-protobuf/mypy_protobuf/main.py
+++ /dev/null
@@ -1,1086 +0,0 @@
-#!/usr/bin/env python
-"""Protoc Plugin to generate mypy stubs. Loosely based on @zbarsky's go implementation"""
-import os
-
-import sys
-from collections import defaultdict
-from contextlib import contextmanager
-from functools import wraps
-from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Set,
- Sequence,
- Tuple,
-)
-
-import google.protobuf.descriptor_pb2 as d
-from google.protobuf.compiler import plugin_pb2 as plugin_pb2
-from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
-from google.protobuf.internal.well_known_types import WKTBASES
-from . import extensions_pb2
-
-__version__ = "3.2.0"
-
-# SourceCodeLocation is defined by `message Location` here
-# https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
-SourceCodeLocation = List[int]
-
-# So phabricator doesn't think mypy_protobuf.py is generated
-GENERATED = "@ge" + "nerated"
-HEADER = f"""\"\"\"
-{GENERATED} by mypy-protobuf. Do not edit manually!
-isort:skip_file
-\"\"\"
-"""
-
-# See https://github.com/dropbox/mypy-protobuf/issues/73 for details
-PYTHON_RESERVED = {
- "False",
- "None",
- "True",
- "and",
- "as",
- "async",
- "await",
- "assert",
- "break",
- "class",
- "continue",
- "def",
- "del",
- "elif",
- "else",
- "except",
- "finally",
- "for",
- "from",
- "global",
- "if",
- "import",
- "in",
- "is",
- "lambda",
- "nonlocal",
- "not",
- "or",
- "pass",
- "raise",
- "return",
- "try",
- "while",
- "with",
- "yield",
-}
-
-PROTO_ENUM_RESERVED = {
- "Name",
- "Value",
- "keys",
- "values",
- "items",
-}
-
-
-def _mangle_global_identifier(name: str) -> str:
- """
- Module level identifiers are mangled and aliased so that they can be disambiguated
- from fields/enum variants with the same name within the file.
-
- Eg:
- Enum variant `Name` or message field `Name` might conflict with a top level
- message or enum named `Name`, so mangle it with a global___ prefix for
- internal references. Note that this doesn't affect inner enums/messages
- because they get fuly qualified when referenced within a file"""
- return f"global___{name}"
-
-
-class Descriptors(object):
- def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None:
- files = {f.name: f for f in request.proto_file}
- to_generate = {n: files[n] for n in request.file_to_generate}
- self.files: Dict[str, d.FileDescriptorProto] = files
- self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate
- self.messages: Dict[str, d.DescriptorProto] = {}
- self.message_to_fd: Dict[str, d.FileDescriptorProto] = {}
-
- def _add_enums(
- enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]",
- prefix: str,
- _fd: d.FileDescriptorProto,
- ) -> None:
- for enum in enums:
- self.message_to_fd[prefix + enum.name] = _fd
- self.message_to_fd[prefix + enum.name + ".ValueType"] = _fd
-
- def _add_messages(
- messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]",
- prefix: str,
- _fd: d.FileDescriptorProto,
- ) -> None:
- for message in messages:
- self.messages[prefix + message.name] = message
- self.message_to_fd[prefix + message.name] = _fd
- sub_prefix = prefix + message.name + "."
- _add_messages(message.nested_type, sub_prefix, _fd)
- _add_enums(message.enum_type, sub_prefix, _fd)
-
- for fd in request.proto_file:
- start_prefix = "." + fd.package + "." if fd.package else "."
- _add_messages(fd.message_type, start_prefix, fd)
- _add_enums(fd.enum_type, start_prefix, fd)
-
-
-class PkgWriter(object):
- """Writes a single pyi file"""
-
- def __init__(
- self,
- fd: d.FileDescriptorProto,
- descriptors: Descriptors,
- readable_stubs: bool,
- relax_strict_optional_primitives: bool,
- grpc: bool,
- ) -> None:
- self.fd = fd
- self.descriptors = descriptors
- self.readable_stubs = readable_stubs
- self.relax_strict_optional_primitives = relax_strict_optional_primitives
- self.grpc = grpc
- self.lines: List[str] = []
- self.indent = ""
-
- # Set of {x}, where {x} corresponds to to `import {x}`
- self.imports: Set[str] = set()
- # dictionary of x->(y,z) for `from {x} import {y} as {z}`
- # if {z} is None, then it shortens to `from {x} import {y}`
- self.from_imports: Dict[str, Set[Tuple[str, Optional[str]]]] = defaultdict(set)
-
- # Comments
- self.source_code_info_by_scl = {
- tuple(location.path): location for location in fd.source_code_info.location
- }
-
- def _import(self, path: str, name: str) -> str:
- """Imports a stdlib path and returns a handle to it
- eg. self._import("typing", "Optional") -> "Optional"
- """
- imp = path.replace("/", ".")
- if self.readable_stubs:
- self.from_imports[imp].add((name, None))
- return name
- else:
- self.imports.add(imp)
- return imp + "." + name
-
- def _import_message(self, name: str) -> str:
- """Import a referenced message and return a handle"""
- message_fd = self.descriptors.message_to_fd[name]
- assert message_fd.name.endswith(".proto")
-
- # Strip off package name
- if message_fd.package:
- assert name.startswith("." + message_fd.package + ".")
- name = name[len("." + message_fd.package + ".") :]
- else:
- assert name.startswith(".")
- name = name[1:]
-
- # Use prepended "_r_" to disambiguate message names that alias python reserved keywords
- split = name.split(".")
- for i, part in enumerate(split):
- if part in PYTHON_RESERVED:
- split[i] = "_r_" + part
- name = ".".join(split)
-
- # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files
- if not self.grpc and message_fd.name == self.fd.name:
- return name if self.readable_stubs else _mangle_global_identifier(name)
-
- # Not in file. Must import
- # Python generated code ignores proto packages, so the only relevant factor is
- # whether it is in the file or not.
- import_name = self._import(
- message_fd.name[:-6].replace("-", "_") + "_pb2", split[0]
- )
-
- remains = ".".join(split[1:])
- if not remains:
- return import_name
-
- # remains could either be a direct import of a nested enum or message
- # from another package.
- return import_name + "." + remains
-
- def _builtin(self, name: str) -> str:
- return self._import("builtins", name)
-
- @contextmanager
- def _indent(self) -> Iterator[None]:
- self.indent = self.indent + " "
- yield
- self.indent = self.indent[:-4]
-
- def _write_line(self, line: str, *args: Any) -> None:
- if args:
- line = line.format(*args)
- if line == "":
- self.lines.append(line)
- else:
- self.lines.append(self.indent + line)
-
- def _break_text(self, text_block: str) -> List[str]:
- if text_block == "":
- return []
- return [
- l[1:] if l.startswith(" ") else l for l in text_block.rstrip().split("\n")
- ]
-
- def _has_comments(self, scl: SourceCodeLocation) -> bool:
- sci_loc = self.source_code_info_by_scl.get(tuple(scl))
- return sci_loc is not None and bool(
- sci_loc.leading_detached_comments
- or sci_loc.leading_comments
- or sci_loc.trailing_comments
- )
-
- def _write_comments(self, scl: SourceCodeLocation) -> bool:
- """Return true if any comments were written"""
- if not self._has_comments(scl):
- return False
-
- sci_loc = self.source_code_info_by_scl.get(tuple(scl))
- assert sci_loc is not None
-
- lines = []
- for leading_detached_comment in sci_loc.leading_detached_comments:
- lines.extend(self._break_text(leading_detached_comment))
- lines.append("")
- if sci_loc.leading_comments is not None:
- lines.extend(self._break_text(sci_loc.leading_comments))
- # Trailing comments also go in the header - to make sure it gets into the docstring
- if sci_loc.trailing_comments is not None:
- lines.extend(self._break_text(sci_loc.trailing_comments))
-
- lines = [
- # Escape triple-quotes that would otherwise end the docstring early.
- line.replace("\\", "\\\\").replace('"""', r"\"\"\"")
- for line in lines
- ]
- if len(lines) == 1:
- line = lines[0]
- if line.endswith(('"', "\\")):
- # Docstrings are terminated with triple-quotes, so if the documentation itself ends in a quote,
- # insert some whitespace to separate it from the closing quotes.
- # This is not necessary with multiline comments
- # because in that case we always insert a newline before the trailing triple-quotes.
- line = line + " "
- self._write_line(f'"""{line}"""')
- else:
- for i, line in enumerate(lines):
- if i == 0:
- self._write_line(f'"""{line}')
- else:
- self._write_line(f"{line}")
- self._write_line('"""')
-
- return True
-
- def write_enum_values(
- self,
- values: Iterable[Tuple[int, d.EnumValueDescriptorProto]],
- value_type: str,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- for i, val in values:
- if val.name in PYTHON_RESERVED:
- continue
-
- scl = scl_prefix + [i]
- self._write_line(
- f"{val.name}: {value_type} # {val.number}",
- )
- if self._write_comments(scl):
- self._write_line("") # Extra newline to separate
-
- def write_module_attributes(self) -> None:
- l = self._write_line
- fd_type = self._import("google.protobuf.descriptor", "FileDescriptor")
- l(f"DESCRIPTOR: {fd_type}")
- l("")
-
- def write_enums(
- self,
- enums: Iterable[d.EnumDescriptorProto],
- prefix: str,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
- for i, enum in enumerate(enums):
- class_name = (
- enum.name if enum.name not in PYTHON_RESERVED else "_r_" + enum.name
- )
- value_type_fq = prefix + class_name + ".ValueType"
- enum_helper_class = "_" + enum.name
- value_type_helper_fq = prefix + enum_helper_class + ".ValueType"
- etw_helper_class = "_" + enum.name + "EnumTypeWrapper"
- scl = scl_prefix + [i]
-
- l(f"class {enum_helper_class}:")
- with self._indent():
- l(
- "ValueType = {}('ValueType', {})",
- self._import("typing", "NewType"),
- self._builtin("int"),
- )
- # Alias to the classic shorter definition "V"
- l("V: {} = ValueType", self._import("typing_extensions", "TypeAlias"))
- l(
- "class {}({}[{}], {}):",
- etw_helper_class,
- self._import(
- "google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"
- ),
- value_type_helper_fq,
- self._builtin("type"),
- )
- with self._indent():
- ed = self._import("google.protobuf.descriptor", "EnumDescriptor")
- l(f"DESCRIPTOR: {ed}")
- self.write_enum_values(
- [
- (i, v)
- for i, v in enumerate(enum.value)
- if v.name not in PROTO_ENUM_RESERVED
- ],
- value_type_helper_fq,
- scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
- )
- l(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}):")
- with self._indent():
- self._write_comments(scl)
- l("pass")
- l("")
-
- self.write_enum_values(
- enumerate(enum.value),
- value_type_fq,
- scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
- )
- if prefix == "" and not self.readable_stubs:
- l(f"{_mangle_global_identifier(class_name)} = {class_name}")
- l("")
- l("")
-
- def write_messages(
- self,
- messages: Iterable[d.DescriptorProto],
- prefix: str,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
-
- for i, desc in enumerate(messages):
- qualified_name = prefix + desc.name
-
- # Reproduce some hardcoded logic from the protobuf implementation - where
- # some specific "well_known_types" generated protos to have additional
- # base classes
- addl_base = ""
- if self.fd.package + "." + desc.name in WKTBASES:
- # chop off the .proto - and import the well known type
- # eg `from google.protobuf.duration import Duration`
- well_known_type = WKTBASES[self.fd.package + "." + desc.name]
- addl_base = ", " + self._import(
- "google.protobuf.internal.well_known_types",
- well_known_type.__name__,
- )
-
- class_name = (
- desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name
- )
- message_class = self._import("google.protobuf.message", "Message")
- l(f"class {class_name}({message_class}{addl_base}):")
- with self._indent():
- scl = scl_prefix + [i]
- self._write_comments(scl)
-
- desc_type = self._import("google.protobuf.descriptor", "Descriptor")
- l(f"DESCRIPTOR: {desc_type}")
-
- # Nested enums/messages
- self.write_enums(
- desc.enum_type,
- qualified_name + ".",
- scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER],
- )
- self.write_messages(
- desc.nested_type,
- qualified_name + ".",
- scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER],
- )
-
- # integer constants for field numbers
- for f in desc.field:
- l(f"{f.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
-
- for idx, field in enumerate(desc.field):
- if field.name in PYTHON_RESERVED:
- continue
- field_type = self.python_type(field)
-
- if (
- is_scalar(field)
- and field.label != d.FieldDescriptorProto.LABEL_REPEATED
- ):
- # Scalar non repeated fields are r/w
- l(f"{field.name}: {field_type}")
- if self._write_comments(
- scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
- ):
- l("")
- else:
- # r/o Getters for non-scalar fields and scalar-repeated fields
- scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
- l("@property")
- body = " ..." if not self._has_comments(scl_field) else ""
- l(f"def {field.name}(self) -> {field_type}:{body}")
- if self._has_comments(scl_field):
- with self._indent():
- self._write_comments(scl_field)
- l("pass")
-
- self.write_extensions(
- desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER]
- )
-
- # Constructor
- if any(f.name == "self" for f in desc.field):
- l("# pyright: reportSelfClsParameterName=false")
- l(f"def __init__(self_,")
- else:
- l(f"def __init__(self,")
- with self._indent():
- constructor_fields = [
- f for f in desc.field if f.name not in PYTHON_RESERVED
- ]
- if len(constructor_fields) > 0:
- # Only positional args allowed
- # See https://github.com/dropbox/mypy-protobuf/issues/71
- l("*,")
- for field in constructor_fields:
- field_type = self.python_type(field, generic_container=True)
- if (
- self.fd.syntax == "proto3"
- and is_scalar(field)
- and field.label != d.FieldDescriptorProto.LABEL_REPEATED
- and not self.relax_strict_optional_primitives
- and not field.proto3_optional
- ):
- l(f"{field.name}: {field_type} = ...,")
- else:
- opt = self._import("typing", "Optional")
- l(f"{field.name}: {opt}[{field_type}] = ...,")
- l(") -> None: ...")
-
- self.write_stringly_typed_fields(desc)
-
- if prefix == "" and not self.readable_stubs:
- l(f"{_mangle_global_identifier(class_name)} = {class_name}")
- l("")
-
- def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None:
- """Type the stringly-typed methods as a Union[Literal, Literal ...]"""
- l = self._write_line
- # HasField, ClearField, WhichOneof accepts both bytes/str
- # HasField only supports singular. ClearField supports repeated as well
- # In proto3, HasField only supports message fields and optional fields
- # HasField always supports oneof fields
- hf_fields = [
- f.name
- for f in desc.field
- if f.HasField("oneof_index")
- or (
- f.label != d.FieldDescriptorProto.LABEL_REPEATED
- and (
- self.fd.syntax != "proto3"
- or f.type == d.FieldDescriptorProto.TYPE_MESSAGE
- or f.proto3_optional
- )
- )
- ]
- cf_fields = [f.name for f in desc.field]
- wo_fields = {
- oneof.name: [
- f.name
- for f in desc.field
- if f.HasField("oneof_index") and f.oneof_index == idx
- ]
- for idx, oneof in enumerate(desc.oneof_decl)
- }
-
- hf_fields.extend(wo_fields.keys())
- cf_fields.extend(wo_fields.keys())
-
- hf_fields_text = ",".join(sorted(f'"{name}",b"{name}"' for name in hf_fields))
- cf_fields_text = ",".join(sorted(f'"{name}",b"{name}"' for name in cf_fields))
-
- if not hf_fields and not cf_fields and not wo_fields:
- return
-
- if hf_fields:
- l(
- "def HasField(self, field_name: {}[{}]) -> {}: ...",
- self._import("typing_extensions", "Literal"),
- hf_fields_text,
- self._builtin("bool"),
- )
- if cf_fields:
- l(
- "def ClearField(self, field_name: {}[{}]) -> None: ...",
- self._import("typing_extensions", "Literal"),
- cf_fields_text,
- )
-
- for wo_field, members in sorted(wo_fields.items()):
- if len(wo_fields) > 1:
- l("@{}", self._import("typing", "overload"))
- l(
- "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}[{}]]: ...",
- self._import("typing_extensions", "Literal"),
- # Accepts both str and bytes
- f'"{wo_field}",b"{wo_field}"',
- self._import("typing", "Optional"),
- self._import("typing_extensions", "Literal"),
- # Returns `str`
- ",".join(f'"{m}"' for m in members),
- )
-
- def write_extensions(
- self,
- extensions: Sequence[d.FieldDescriptorProto],
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
-
- for ext in extensions:
- l(f"{ext.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
-
- for i, ext in enumerate(extensions):
- scl = scl_prefix + [i]
-
- l(
- "{}: {}[{}, {}]",
- ext.name,
- self._import(
- "google.protobuf.internal.extension_dict",
- "_ExtensionFieldDescriptor",
- ),
- self._import_message(ext.extendee),
- self.python_type(ext),
- )
- self._write_comments(scl)
- l("")
-
- def write_methods(
- self,
- service: d.ServiceDescriptorProto,
- class_name: str,
- is_abstract: bool,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
- l(
- "DESCRIPTOR: {}",
- self._import("google.protobuf.descriptor", "ServiceDescriptor"),
- )
- methods = [
- (i, m)
- for i, m in enumerate(service.method)
- if m.name not in PYTHON_RESERVED
- ]
- if not methods:
- l("pass")
- for i, method in methods:
- if is_abstract:
- l("@{}", self._import("abc", "abstractmethod"))
- l(f"def {method.name}(")
- with self._indent():
- l(f"inst: {class_name},")
- l(
- "rpc_controller: {},",
- self._import("google.protobuf.service", "RpcController"),
- )
- l("request: {},", self._import_message(method.input_type))
- l(
- "callback: {}[{}[[{}], None]]{},",
- self._import("typing", "Optional"),
- self._import("typing", "Callable"),
- self._import_message(method.output_type),
- "" if is_abstract else " = None",
- )
-
- scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
- l(
- ") -> {}[{}]:{}",
- self._import("concurrent.futures", "Future"),
- self._import_message(method.output_type),
- " ..." if not self._has_comments(scl_method) else "",
- )
- if self._has_comments(scl_method):
- with self._indent():
- self._write_comments(scl_method)
- l("pass")
-
- def write_services(
- self,
- services: Iterable[d.ServiceDescriptorProto],
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
- for i, service in enumerate(services):
- scl = scl_prefix + [i]
- class_name = (
- service.name
- if service.name not in PYTHON_RESERVED
- else "_r_" + service.name
- )
- # The service definition interface
- l(
- "class {}({}, metaclass={}):",
- class_name,
- self._import("google.protobuf.service", "Service"),
- self._import("abc", "ABCMeta"),
- )
- with self._indent():
- self._write_comments(scl)
- self.write_methods(
- service, class_name, is_abstract=True, scl_prefix=scl
- )
-
- # The stub client
- stub_class_name = service.name + "_Stub"
- l("class {}({}):", stub_class_name, class_name)
- with self._indent():
- self._write_comments(scl)
- l(
- "def __init__(self, rpc_channel: {}) -> None: ...",
- self._import("google.protobuf.service", "RpcChannel"),
- )
- self.write_methods(
- service, stub_class_name, is_abstract=False, scl_prefix=scl
- )
-
- def _import_casttype(self, casttype: str) -> str:
- split = casttype.split(".")
- assert (
- len(split) == 2
- ), "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile"
- pkg = split[0].replace("/", ".")
- return self._import(pkg, split[1])
-
- def _map_key_value_types(
- self,
- map_field: d.FieldDescriptorProto,
- key_field: d.FieldDescriptorProto,
- value_field: d.FieldDescriptorProto,
- ) -> Tuple[str, str]:
- key_casttype = map_field.options.Extensions[extensions_pb2.keytype]
- ktype = (
- self._import_casttype(key_casttype)
- if key_casttype
- else self.python_type(key_field)
- )
- value_casttype = map_field.options.Extensions[extensions_pb2.valuetype]
- vtype = (
- self._import_casttype(value_casttype)
- if value_casttype
- else self.python_type(value_field)
- )
- return ktype, vtype
-
- def _callable_type(self, method: d.MethodDescriptorProto) -> str:
- if method.client_streaming:
- if method.server_streaming:
- return self._import("grpc", "StreamStreamMultiCallable")
- else:
- return self._import("grpc", "StreamUnaryMultiCallable")
- else:
- if method.server_streaming:
- return self._import("grpc", "UnaryStreamMultiCallable")
- else:
- return self._import("grpc", "UnaryUnaryMultiCallable")
-
- def _input_type(
- self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True
- ) -> str:
- result = self._import_message(method.input_type)
- if use_stream_iterator and method.client_streaming:
- result = f"{self._import('typing', 'Iterator')}[{result}]"
- return result
-
- def _output_type(
- self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True
- ) -> str:
- result = self._import_message(method.output_type)
- if use_stream_iterator and method.server_streaming:
- result = f"{self._import('typing', 'Iterator')}[{result}]"
- return result
-
- def write_grpc_methods(
- self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation
- ) -> None:
- l = self._write_line
- methods = [
- (i, m)
- for i, m in enumerate(service.method)
- if m.name not in PYTHON_RESERVED
- ]
- if not methods:
- l("pass")
- l("")
- for i, method in methods:
- scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
-
- l("@{}", self._import("abc", "abstractmethod"))
- l("def {}(self,", method.name)
- with self._indent():
- input_name = (
- "request_iterator" if method.client_streaming else "request"
- )
- input_type = self._input_type(method)
- l(f"{input_name}: {input_type},")
- l("context: {},", self._import("grpc", "ServicerContext"))
- l(
- ") -> {}:{}",
- self._output_type(method),
- " ..." if not self._has_comments(scl) else "",
- ),
- if self._has_comments(scl):
- with self._indent():
- self._write_comments(scl)
- l("pass")
- l("")
-
- def write_grpc_stub_methods(
- self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation
- ) -> None:
- l = self._write_line
- methods = [
- (i, m)
- for i, m in enumerate(service.method)
- if m.name not in PYTHON_RESERVED
- ]
- if not methods:
- l("pass")
- l("")
- for i, method in methods:
- scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
-
- l("{}: {}[", method.name, self._callable_type(method))
- with self._indent():
- l("{},", self._input_type(method, False))
- l("{}]", self._output_type(method, False))
- self._write_comments(scl)
- l("")
-
- def write_grpc_services(
- self,
- services: Iterable[d.ServiceDescriptorProto],
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
- for i, service in enumerate(services):
- if service.name in PYTHON_RESERVED:
- continue
-
- scl = scl_prefix + [i]
-
- # The stub client
- l(f"class {service.name}Stub:")
- with self._indent():
- self._write_comments(scl)
- l(
- "def __init__(self, channel: {}) -> None: ...",
- self._import("grpc", "Channel"),
- )
- self.write_grpc_stub_methods(service, scl)
- l("")
-
- # The service definition interface
- l(
- "class {}Servicer(metaclass={}):",
- service.name,
- self._import("abc", "ABCMeta"),
- )
- with self._indent():
- self._write_comments(scl)
- self.write_grpc_methods(service, scl)
- l("")
- l(
- "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
- service.name,
- service.name,
- self._import("grpc", "Server"),
- )
- l("")
-
- def python_type(
- self, field: d.FieldDescriptorProto, generic_container: bool = False
- ) -> str:
- """
- generic_container
- if set, type the field with generic interfaces. Eg.
- - Iterable[int] rather than RepeatedScalarFieldContainer[int]
- - Mapping[k, v] rather than MessageMap[k, v]
- Can be useful for input types (eg constructor)
- """
- casttype = field.options.Extensions[extensions_pb2.casttype]
- if casttype:
- return self._import_casttype(casttype)
-
- mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = {
- d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"),
- d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"),
- d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"),
- d.FieldDescriptorProto.TYPE_STRING: lambda: self._import("typing", "Text"),
- d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"),
- d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(
- field.type_name + ".ValueType"
- ),
- d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(
- field.type_name
- ),
- d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(
- field.type_name
- ),
- }
-
- assert field.type in mapping, "Unrecognized type: " + repr(field.type)
- field_type = mapping[field.type]()
-
- # For non-repeated fields, we're done!
- if field.label != d.FieldDescriptorProto.LABEL_REPEATED:
- return field_type
-
- # Scalar repeated fields go in RepeatedScalarFieldContainer
- if is_scalar(field):
- container = (
- self._import("typing", "Iterable")
- if generic_container
- else self._import(
- "google.protobuf.internal.containers",
- "RepeatedScalarFieldContainer",
- )
- )
- return f"{container}[{field_type}]"
-
- # non-scalar repeated map fields go in ScalarMap/MessageMap
- msg = self.descriptors.messages[field.type_name]
- if msg.options.map_entry:
- # map generates a special Entry wrapper message
- if generic_container:
- container = self._import("typing", "Mapping")
- elif is_scalar(msg.field[1]):
- container = self._import(
- "google.protobuf.internal.containers", "ScalarMap"
- )
- else:
- container = self._import(
- "google.protobuf.internal.containers", "MessageMap"
- )
- ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1])
- return f"{container}[{ktype}, {vtype}]"
-
- # non-scalar repetated fields go in RepeatedCompositeFieldContainer
- container = (
- self._import("typing", "Iterable")
- if generic_container
- else self._import(
- "google.protobuf.internal.containers",
- "RepeatedCompositeFieldContainer",
- )
- )
- return f"{container}[{field_type}]"
-
- def write(self) -> str:
- for reexport_idx in self.fd.public_dependency:
- reexport_file = self.fd.dependency[reexport_idx]
- reexport_fd = self.descriptors.files[reexport_file]
- reexport_imp = (
- reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2"
- )
- names = (
- [m.name for m in reexport_fd.message_type]
- + [m.name for m in reexport_fd.enum_type]
- + [v.name for m in reexport_fd.enum_type for v in m.value]
- + [m.name for m in reexport_fd.extension]
- )
- if reexport_fd.options.py_generic_services:
- names.extend(m.name for m in reexport_fd.service)
-
- if names:
- # n,n to force a reexport (from x import y as y)
- self.from_imports[reexport_imp].update((n, n) for n in names)
-
- import_lines = []
- for pkg in sorted(self.imports):
- import_lines.append(f"import {pkg}")
-
- for pkg, items in sorted(self.from_imports.items()):
- import_lines.append(f"from {pkg} import (")
- for (name, reexport_name) in sorted(items):
- if reexport_name is None:
- import_lines.append(f" {name},")
- else:
- import_lines.append(f" {name} as {reexport_name},")
- import_lines.append(")\n")
- import_lines.append("")
-
- return "\n".join(import_lines + self.lines)
-
-
-def is_scalar(fd: d.FieldDescriptorProto) -> bool:
- return not (
- fd.type == d.FieldDescriptorProto.TYPE_MESSAGE
- or fd.type == d.FieldDescriptorProto.TYPE_GROUP
- )
-
-
-def generate_mypy_stubs(
- descriptors: Descriptors,
- response: plugin_pb2.CodeGeneratorResponse,
- quiet: bool,
- readable_stubs: bool,
- relax_strict_optional_primitives: bool,
-) -> None:
- for name, fd in descriptors.to_generate.items():
- pkg_writer = PkgWriter(
- fd,
- descriptors,
- readable_stubs,
- relax_strict_optional_primitives,
- grpc=False,
- )
-
- pkg_writer.write_module_attributes()
- pkg_writer.write_enums(
- fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER]
- )
- pkg_writer.write_messages(
- fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER]
- )
- pkg_writer.write_extensions(
- fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER]
- )
- if fd.options.py_generic_services:
- pkg_writer.write_services(
- fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]
- )
-
- assert name == fd.name
- assert fd.name.endswith(".proto")
- output = response.file.add()
- output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi"
- output.content = HEADER + pkg_writer.write()
-
-
-def generate_mypy_grpc_stubs(
- descriptors: Descriptors,
- response: plugin_pb2.CodeGeneratorResponse,
- quiet: bool,
- readable_stubs: bool,
- relax_strict_optional_primitives: bool,
-) -> None:
- for name, fd in descriptors.to_generate.items():
- pkg_writer = PkgWriter(
- fd,
- descriptors,
- readable_stubs,
- relax_strict_optional_primitives,
- grpc=True,
- )
- pkg_writer.write_grpc_services(
- fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]
- )
-
- assert name == fd.name
- assert fd.name.endswith(".proto")
- output = response.file.add()
- output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi"
- output.content = HEADER + pkg_writer.write()
-
-
-@contextmanager
-def code_generation() -> Iterator[
- Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],
-]:
- if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"):
- print("mypy-protobuf " + __version__)
- sys.exit(0)
-
- # Read request message from stdin
- data = sys.stdin.buffer.read()
-
- # Parse request
- request = plugin_pb2.CodeGeneratorRequest()
- request.ParseFromString(data)
-
- # Create response
- response = plugin_pb2.CodeGeneratorResponse()
-
- # Declare support for optional proto3 fields
- response.supported_features |= (
- plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
- )
-
- yield request, response
-
- # Serialise response message
- output = response.SerializeToString()
-
- # Write to stdout
- sys.stdout.buffer.write(output)
-
-
-def main() -> None:
- # Generate mypy
- with code_generation() as (request, response):
- generate_mypy_stubs(
- Descriptors(request),
- response,
- "quiet" in request.parameter,
- "readable_stubs" in request.parameter,
- "relax_strict_optional_primitives" in request.parameter,
- )
-
-
-def grpc() -> None:
- # Generate grpc mypy
- with code_generation() as (request, response):
- generate_mypy_grpc_stubs(
- Descriptors(request),
- response,
- "quiet" in request.parameter,
- "readable_stubs" in request.parameter,
- "relax_strict_optional_primitives" in request.parameter,
- )
-
-
-if __name__ == "__main__":
- main()