1// Protocol Buffers - Google's data interchange format
2// Copyright 2008 Google Inc.  All rights reserved.
3// https://developers.google.com/protocol-buffers/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are
7// met:
8//
9//     * Redistributions of source code must retain the above copyright
10// notice, this list of conditions and the following disclaimer.
11//     * Redistributions in binary form must reproduce the above
12// copyright notice, this list of conditions and the following disclaimer
13// in the documentation and/or other materials provided with the
14// distribution.
15//     * Neither the name of Google Inc. nor the names of its
16// contributors may be used to endorse or promote products derived from
17// this software without specific prior written permission.
18//
19// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31#include <unordered_map>
32
33#include <Python.h>
34
35#include <google/protobuf/dynamic_message.h>
36#include <google/protobuf/pyext/descriptor.h>
37#include <google/protobuf/pyext/message.h>
38#include <google/protobuf/pyext/message_factory.h>
39#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
40
41#if PY_MAJOR_VERSION >= 3
42  #if PY_VERSION_HEX < 0x03030000
43    #error "Python 3.0 - 3.2 are not supported."
44  #endif
45  #define PyString_AsStringAndSize(ob, charpp, sizep) \
46    (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>(                   \
47                               PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
48                              ? -1                                            \
49                              : 0)                                            \
50                        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
51#endif
52
53namespace google {
54namespace protobuf {
55namespace python {
56
57namespace message_factory {
58
59PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
60  PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
61      PyType_GenericAlloc(type, 0));
62  if (factory == NULL) {
63    return NULL;
64  }
65
66  DynamicMessageFactory* message_factory = new DynamicMessageFactory();
67  // This option might be the default some day.
68  message_factory->SetDelegateToGeneratedFactory(true);
69  factory->message_factory = message_factory;
70
71  factory->pool = pool;
72  Py_INCREF(pool);
73
74  factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
75
76  return factory;
77}
78
79PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
80  static char* kwlist[] = {"pool", 0};
81  PyObject* pool = NULL;
82  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &pool)) {
83    return NULL;
84  }
85  ScopedPyObjectPtr owned_pool;
86  if (pool == NULL || pool == Py_None) {
87    owned_pool.reset(PyObject_CallFunction(
88        reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), NULL));
89    if (owned_pool == NULL) {
90      return NULL;
91    }
92    pool = owned_pool.get();
93  } else {
94    if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) {
95      PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s",
96                   pool->ob_type->tp_name);
97      return NULL;
98    }
99  }
100
101  return reinterpret_cast<PyObject*>(
102      NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool)));
103}
104
105static void Dealloc(PyObject* pself) {
106  PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
107
108  typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
109  for (iterator it = self->classes_by_descriptor->begin();
110       it != self->classes_by_descriptor->end(); ++it) {
111    Py_CLEAR(it->second);
112  }
113  delete self->classes_by_descriptor;
114  delete self->message_factory;
115  Py_CLEAR(self->pool);
116  Py_TYPE(self)->tp_free(pself);
117}
118
119static int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
120  PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
121  Py_VISIT(self->pool);
122  for (const auto& desc_and_class : *self->classes_by_descriptor) {
123    Py_VISIT(desc_and_class.second);
124  }
125  return 0;
126}
127
128static int GcClear(PyObject* pself) {
129  PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
130  // Here it's important to not clear self->pool, so that the C++ DescriptorPool
131  // is still alive when self->message_factory is destructed.
132  for (auto& desc_and_class : *self->classes_by_descriptor) {
133    Py_CLEAR(desc_and_class.second);
134  }
135
136  return 0;
137}
138
139// Add a message class to our database.
140int RegisterMessageClass(PyMessageFactory* self,
141                         const Descriptor* message_descriptor,
142                         CMessageClass* message_class) {
143  Py_INCREF(message_class);
144  typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
145  std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
146      std::make_pair(message_descriptor, message_class));
147  if (!ret.second) {
148    // Update case: DECREF the previous value.
149    Py_DECREF(ret.first->second);
150    ret.first->second = message_class;
151  }
152  return 0;
153}
154
155CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
156                                       const Descriptor* descriptor) {
157  // This is the same implementation as MessageFactory.GetPrototype().
158
159  // Do not create a MessageClass that already exists.
160  std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
161      self->classes_by_descriptor->find(descriptor);
162  if (it != self->classes_by_descriptor->end()) {
163    Py_INCREF(it->second);
164    return it->second;
165  }
166  ScopedPyObjectPtr py_descriptor(
167      PyMessageDescriptor_FromDescriptor(descriptor));
168  if (py_descriptor == NULL) {
169    return NULL;
170  }
171  // Create a new message class.
172  ScopedPyObjectPtr args(Py_BuildValue(
173      "s(){sOsOsO}", descriptor->name().c_str(),
174      "DESCRIPTOR", py_descriptor.get(),
175      "__module__", Py_None,
176      "message_factory", self));
177  if (args == NULL) {
178    return NULL;
179  }
180  ScopedPyObjectPtr message_class(PyObject_CallObject(
181      reinterpret_cast<PyObject*>(CMessageClass_Type), args.get()));
182  if (message_class == NULL) {
183    return NULL;
184  }
185  // Create messages class for the messages used by the fields, and registers
186  // all extensions for these messages during the recursion.
187  for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
188    const Descriptor* sub_descriptor =
189        descriptor->field(field_idx)->message_type();
190    // It is NULL if the field type is not a message.
191    if (sub_descriptor != NULL) {
192      CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
193      if (result == NULL) {
194        return NULL;
195      }
196      Py_DECREF(result);
197    }
198  }
199
200  // Register extensions defined in this message.
201  for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
202    const FieldDescriptor* extension = descriptor->extension(ext_idx);
203    ScopedPyObjectPtr py_extended_class(
204        GetOrCreateMessageClass(self, extension->containing_type())
205            ->AsPyObject());
206    if (py_extended_class == NULL) {
207      return NULL;
208    }
209    ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
210    if (py_extension == NULL) {
211      return NULL;
212    }
213    ScopedPyObjectPtr result(cmessage::RegisterExtension(
214        py_extended_class.get(), py_extension.get()));
215    if (result == NULL) {
216      return NULL;
217    }
218  }
219  return reinterpret_cast<CMessageClass*>(message_class.release());
220}
221
222// Retrieve the message class added to our database.
223CMessageClass* GetMessageClass(PyMessageFactory* self,
224                               const Descriptor* message_descriptor) {
225  typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
226  iterator ret = self->classes_by_descriptor->find(message_descriptor);
227  if (ret == self->classes_by_descriptor->end()) {
228    PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
229                 message_descriptor->full_name().c_str());
230    return NULL;
231  } else {
232    return ret->second;
233  }
234}
235
236static PyMethodDef Methods[] = {
237    {NULL}};
238
239static PyObject* GetPool(PyMessageFactory* self, void* closure) {
240  Py_INCREF(self->pool);
241  return reinterpret_cast<PyObject*>(self->pool);
242}
243
244static PyGetSetDef Getters[] = {
245    {"pool", (getter)GetPool, NULL, "DescriptorPool"},
246    {NULL}
247};
248
249}  // namespace message_factory
250
251PyTypeObject PyMessageFactory_Type = {
252    PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
253    ".MessageFactory",         // tp_name
254    sizeof(PyMessageFactory),  // tp_basicsize
255    0,                         // tp_itemsize
256    message_factory::Dealloc,  // tp_dealloc
257    0,                         // tp_print
258    0,                         // tp_getattr
259    0,                         // tp_setattr
260    0,                         // tp_compare
261    0,                         // tp_repr
262    0,                         // tp_as_number
263    0,                         // tp_as_sequence
264    0,                         // tp_as_mapping
265    0,                         // tp_hash
266    0,                         // tp_call
267    0,                         // tp_str
268    0,                         // tp_getattro
269    0,                         // tp_setattro
270    0,                         // tp_as_buffer
271    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,  // tp_flags
272    "A static Message Factory",                                     // tp_doc
273    message_factory::GcTraverse,  // tp_traverse
274    message_factory::GcClear,     // tp_clear
275    0,                            // tp_richcompare
276    0,                            // tp_weaklistoffset
277    0,                            // tp_iter
278    0,                            // tp_iternext
279    message_factory::Methods,     // tp_methods
280    0,                            // tp_members
281    message_factory::Getters,     // tp_getset
282    0,                            // tp_base
283    0,                            // tp_dict
284    0,                            // tp_descr_get
285    0,                            // tp_descr_set
286    0,                            // tp_dictoffset
287    0,                            // tp_init
288    0,                            // tp_alloc
289    message_factory::New,         // tp_new
290    PyObject_GC_Del,              // tp_free
291};
292
293bool InitMessageFactory() {
294  if (PyType_Ready(&PyMessageFactory_Type) < 0) {
295    return false;
296  }
297
298  return true;
299}
300
301}  // namespace python
302}  // namespace protobuf
303}  // namespace google
304