summaryrefslogtreecommitdiffstats
path: root/contrib/python/protobuf/py3/patches/pr15999-register-as-virtual-subclasses.patch
blob: 05d00c02fe7a220e92d3777d62e5517ef53d5f8b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
index 7d690c1..2d75c09 100644
--- a/python/google/protobuf/pyext/map_container.cc
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -53,7 +53,7 @@ namespace python {
 class MapReflectionFriend {
  public:
   // Methods that are in common between the map types.
-  static PyObject* Contains(PyObject* _self, PyObject* key);
+  static int Contains(PyObject* _self, PyObject* key);
   static Py_ssize_t Length(PyObject* _self);
   static PyObject* GetIterator(PyObject *_self);
   static PyObject* IterNext(PyObject* _self);
@@ -352,7 +352,7 @@ PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) {
   Py_RETURN_NONE;
 }
 
-PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
+int MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
   MapContainer* self = GetMap(_self);
 
   const Message* message = self->parent->message;
@@ -360,14 +360,14 @@ PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
   MapKey map_key;
 
   if (!PythonToMapKey(self, key, &map_key)) {
-    return nullptr;
+    return -1;
   }
 
   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
                                  map_key)) {
-    Py_RETURN_TRUE;
+    return 1;
   } else {
-    Py_RETURN_FALSE;
+    return 0;
   }
 }
 
@@ -466,12 +466,12 @@ static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
     return nullptr;
   }
 
-  ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
-  if (is_present.get() == nullptr) {
+  auto is_present = MapReflectionFriend::Contains(self, key);
+  if (is_present < 0) {
     return nullptr;
   }
 
-  if (PyObject_IsTrue(is_present.get())) {
+  if (is_present == 1) {
     return MapReflectionFriend::ScalarMapGetItem(self, key);
   } else {
     if (default_value != nullptr) {
@@ -525,8 +525,6 @@ static void ScalarMapDealloc(PyObject* _self) {
 }
 
 static PyMethodDef ScalarMapMethods[] = {
-    {"__contains__", MapReflectionFriend::Contains, METH_O,
-     "Tests whether a key is a member of the map."},
     {"clear", (PyCFunction)Clear, METH_NOARGS,
      "Removes all elements from the map."},
     {"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS,
@@ -551,6 +549,7 @@ static PyType_Slot ScalarMapContainer_Type_slots[] = {
     {Py_mp_subscript, (void*)MapReflectionFriend::ScalarMapGetItem},
     {Py_mp_ass_subscript, (void*)MapReflectionFriend::ScalarMapSetItem},
     {Py_tp_methods, (void*)ScalarMapMethods},
+    {Py_sq_contains, (void*)MapReflectionFriend::Contains},
     {Py_tp_iter, (void*)MapReflectionFriend::GetIterator},
     {Py_tp_repr, (void*)MapReflectionFriend::ScalarMapToStr},
     {0, nullptr},
@@ -710,12 +709,12 @@ PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
     return nullptr;
   }
 
-  ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
-  if (is_present.get() == nullptr) {
+  auto is_present = MapReflectionFriend::Contains(self, key);
+  if (is_present < 0) {
     return nullptr;
   }
 
-  if (PyObject_IsTrue(is_present.get())) {
+  if (is_present == 1) {
     return MapReflectionFriend::MessageMapGetItem(self, key);
   } else {
     if (default_value != nullptr) {
@@ -740,8 +739,6 @@ static void MessageMapDealloc(PyObject* _self) {
 }
 
 static PyMethodDef MessageMapMethods[] = {
-    {"__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
-     "Tests whether the map contains this element."},
     {"clear", (PyCFunction)Clear, METH_NOARGS,
      "Removes all elements from the map."},
     {"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS,
@@ -768,6 +765,7 @@ static PyType_Slot MessageMapContainer_Type_slots[] = {
     {Py_mp_subscript, (void*)MapReflectionFriend::MessageMapGetItem},
     {Py_mp_ass_subscript, (void*)MapReflectionFriend::MessageMapSetItem},
     {Py_tp_methods, (void*)MessageMapMethods},
+    {Py_sq_contains, (void*)MapReflectionFriend::Contains},
     {Py_tp_iter, (void*)MapReflectionFriend::GetIterator},
     {Py_tp_repr, (void*)MapReflectionFriend::MessageMapToStr},
     {0, nullptr}};
@@ -893,6 +891,33 @@ PyTypeObject MapIterator_Type = {
     nullptr,                        //  tp_init
 };
 
+
+PyTypeObject* PyUpb_AddClassWithRegister(PyType_Spec* spec,
+                                         PyObject* virtual_base,
+                                         const char** methods) {
+  PyObject* type = PyType_FromSpec(spec);
+  PyObject* ret1 = PyObject_CallMethod(virtual_base, "register", "O", type);
+  if (!ret1) {
+    Py_XDECREF(type);
+    return NULL;
+  }
+  for (size_t i = 0; methods[i] != NULL; i++) {
+    PyObject* method = PyObject_GetAttrString(virtual_base, methods[i]);
+    if (!method) {
+      Py_XDECREF(type);
+      return NULL;
+    }
+    int ret2 = PyObject_SetAttrString(type, methods[i], method);
+    if (ret2 < 0) {
+      Py_XDECREF(type);
+      return NULL;
+    }
+  }
+
+  return (PyTypeObject*)type;
+}
+
+
 bool InitMapContainers() {
   // ScalarMapContainer_Type derives from our MutableMapping type.
   ScopedPyObjectPtr abc(PyImport_ImportModule("collections.abc"));
@@ -907,20 +932,23 @@ bool InitMapContainers() {
   }
 
   Py_INCREF(mutable_mapping.get());
-  ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get()));
+  ScopedPyObjectPtr bases(Py_BuildValue("O", mutable_mapping.get()));
   if (bases == nullptr) {
     return false;
   }
 
+  const char* methods[] = {"keys", "items",   "values", "__eq__",     "__ne__",
+                           "pop",  "popitem", "update", "setdefault", NULL};
+
   ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
-      PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get()));
+      PyUpb_AddClassWithRegister(&ScalarMapContainer_Type_spec, bases.get(), methods));
 
   if (PyType_Ready(&MapIterator_Type) < 0) {
     return false;
   }
 
   MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
-      PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get()));
+      PyUpb_AddClassWithRegister(&MessageMapContainer_Type_spec, bases.get(), methods));
   return true;
 }