1// Protocol Buffers - Google's data interchange format 2// Copyright 2008 Google Inc. All rights reserved. 3// https://developers.google.com/protocol-buffers/ 4// 5// Redistribution and use in source and binary forms, with or without 6// modification, are permitted provided that the following conditions are 7// met: 8// 9// * Redistributions of source code must retain the above copyright 10// notice, this list of conditions and the following disclaimer. 11// * Redistributions in binary form must reproduce the above 12// copyright notice, this list of conditions and the following disclaimer 13// in the documentation and/or other materials provided with the 14// distribution. 15// * Neither the name of Google Inc. nor the names of its 16// contributors may be used to endorse or promote products derived from 17// this software without specific prior written permission. 18// 19// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31#include <unordered_map> 32 33#include <Python.h> 34 35#include <google/protobuf/dynamic_message.h> 36#include <google/protobuf/pyext/descriptor.h> 37#include <google/protobuf/pyext/message.h> 38#include <google/protobuf/pyext/message_factory.h> 39#include <google/protobuf/pyext/scoped_pyobject_ptr.h> 40 41#if PY_MAJOR_VERSION >= 3 42 #if PY_VERSION_HEX < 0x03030000 43 #error "Python 3.0 - 3.2 are not supported." 44 #endif 45 #define PyString_AsStringAndSize(ob, charpp, sizep) \ 46 (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \ 47 PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ 48 ? -1 \ 49 : 0) \ 50 : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) 51#endif 52 53namespace google { 54namespace protobuf { 55namespace python { 56 57namespace message_factory { 58 59PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) { 60 PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>( 61 PyType_GenericAlloc(type, 0)); 62 if (factory == NULL) { 63 return NULL; 64 } 65 66 DynamicMessageFactory* message_factory = new DynamicMessageFactory(); 67 // This option might be the default some day. 68 message_factory->SetDelegateToGeneratedFactory(true); 69 factory->message_factory = message_factory; 70 71 factory->pool = pool; 72 Py_INCREF(pool); 73 74 factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap(); 75 76 return factory; 77} 78 79PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) { 80 static char* kwlist[] = {"pool", 0}; 81 PyObject* pool = NULL; 82 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &pool)) { 83 return NULL; 84 } 85 ScopedPyObjectPtr owned_pool; 86 if (pool == NULL || pool == Py_None) { 87 owned_pool.reset(PyObject_CallFunction( 88 reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), NULL)); 89 if (owned_pool == NULL) { 90 return NULL; 91 } 92 pool = owned_pool.get(); 93 } else { 94 if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) { 95 PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s", 96 pool->ob_type->tp_name); 97 return NULL; 98 } 99 } 100 101 return reinterpret_cast<PyObject*>( 102 NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool))); 103} 104 105static void Dealloc(PyObject* pself) { 106 PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself); 107 108 typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; 109 for (iterator it = self->classes_by_descriptor->begin(); 110 it != self->classes_by_descriptor->end(); ++it) { 111 Py_CLEAR(it->second); 112 } 113 delete self->classes_by_descriptor; 114 delete self->message_factory; 115 Py_CLEAR(self->pool); 116 Py_TYPE(self)->tp_free(pself); 117} 118 119static int GcTraverse(PyObject* pself, visitproc visit, void* arg) { 120 PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself); 121 Py_VISIT(self->pool); 122 for (const auto& desc_and_class : *self->classes_by_descriptor) { 123 Py_VISIT(desc_and_class.second); 124 } 125 return 0; 126} 127 128static int GcClear(PyObject* pself) { 129 PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself); 130 // Here it's important to not clear self->pool, so that the C++ DescriptorPool 131 // is still alive when self->message_factory is destructed. 132 for (auto& desc_and_class : *self->classes_by_descriptor) { 133 Py_CLEAR(desc_and_class.second); 134 } 135 136 return 0; 137} 138 139// Add a message class to our database. 140int RegisterMessageClass(PyMessageFactory* self, 141 const Descriptor* message_descriptor, 142 CMessageClass* message_class) { 143 Py_INCREF(message_class); 144 typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; 145 std::pair<iterator, bool> ret = self->classes_by_descriptor->insert( 146 std::make_pair(message_descriptor, message_class)); 147 if (!ret.second) { 148 // Update case: DECREF the previous value. 149 Py_DECREF(ret.first->second); 150 ret.first->second = message_class; 151 } 152 return 0; 153} 154 155CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, 156 const Descriptor* descriptor) { 157 // This is the same implementation as MessageFactory.GetPrototype(). 158 159 // Do not create a MessageClass that already exists. 160 std::unordered_map<const Descriptor*, CMessageClass*>::iterator it = 161 self->classes_by_descriptor->find(descriptor); 162 if (it != self->classes_by_descriptor->end()) { 163 Py_INCREF(it->second); 164 return it->second; 165 } 166 ScopedPyObjectPtr py_descriptor( 167 PyMessageDescriptor_FromDescriptor(descriptor)); 168 if (py_descriptor == NULL) { 169 return NULL; 170 } 171 // Create a new message class. 172 ScopedPyObjectPtr args(Py_BuildValue( 173 "s(){sOsOsO}", descriptor->name().c_str(), 174 "DESCRIPTOR", py_descriptor.get(), 175 "__module__", Py_None, 176 "message_factory", self)); 177 if (args == NULL) { 178 return NULL; 179 } 180 ScopedPyObjectPtr message_class(PyObject_CallObject( 181 reinterpret_cast<PyObject*>(CMessageClass_Type), args.get())); 182 if (message_class == NULL) { 183 return NULL; 184 } 185 // Create messages class for the messages used by the fields, and registers 186 // all extensions for these messages during the recursion. 187 for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) { 188 const Descriptor* sub_descriptor = 189 descriptor->field(field_idx)->message_type(); 190 // It is NULL if the field type is not a message. 191 if (sub_descriptor != NULL) { 192 CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor); 193 if (result == NULL) { 194 return NULL; 195 } 196 Py_DECREF(result); 197 } 198 } 199 200 // Register extensions defined in this message. 201 for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) { 202 const FieldDescriptor* extension = descriptor->extension(ext_idx); 203 ScopedPyObjectPtr py_extended_class( 204 GetOrCreateMessageClass(self, extension->containing_type()) 205 ->AsPyObject()); 206 if (py_extended_class == NULL) { 207 return NULL; 208 } 209 ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension)); 210 if (py_extension == NULL) { 211 return NULL; 212 } 213 ScopedPyObjectPtr result(cmessage::RegisterExtension( 214 py_extended_class.get(), py_extension.get())); 215 if (result == NULL) { 216 return NULL; 217 } 218 } 219 return reinterpret_cast<CMessageClass*>(message_class.release()); 220} 221 222// Retrieve the message class added to our database. 223CMessageClass* GetMessageClass(PyMessageFactory* self, 224 const Descriptor* message_descriptor) { 225 typedef PyMessageFactory::ClassesByMessageMap::iterator iterator; 226 iterator ret = self->classes_by_descriptor->find(message_descriptor); 227 if (ret == self->classes_by_descriptor->end()) { 228 PyErr_Format(PyExc_TypeError, "No message class registered for '%s'", 229 message_descriptor->full_name().c_str()); 230 return NULL; 231 } else { 232 return ret->second; 233 } 234} 235 236static PyMethodDef Methods[] = { 237 {NULL}}; 238 239static PyObject* GetPool(PyMessageFactory* self, void* closure) { 240 Py_INCREF(self->pool); 241 return reinterpret_cast<PyObject*>(self->pool); 242} 243 244static PyGetSetDef Getters[] = { 245 {"pool", (getter)GetPool, NULL, "DescriptorPool"}, 246 {NULL} 247}; 248 249} // namespace message_factory 250 251PyTypeObject PyMessageFactory_Type = { 252 PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME 253 ".MessageFactory", // tp_name 254 sizeof(PyMessageFactory), // tp_basicsize 255 0, // tp_itemsize 256 message_factory::Dealloc, // tp_dealloc 257 0, // tp_print 258 0, // tp_getattr 259 0, // tp_setattr 260 0, // tp_compare 261 0, // tp_repr 262 0, // tp_as_number 263 0, // tp_as_sequence 264 0, // tp_as_mapping 265 0, // tp_hash 266 0, // tp_call 267 0, // tp_str 268 0, // tp_getattro 269 0, // tp_setattro 270 0, // tp_as_buffer 271 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, // tp_flags 272 "A static Message Factory", // tp_doc 273 message_factory::GcTraverse, // tp_traverse 274 message_factory::GcClear, // tp_clear 275 0, // tp_richcompare 276 0, // tp_weaklistoffset 277 0, // tp_iter 278 0, // tp_iternext 279 message_factory::Methods, // tp_methods 280 0, // tp_members 281 message_factory::Getters, // tp_getset 282 0, // tp_base 283 0, // tp_dict 284 0, // tp_descr_get 285 0, // tp_descr_set 286 0, // tp_dictoffset 287 0, // tp_init 288 0, // tp_alloc 289 message_factory::New, // tp_new 290 PyObject_GC_Del, // tp_free 291}; 292 293bool InitMessageFactory() { 294 if (PyType_Ready(&PyMessageFactory_Type) < 0) { 295 return false; 296 } 297 298 return true; 299} 300 301} // namespace python 302} // namespace protobuf 303} // namespace google 304