aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/pybind/module.h
blob: 41dcb4dfec8cfd36d9510b2d6189762efb2966c1 (plain) (tree)














































































































































































                                                                                                                                                                                                        
#pragma once

#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "ptr.h"
#include "cast.h"
#include "exceptions.h"

#include <util/generic/function.h>

namespace NPyBind {
#if PY_MAJOR_VERSION >= 3
    namespace NPrivate {
        using TFinalizationCallBack = std::function<void()>;
        void AddFinalizationCallBack(TFinalizationCallBack);
        class TAtExitRegistrar: private TNonCopyable {
            TAtExitRegistrar(TPyObjectPtr module);
        public:
            static void Instantiate(TPyObjectPtr module) {
                static TAtExitRegistrar registrar(module);
                Y_UNUSED(registrar);
            }
        };

        class TPyBindModuleRegistrar: private TNonCopyable {
            TPyBindModuleRegistrar();
            TPyObjectPtr Module;
        public:
            static void Instantiate() {
                static TPyBindModuleRegistrar registrar;
                Y_UNUSED(registrar);
            }
        };
    } //NPrivate
#endif

    class TModuleHolder {
    private:
        TModuleHolder(const TModuleHolder&);
        TModuleHolder& operator=(const TModuleHolder&);

        TModuleHolder();
    private:
        typedef PyCFunction TModuleMethod;
#if PY_MAJOR_VERSION >= 3
        typedef PyObject* (*TModuleInitFunc)();
#else
        typedef void (*TModuleInitFunc)();
#endif

        struct TMethodDef {
            TString Name;
            TModuleMethod Method;
            TString Description;
            int Flags;

            TMethodDef(const TString& name, TModuleMethod method, const TString& descr, int flags)
                : Name(name)
                , Method(method)
                , Description(descr)
                , Flags(flags)
            {
            }

            operator PyMethodDef() const {
                PyMethodDef cur = {Name.c_str(), Method, Flags, Description.c_str()};
                return cur;
            }
        };

        typedef TSimpleSharedPtr<TVector<TMethodDef>> TMethodDefVecPtr;
        typedef TSimpleSharedPtr<TVector<PyMethodDef>> TPyMethodDefVecPtr;

        TVector<TMethodDefVecPtr> Methods;
        TVector<TPyMethodDefVecPtr> Defs;
#if PY_MAJOR_VERSION >= 3
        //because the md_name will leak otherwise
        class TPyModuleDefWithName {
            PyModuleDef Def;
            TString Name;
        public:
            explicit TPyModuleDefWithName(TString name, TPyMethodDefVecPtr moduleDefs)
                : Name(std::move(name))
            {
                Def = PyModuleDef{
                    PyModuleDef_HEAD_INIT,
                    Name.c_str(),
                    nullptr,
                    -1,
                    moduleDefs->data(),
                    nullptr, nullptr, nullptr, nullptr
                };
            }
            PyModuleDef* GetDefPtr() {
                return &Def;
            }

        };
        TVector<TSimpleSharedPtr<TPyModuleDefWithName>> ModuleDefs;
#endif

        template <TModuleMethod method>
        static PyObject* MethodWrapper(PyObject* obj, PyObject* args) {
            try {
                PyObject* res = method(obj, args);
                if (!res && !PyErr_Occurred())
                    ythrow yexception() << "\nModule method exited with NULL, but didn't set Error.\n Options:\n -- Return correct value or None;\n -- Set python exception;\n -- Throw c++ exception.";
                return res;
            } catch (const std::exception& ex) {
                PyErr_SetString(TExceptionsHolder::Instance().ToPyException(ex).Get(), ex.what());
            } catch (...) {
                PyErr_SetString(PyExc_RuntimeError, "Unknown error occurred while trying to call module method");
            }
            return nullptr;
        }

    public:
        static TModuleHolder& Instance() {
            static TModuleHolder Holder;
            return Holder;
        }

        void ImportModule(TPyObjectPtr module, const char* const name, TModuleInitFunc initFunc) {
            PyImport_AppendInittab(const_cast<char*>(name), initFunc);
            TPyObjectPtr importedModule(PyImport_ImportModule(name), true);
            PyModule_AddObject(module.Get(), name, importedModule.Get());
        }

        template <TModuleMethod method>
        void AddModuleMethod(const TString& name, const TString& descr = "") {
            Methods.back()->push_back(TMethodDef(name, MethodWrapper<method>, descr, METH_VARARGS));
        }

        TPyObjectPtr InitModule(const TString& name) {
            Defs.push_back(new TVector<PyMethodDef>(Methods.back()->begin(), Methods.back()->end()));
            PyMethodDef blank = {nullptr, nullptr, 0, nullptr};
            Defs.back()->push_back(blank);
#if PY_MAJOR_VERSION >= 3
            ModuleDefs.push_back(MakeSimpleShared<TPyModuleDefWithName>(name, Defs.back()));
            TPyObjectPtr res(PyModule_Create(ModuleDefs.back()->GetDefPtr()));
            NPrivate::TAtExitRegistrar::Instantiate(res);
            NPrivate::TPyBindModuleRegistrar::Instantiate();
#else
            TPyObjectPtr res(Py_InitModule(name.c_str(), &(Defs.back()->at(0))));
#endif
            Methods.push_back(new TVector<TMethodDef>);
            return res;
        }
    };

    template <typename TMethodSignature, TMethodSignature method>
    class TModuleMethodCaller {
    private:
        template <typename TResult, typename... Args>
        struct TCaller {
            static PyObject* Call(PyObject* args) {
                return BuildPyObject(Apply(method, GetArguments<Args...>(args)));
            }
        };

        template <typename TResult, typename... Args>
        static PyObject* InternalCall(TResult (*)(Args...), PyObject* args) {
            return BuildPyObject(Apply(method, GetArguments<Args...>(args)));
        }

    public:
        static PyObject* Call(PyObject*, PyObject* args) {
            if (args && (!PyTuple_Check(args) || PyTuple_Size(args) != TFunctionArgs<TMethodSignature>::Length)) {
                ythrow yexception() << "Method takes " << (size_t)(TFunctionArgs<TMethodSignature>::Length) << " arguments, " << PyTuple_Size(args) << " provided";
            }

            return InternalCall(method, args);
        }
    };

}