aboutsummaryrefslogtreecommitdiffstats
path: root/library/python/runtime_py3/test/subinterpreter/py3_subinterpreters.cpp
blob: 0a934d4db50ac878631e1d1ecf26da972bb8f2c2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include "stdout_interceptor.h"

#include <util/stream/str.h>

#include <library/cpp/testing/gtest/gtest.h>

#include <Python.h>

#include <thread>
#include <algorithm>

struct TSubinterpreters: ::testing::Test {
    static void SetUpTestSuite() {
        Py_InitializeEx(0);
        EXPECT_TRUE(TPyStdoutInterceptor::SetupInterceptionSupport());
    }
    static void TearDownTestSuite() {
        Py_Finalize();
    }

    static void ThreadPyRun(PyInterpreterState* interp, IOutputStream& pyout, const char* pycode) {
        PyThreadState* state = PyThreadState_New(interp);
        PyEval_RestoreThread(state);

        {
            TPyStdoutInterceptor interceptor{pyout};
            PyRun_SimpleString(pycode);
        }

        PyThreadState_Clear(state);
        PyThreadState_DeleteCurrent();
    }
};

TEST_F(TSubinterpreters, NonSubinterpreterFlowStillWorks) {
    TStringStream pyout;
    TPyStdoutInterceptor interceptor{pyout};

    PyRun_SimpleString("print('Hello World')");
    EXPECT_EQ(pyout.Str(), "Hello World\n");
}

TEST_F(TSubinterpreters, ThreadedSubinterpretersFlowWorks) {
    TStringStream pyout[2];

    PyInterpreterConfig cfg = {
        .use_main_obmalloc = 0,
        .allow_fork = 0,
        .allow_exec = 0,
        .allow_threads = 1,
        .allow_daemon_threads = 0,
        .check_multi_interp_extensions = 1,
        .gil = PyInterpreterConfig_OWN_GIL,
    };

    PyThreadState* mainState = PyThreadState_Get();
    PyThreadState *sub[2] = {nullptr, nullptr};
    Py_NewInterpreterFromConfig(&sub[0], &cfg);
    ASSERT_NE(sub[0], nullptr);
    Py_NewInterpreterFromConfig(&sub[1], &cfg);
    ASSERT_NE(sub[1], nullptr);
    PyThreadState_Swap(mainState);

    PyThreadState* savedState = PyEval_SaveThread();
    std::array<std::thread, 2> threads{
        std::thread{ThreadPyRun, sub[0]->interp, std::ref(pyout[0]), "print('Hello Thread 0')"},
        std::thread{ThreadPyRun, sub[1]->interp, std::ref(pyout[1]), "print('Hello Thread 1')"}
    };
    std::ranges::for_each(threads, &std::thread::join);
    PyEval_RestoreThread(savedState);

    PyThreadState_Swap(sub[0]);
    Py_EndInterpreter(sub[0]);

    PyThreadState_Swap(sub[1]);
    Py_EndInterpreter(sub[1]);

    PyThreadState_Swap(mainState);

    EXPECT_EQ(pyout[0].Str(), "Hello Thread 0\n");
    EXPECT_EQ(pyout[1].Str(), "Hello Thread 1\n");
}