1ffe3c632Sopenharmony_ci// Protocol Buffers - Google's data interchange format
2ffe3c632Sopenharmony_ci// Copyright 2008 Google Inc.  All rights reserved.
3ffe3c632Sopenharmony_ci// https://developers.google.com/protocol-buffers/
4ffe3c632Sopenharmony_ci//
5ffe3c632Sopenharmony_ci// Redistribution and use in source and binary forms, with or without
6ffe3c632Sopenharmony_ci// modification, are permitted provided that the following conditions are
7ffe3c632Sopenharmony_ci// met:
8ffe3c632Sopenharmony_ci//
9ffe3c632Sopenharmony_ci//     * Redistributions of source code must retain the above copyright
10ffe3c632Sopenharmony_ci// notice, this list of conditions and the following disclaimer.
11ffe3c632Sopenharmony_ci//     * Redistributions in binary form must reproduce the above
12ffe3c632Sopenharmony_ci// copyright notice, this list of conditions and the following disclaimer
13ffe3c632Sopenharmony_ci// in the documentation and/or other materials provided with the
14ffe3c632Sopenharmony_ci// distribution.
15ffe3c632Sopenharmony_ci//     * Neither the name of Google Inc. nor the names of its
16ffe3c632Sopenharmony_ci// contributors may be used to endorse or promote products derived from
17ffe3c632Sopenharmony_ci// this software without specific prior written permission.
18ffe3c632Sopenharmony_ci//
19ffe3c632Sopenharmony_ci// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20ffe3c632Sopenharmony_ci// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21ffe3c632Sopenharmony_ci// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22ffe3c632Sopenharmony_ci// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23ffe3c632Sopenharmony_ci// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24ffe3c632Sopenharmony_ci// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25ffe3c632Sopenharmony_ci// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26ffe3c632Sopenharmony_ci// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27ffe3c632Sopenharmony_ci// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28ffe3c632Sopenharmony_ci// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29ffe3c632Sopenharmony_ci// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30ffe3c632Sopenharmony_ci
31ffe3c632Sopenharmony_ci#include <unordered_map>
32ffe3c632Sopenharmony_ci
33ffe3c632Sopenharmony_ci#include <Python.h>
34ffe3c632Sopenharmony_ci
35ffe3c632Sopenharmony_ci#include <google/protobuf/dynamic_message.h>
36ffe3c632Sopenharmony_ci#include <google/protobuf/pyext/descriptor.h>
37ffe3c632Sopenharmony_ci#include <google/protobuf/pyext/message.h>
38ffe3c632Sopenharmony_ci#include <google/protobuf/pyext/message_factory.h>
39ffe3c632Sopenharmony_ci#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
40ffe3c632Sopenharmony_ci
41ffe3c632Sopenharmony_ci#if PY_MAJOR_VERSION >= 3
42ffe3c632Sopenharmony_ci  #if PY_VERSION_HEX < 0x03030000
43ffe3c632Sopenharmony_ci    #error "Python 3.0 - 3.2 are not supported."
44ffe3c632Sopenharmony_ci  #endif
45ffe3c632Sopenharmony_ci  #define PyString_AsStringAndSize(ob, charpp, sizep) \
46ffe3c632Sopenharmony_ci    (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>(                   \
47ffe3c632Sopenharmony_ci                               PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
48ffe3c632Sopenharmony_ci                              ? -1                                            \
49ffe3c632Sopenharmony_ci                              : 0)                                            \
50ffe3c632Sopenharmony_ci                        : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
51ffe3c632Sopenharmony_ci#endif
52ffe3c632Sopenharmony_ci
53ffe3c632Sopenharmony_cinamespace google {
54ffe3c632Sopenharmony_cinamespace protobuf {
55ffe3c632Sopenharmony_cinamespace python {
56ffe3c632Sopenharmony_ci
57ffe3c632Sopenharmony_cinamespace message_factory {
58ffe3c632Sopenharmony_ci
59ffe3c632Sopenharmony_ciPyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
60ffe3c632Sopenharmony_ci  PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
61ffe3c632Sopenharmony_ci      PyType_GenericAlloc(type, 0));
62ffe3c632Sopenharmony_ci  if (factory == NULL) {
63ffe3c632Sopenharmony_ci    return NULL;
64ffe3c632Sopenharmony_ci  }
65ffe3c632Sopenharmony_ci
66ffe3c632Sopenharmony_ci  DynamicMessageFactory* message_factory = new DynamicMessageFactory();
67ffe3c632Sopenharmony_ci  // This option might be the default some day.
68ffe3c632Sopenharmony_ci  message_factory->SetDelegateToGeneratedFactory(true);
69ffe3c632Sopenharmony_ci  factory->message_factory = message_factory;
70ffe3c632Sopenharmony_ci
71ffe3c632Sopenharmony_ci  factory->pool = pool;
72ffe3c632Sopenharmony_ci  Py_INCREF(pool);
73ffe3c632Sopenharmony_ci
74ffe3c632Sopenharmony_ci  factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
75ffe3c632Sopenharmony_ci
76ffe3c632Sopenharmony_ci  return factory;
77ffe3c632Sopenharmony_ci}
78ffe3c632Sopenharmony_ci
79ffe3c632Sopenharmony_ciPyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
80ffe3c632Sopenharmony_ci  static char* kwlist[] = {"pool", 0};
81ffe3c632Sopenharmony_ci  PyObject* pool = NULL;
82ffe3c632Sopenharmony_ci  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &pool)) {
83ffe3c632Sopenharmony_ci    return NULL;
84ffe3c632Sopenharmony_ci  }
85ffe3c632Sopenharmony_ci  ScopedPyObjectPtr owned_pool;
86ffe3c632Sopenharmony_ci  if (pool == NULL || pool == Py_None) {
87ffe3c632Sopenharmony_ci    owned_pool.reset(PyObject_CallFunction(
88ffe3c632Sopenharmony_ci        reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), NULL));
89ffe3c632Sopenharmony_ci    if (owned_pool == NULL) {
90ffe3c632Sopenharmony_ci      return NULL;
91ffe3c632Sopenharmony_ci    }
92ffe3c632Sopenharmony_ci    pool = owned_pool.get();
93ffe3c632Sopenharmony_ci  } else {
94ffe3c632Sopenharmony_ci    if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) {
95ffe3c632Sopenharmony_ci      PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s",
96ffe3c632Sopenharmony_ci                   pool->ob_type->tp_name);
97ffe3c632Sopenharmony_ci      return NULL;
98ffe3c632Sopenharmony_ci    }
99ffe3c632Sopenharmony_ci  }
100ffe3c632Sopenharmony_ci
101ffe3c632Sopenharmony_ci  return reinterpret_cast<PyObject*>(
102ffe3c632Sopenharmony_ci      NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool)));
103ffe3c632Sopenharmony_ci}
104ffe3c632Sopenharmony_ci
105ffe3c632Sopenharmony_cistatic void Dealloc(PyObject* pself) {
106ffe3c632Sopenharmony_ci  PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
107ffe3c632Sopenharmony_ci
108ffe3c632Sopenharmony_ci  typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
109ffe3c632Sopenharmony_ci  for (iterator it = self->classes_by_descriptor->begin();
110ffe3c632Sopenharmony_ci       it != self->classes_by_descriptor->end(); ++it) {
111ffe3c632Sopenharmony_ci    Py_CLEAR(it->second);
112ffe3c632Sopenharmony_ci  }
113ffe3c632Sopenharmony_ci  delete self->classes_by_descriptor;
114ffe3c632Sopenharmony_ci  delete self->message_factory;
115ffe3c632Sopenharmony_ci  Py_CLEAR(self->pool);
116ffe3c632Sopenharmony_ci  Py_TYPE(self)->tp_free(pself);
117ffe3c632Sopenharmony_ci}
118ffe3c632Sopenharmony_ci
119ffe3c632Sopenharmony_cistatic int GcTraverse(PyObject* pself, visitproc visit, void* arg) {
120ffe3c632Sopenharmony_ci  PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
121ffe3c632Sopenharmony_ci  Py_VISIT(self->pool);
122ffe3c632Sopenharmony_ci  for (const auto& desc_and_class : *self->classes_by_descriptor) {
123ffe3c632Sopenharmony_ci    Py_VISIT(desc_and_class.second);
124ffe3c632Sopenharmony_ci  }
125ffe3c632Sopenharmony_ci  return 0;
126ffe3c632Sopenharmony_ci}
127ffe3c632Sopenharmony_ci
128ffe3c632Sopenharmony_cistatic int GcClear(PyObject* pself) {
129ffe3c632Sopenharmony_ci  PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
130ffe3c632Sopenharmony_ci  // Here it's important to not clear self->pool, so that the C++ DescriptorPool
131ffe3c632Sopenharmony_ci  // is still alive when self->message_factory is destructed.
132ffe3c632Sopenharmony_ci  for (auto& desc_and_class : *self->classes_by_descriptor) {
133ffe3c632Sopenharmony_ci    Py_CLEAR(desc_and_class.second);
134ffe3c632Sopenharmony_ci  }
135ffe3c632Sopenharmony_ci
136ffe3c632Sopenharmony_ci  return 0;
137ffe3c632Sopenharmony_ci}
138ffe3c632Sopenharmony_ci
139ffe3c632Sopenharmony_ci// Add a message class to our database.
140ffe3c632Sopenharmony_ciint RegisterMessageClass(PyMessageFactory* self,
141ffe3c632Sopenharmony_ci                         const Descriptor* message_descriptor,
142ffe3c632Sopenharmony_ci                         CMessageClass* message_class) {
143ffe3c632Sopenharmony_ci  Py_INCREF(message_class);
144ffe3c632Sopenharmony_ci  typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
145ffe3c632Sopenharmony_ci  std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
146ffe3c632Sopenharmony_ci      std::make_pair(message_descriptor, message_class));
147ffe3c632Sopenharmony_ci  if (!ret.second) {
148ffe3c632Sopenharmony_ci    // Update case: DECREF the previous value.
149ffe3c632Sopenharmony_ci    Py_DECREF(ret.first->second);
150ffe3c632Sopenharmony_ci    ret.first->second = message_class;
151ffe3c632Sopenharmony_ci  }
152ffe3c632Sopenharmony_ci  return 0;
153ffe3c632Sopenharmony_ci}
154ffe3c632Sopenharmony_ci
155ffe3c632Sopenharmony_ciCMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
156ffe3c632Sopenharmony_ci                                       const Descriptor* descriptor) {
157ffe3c632Sopenharmony_ci  // This is the same implementation as MessageFactory.GetPrototype().
158ffe3c632Sopenharmony_ci
159ffe3c632Sopenharmony_ci  // Do not create a MessageClass that already exists.
160ffe3c632Sopenharmony_ci  std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
161ffe3c632Sopenharmony_ci      self->classes_by_descriptor->find(descriptor);
162ffe3c632Sopenharmony_ci  if (it != self->classes_by_descriptor->end()) {
163ffe3c632Sopenharmony_ci    Py_INCREF(it->second);
164ffe3c632Sopenharmony_ci    return it->second;
165ffe3c632Sopenharmony_ci  }
166ffe3c632Sopenharmony_ci  ScopedPyObjectPtr py_descriptor(
167ffe3c632Sopenharmony_ci      PyMessageDescriptor_FromDescriptor(descriptor));
168ffe3c632Sopenharmony_ci  if (py_descriptor == NULL) {
169ffe3c632Sopenharmony_ci    return NULL;
170ffe3c632Sopenharmony_ci  }
171ffe3c632Sopenharmony_ci  // Create a new message class.
172ffe3c632Sopenharmony_ci  ScopedPyObjectPtr args(Py_BuildValue(
173ffe3c632Sopenharmony_ci      "s(){sOsOsO}", descriptor->name().c_str(),
174ffe3c632Sopenharmony_ci      "DESCRIPTOR", py_descriptor.get(),
175ffe3c632Sopenharmony_ci      "__module__", Py_None,
176ffe3c632Sopenharmony_ci      "message_factory", self));
177ffe3c632Sopenharmony_ci  if (args == NULL) {
178ffe3c632Sopenharmony_ci    return NULL;
179ffe3c632Sopenharmony_ci  }
180ffe3c632Sopenharmony_ci  ScopedPyObjectPtr message_class(PyObject_CallObject(
181ffe3c632Sopenharmony_ci      reinterpret_cast<PyObject*>(CMessageClass_Type), args.get()));
182ffe3c632Sopenharmony_ci  if (message_class == NULL) {
183ffe3c632Sopenharmony_ci    return NULL;
184ffe3c632Sopenharmony_ci  }
185ffe3c632Sopenharmony_ci  // Create messages class for the messages used by the fields, and registers
186ffe3c632Sopenharmony_ci  // all extensions for these messages during the recursion.
187ffe3c632Sopenharmony_ci  for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
188ffe3c632Sopenharmony_ci    const Descriptor* sub_descriptor =
189ffe3c632Sopenharmony_ci        descriptor->field(field_idx)->message_type();
190ffe3c632Sopenharmony_ci    // It is NULL if the field type is not a message.
191ffe3c632Sopenharmony_ci    if (sub_descriptor != NULL) {
192ffe3c632Sopenharmony_ci      CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
193ffe3c632Sopenharmony_ci      if (result == NULL) {
194ffe3c632Sopenharmony_ci        return NULL;
195ffe3c632Sopenharmony_ci      }
196ffe3c632Sopenharmony_ci      Py_DECREF(result);
197ffe3c632Sopenharmony_ci    }
198ffe3c632Sopenharmony_ci  }
199ffe3c632Sopenharmony_ci
200ffe3c632Sopenharmony_ci  // Register extensions defined in this message.
201ffe3c632Sopenharmony_ci  for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
202ffe3c632Sopenharmony_ci    const FieldDescriptor* extension = descriptor->extension(ext_idx);
203ffe3c632Sopenharmony_ci    ScopedPyObjectPtr py_extended_class(
204ffe3c632Sopenharmony_ci        GetOrCreateMessageClass(self, extension->containing_type())
205ffe3c632Sopenharmony_ci            ->AsPyObject());
206ffe3c632Sopenharmony_ci    if (py_extended_class == NULL) {
207ffe3c632Sopenharmony_ci      return NULL;
208ffe3c632Sopenharmony_ci    }
209ffe3c632Sopenharmony_ci    ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
210ffe3c632Sopenharmony_ci    if (py_extension == NULL) {
211ffe3c632Sopenharmony_ci      return NULL;
212ffe3c632Sopenharmony_ci    }
213ffe3c632Sopenharmony_ci    ScopedPyObjectPtr result(cmessage::RegisterExtension(
214ffe3c632Sopenharmony_ci        py_extended_class.get(), py_extension.get()));
215ffe3c632Sopenharmony_ci    if (result == NULL) {
216ffe3c632Sopenharmony_ci      return NULL;
217ffe3c632Sopenharmony_ci    }
218ffe3c632Sopenharmony_ci  }
219ffe3c632Sopenharmony_ci  return reinterpret_cast<CMessageClass*>(message_class.release());
220ffe3c632Sopenharmony_ci}
221ffe3c632Sopenharmony_ci
222ffe3c632Sopenharmony_ci// Retrieve the message class added to our database.
223ffe3c632Sopenharmony_ciCMessageClass* GetMessageClass(PyMessageFactory* self,
224ffe3c632Sopenharmony_ci                               const Descriptor* message_descriptor) {
225ffe3c632Sopenharmony_ci  typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
226ffe3c632Sopenharmony_ci  iterator ret = self->classes_by_descriptor->find(message_descriptor);
227ffe3c632Sopenharmony_ci  if (ret == self->classes_by_descriptor->end()) {
228ffe3c632Sopenharmony_ci    PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
229ffe3c632Sopenharmony_ci                 message_descriptor->full_name().c_str());
230ffe3c632Sopenharmony_ci    return NULL;
231ffe3c632Sopenharmony_ci  } else {
232ffe3c632Sopenharmony_ci    return ret->second;
233ffe3c632Sopenharmony_ci  }
234ffe3c632Sopenharmony_ci}
235ffe3c632Sopenharmony_ci
236ffe3c632Sopenharmony_cistatic PyMethodDef Methods[] = {
237ffe3c632Sopenharmony_ci    {NULL}};
238ffe3c632Sopenharmony_ci
239ffe3c632Sopenharmony_cistatic PyObject* GetPool(PyMessageFactory* self, void* closure) {
240ffe3c632Sopenharmony_ci  Py_INCREF(self->pool);
241ffe3c632Sopenharmony_ci  return reinterpret_cast<PyObject*>(self->pool);
242ffe3c632Sopenharmony_ci}
243ffe3c632Sopenharmony_ci
244ffe3c632Sopenharmony_cistatic PyGetSetDef Getters[] = {
245ffe3c632Sopenharmony_ci    {"pool", (getter)GetPool, NULL, "DescriptorPool"},
246ffe3c632Sopenharmony_ci    {NULL}
247ffe3c632Sopenharmony_ci};
248ffe3c632Sopenharmony_ci
249ffe3c632Sopenharmony_ci}  // namespace message_factory
250ffe3c632Sopenharmony_ci
251ffe3c632Sopenharmony_ciPyTypeObject PyMessageFactory_Type = {
252ffe3c632Sopenharmony_ci    PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
253ffe3c632Sopenharmony_ci    ".MessageFactory",         // tp_name
254ffe3c632Sopenharmony_ci    sizeof(PyMessageFactory),  // tp_basicsize
255ffe3c632Sopenharmony_ci    0,                         // tp_itemsize
256ffe3c632Sopenharmony_ci    message_factory::Dealloc,  // tp_dealloc
257ffe3c632Sopenharmony_ci    0,                         // tp_print
258ffe3c632Sopenharmony_ci    0,                         // tp_getattr
259ffe3c632Sopenharmony_ci    0,                         // tp_setattr
260ffe3c632Sopenharmony_ci    0,                         // tp_compare
261ffe3c632Sopenharmony_ci    0,                         // tp_repr
262ffe3c632Sopenharmony_ci    0,                         // tp_as_number
263ffe3c632Sopenharmony_ci    0,                         // tp_as_sequence
264ffe3c632Sopenharmony_ci    0,                         // tp_as_mapping
265ffe3c632Sopenharmony_ci    0,                         // tp_hash
266ffe3c632Sopenharmony_ci    0,                         // tp_call
267ffe3c632Sopenharmony_ci    0,                         // tp_str
268ffe3c632Sopenharmony_ci    0,                         // tp_getattro
269ffe3c632Sopenharmony_ci    0,                         // tp_setattro
270ffe3c632Sopenharmony_ci    0,                         // tp_as_buffer
271ffe3c632Sopenharmony_ci    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,  // tp_flags
272ffe3c632Sopenharmony_ci    "A static Message Factory",                                     // tp_doc
273ffe3c632Sopenharmony_ci    message_factory::GcTraverse,  // tp_traverse
274ffe3c632Sopenharmony_ci    message_factory::GcClear,     // tp_clear
275ffe3c632Sopenharmony_ci    0,                            // tp_richcompare
276ffe3c632Sopenharmony_ci    0,                            // tp_weaklistoffset
277ffe3c632Sopenharmony_ci    0,                            // tp_iter
278ffe3c632Sopenharmony_ci    0,                            // tp_iternext
279ffe3c632Sopenharmony_ci    message_factory::Methods,     // tp_methods
280ffe3c632Sopenharmony_ci    0,                            // tp_members
281ffe3c632Sopenharmony_ci    message_factory::Getters,     // tp_getset
282ffe3c632Sopenharmony_ci    0,                            // tp_base
283ffe3c632Sopenharmony_ci    0,                            // tp_dict
284ffe3c632Sopenharmony_ci    0,                            // tp_descr_get
285ffe3c632Sopenharmony_ci    0,                            // tp_descr_set
286ffe3c632Sopenharmony_ci    0,                            // tp_dictoffset
287ffe3c632Sopenharmony_ci    0,                            // tp_init
288ffe3c632Sopenharmony_ci    0,                            // tp_alloc
289ffe3c632Sopenharmony_ci    message_factory::New,         // tp_new
290ffe3c632Sopenharmony_ci    PyObject_GC_Del,              // tp_free
291ffe3c632Sopenharmony_ci};
292ffe3c632Sopenharmony_ci
293ffe3c632Sopenharmony_cibool InitMessageFactory() {
294ffe3c632Sopenharmony_ci  if (PyType_Ready(&PyMessageFactory_Type) < 0) {
295ffe3c632Sopenharmony_ci    return false;
296ffe3c632Sopenharmony_ci  }
297ffe3c632Sopenharmony_ci
298ffe3c632Sopenharmony_ci  return true;
299ffe3c632Sopenharmony_ci}
300ffe3c632Sopenharmony_ci
301ffe3c632Sopenharmony_ci}  // namespace python
302ffe3c632Sopenharmony_ci}  // namespace protobuf
303ffe3c632Sopenharmony_ci}  // namespace google
304