aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/lua
diff options
context:
space:
mode:
authormonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
committermonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
commit06e5c21a835c0e923506c4ff27929f34e00761c2 (patch)
tree75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/lua
parent03f024c4412e3aa613bb543cf1660176320ba8f4 (diff)
downloadydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz
fix ya.make
Diffstat (limited to 'library/cpp/lua')
-rw-r--r--library/cpp/lua/eval.cpp178
-rw-r--r--library/cpp/lua/eval.h65
-rw-r--r--library/cpp/lua/json.cpp62
-rw-r--r--library/cpp/lua/json.h14
-rw-r--r--library/cpp/lua/wrapper.cpp229
-rw-r--r--library/cpp/lua/wrapper.h565
6 files changed, 1113 insertions, 0 deletions
diff --git a/library/cpp/lua/eval.cpp b/library/cpp/lua/eval.cpp
new file mode 100644
index 0000000000..5c78d0ad3f
--- /dev/null
+++ b/library/cpp/lua/eval.cpp
@@ -0,0 +1,178 @@
+#include "eval.h"
+#include "json.h"
+#include <util/string/cast.h>
+#include <util/system/guard.h>
+#include <util/stream/mem.h>
+#include <util/string/builder.h>
+
+TLuaEval::TLuaEval()
+ : FunctionNameCounter_(0)
+{
+ LuaState_.BootStrap();
+}
+
+void TLuaEval::SetVariable(TZtStringBuf name, const NJson::TJsonValue& value) {
+ TGuard<TMutex> guard(LuaMutex_);
+
+ NLua::PushJsonValue(&LuaState_, value);
+ LuaState_.set_global(name.c_str());
+}
+
+void TLuaEval::RunExpressionLocked(const TGuard<TMutex>&, const TExpression& expr) {
+ LuaState_.push_global(expr.Name.c_str());
+ LuaState_.call(0, 1);
+}
+
+TString TLuaEval::EvalCompiled(const TExpression& expr) {
+ TGuard<TMutex> guard(LuaMutex_);
+ RunExpressionLocked(guard, expr);
+ return LuaState_.pop_value();
+}
+
+void TLuaEval::EvalCompiledRaw(const TExpression& expr) {
+ TGuard<TMutex> guard(LuaMutex_);
+ RunExpressionLocked(guard, expr);
+}
+
+bool TLuaEval::EvalCompiledCondition(const TExpression& expr) {
+ TGuard<TMutex> guard(LuaMutex_);
+ RunExpressionLocked(guard, expr);
+ return LuaState_.pop_bool_strict();
+}
+
+TString TLuaEval::EvalRaw(TStringBuf code) {
+ TMemoryInput bodyIn(code.data(), code.size());
+
+ LuaState_.Load(&bodyIn, "main");
+ LuaState_.call(0, 1);
+
+ return LuaState_.pop_value();
+}
+
+void TLuaEval::ParseChunk(TStringBuf code) {
+ TMemoryInput in(code.data(), code.size());
+
+ LuaState_.Load(&in, "chunk_" + GenerateName());
+ LuaState_.call(0, 0);
+}
+
+TString TLuaEval::EvalExpression(TStringBuf expression) {
+ const auto expr = Compile(expression);
+ try {
+ return EvalCompiled(expr);
+ } catch (const yexception& e) {
+ throw yexception(e) << '\n' << expression;
+ }
+}
+
+TLuaEval::TExpression TLuaEval::Compile(TStringBuf expression) {
+ TGuard<TMutex> guard(LuaMutex_);
+
+ TString name = GenerateName();
+
+ TString body = "function ";
+ body += name;
+ body += "()\n\treturn (";
+ body += expression;
+ body += ")\nend\n";
+
+ try {
+ TMemoryInput bodyIn(body.c_str(), body.size());
+ LuaState_.Load(&bodyIn, "chunk_" + name);
+ LuaState_.call(0, 0);
+ } catch (const yexception& e) {
+ ythrow yexception(e) << "\n"
+ << body;
+ }
+ return {name};
+}
+
+TLuaEval::TExpression TLuaEval::CompileFunction(TStringBuf expression) {
+ TString name = GenerateName();
+ TStringBuilder body;
+ body << "function " << name << "()" << Endl
+ << expression << Endl
+ << "end";
+
+ return CompileRaw(TStringBuf(body.data(), body.size()), name);
+}
+
+TLuaEval::TExpression TLuaEval::CompileRaw(TStringBuf body, const TString& name) {
+ TGuard<TMutex> guard(LuaMutex_);
+ try {
+ TMemoryInput bodyIn(body.data(), body.size());
+ LuaState_.Load(&bodyIn, "chunk_" + name);
+ LuaState_.call(0, 0);
+ } catch (const yexception& e) {
+ ythrow yexception(e) << "\n" << body;
+ }
+ return { name };
+}
+
+TString TLuaEval::GenerateName() {
+ TGuard<TMutex> guard(LuaMutex_);
+ return "dummy_" + ToString(FunctionNameCounter_++);
+}
+
+template <class T>
+static inline T FindEnd(T b, T e) {
+ size_t cnt = 0;
+
+ while (b < e) {
+ switch (*b) {
+ case '{':
+ ++cnt;
+ break;
+
+ case '}':
+ if (cnt == 0) {
+ return b;
+ }
+
+ --cnt;
+ break;
+ }
+
+ ++b;
+ }
+
+ return b;
+}
+
+TString TLuaEval::PreprocessOne(TStringBuf line) {
+ const size_t pos = line.find("${");
+
+ if (pos == TStringBuf::npos) {
+ return EvalExpression(line);
+ }
+
+ const char* rpos = FindEnd(line.data() + pos + 2, line.end());
+
+ if (rpos == line.end()) {
+ ythrow yexception() << TStringBuf("can not parse ") << line;
+ }
+
+ const TStringBuf before = line.SubStr(0, pos);
+ const TStringBuf after = TStringBuf(rpos + 1, line.end());
+ const TStringBuf code = TStringBuf(line.data() + pos + 2, rpos);
+
+ TString res;
+
+ if (code.find("${") == TStringBuf::npos) {
+ res = EvalExpression(code);
+ } else {
+ res = EvalExpression(Preprocess(code));
+ }
+
+ return ToString(before) + res + ToString(after);
+}
+
+TString TLuaEval::Preprocess(TStringBuf line) {
+ TString res = ToString(line);
+
+ while (res.find("${") != TString::npos) {
+ res = PreprocessOne(res);
+ }
+
+ return res;
+}
diff --git a/library/cpp/lua/eval.h b/library/cpp/lua/eval.h
new file mode 100644
index 0000000000..77c59d7efd
--- /dev/null
+++ b/library/cpp/lua/eval.h
@@ -0,0 +1,65 @@
+#pragma once
+
+#include "wrapper.h"
+
+#include <library/cpp/json/json_value.h>
+
+#include <util/system/mutex.h>
+
+class TLuaEval {
+public:
+ TLuaEval();
+
+ template <class C>
+ inline TLuaEval& SetVars(const C& container) {
+ for (auto& [k, v] : container) {
+ SetVariable(k, v);
+ }
+
+ return *this;
+ }
+
+ inline TLuaEval& Parse(TStringBuf chunk) {
+ ParseChunk(chunk);
+
+ return *this;
+ }
+
+ void SetVariable(TZtStringBuf name, const NJson::TJsonValue& value);
+ template <typename T>
+ void SetUserdata(TZtStringBuf name, T&& userdata) {
+ LuaState_.push_userdata(std::forward<T>(userdata));
+ LuaState_.set_global(name.c_str());
+ }
+ TString EvalExpression(TStringBuf expression);
+ TString EvalRaw(TStringBuf code);
+ void ParseChunk(TStringBuf code);
+ TString Preprocess(TStringBuf line);
+ TString PreprocessOne(TStringBuf line);
+
+ struct TExpression {
+ TString Name;
+ };
+
+ TExpression Compile(TStringBuf expression);
+ TExpression CompileFunction(TStringBuf expression);
+ TExpression CompileRaw(TStringBuf body, const TString& name);
+ TString EvalCompiled(const TExpression& compiled);
+ void EvalCompiledRaw(const TExpression& compiled);
+ bool EvalCompiledCondition(const TExpression& compiled);
+ template <typename TNumber>
+ TNumber EvalCompiledNumeric(const TExpression& compiled) {
+ TGuard<TMutex> guard(LuaMutex_);
+ RunExpressionLocked(guard, compiled);
+ return LuaState_.pop_number<TNumber>();
+ }
+
+private:
+ TString GenerateName();
+ TString Evaluate(const TString& name, const TString& body);
+ void RunExpressionLocked(const TGuard<TMutex>& lock, const TExpression& compiled);
+
+ TLuaStateHolder LuaState_;
+ ui64 FunctionNameCounter_;
+ TMutex LuaMutex_;
+};
diff --git a/library/cpp/lua/json.cpp b/library/cpp/lua/json.cpp
new file mode 100644
index 0000000000..da7d228459
--- /dev/null
+++ b/library/cpp/lua/json.cpp
@@ -0,0 +1,62 @@
+#include "json.h"
+#include "wrapper.h"
+
+#include <library/cpp/json/json_value.h>
+
+using namespace NJson;
+
+void NLua::PushJsonValue(TLuaStateHolder* state, const TJsonValue& json) {
+ // each recursive call will explicitly push only a single element to stack relying on subcalls to reserve stack space for themselves
+ // I.e. for a map {"a": "b"} the first call will ensure stack space for create_table, then call PushJsonValue for the string,
+ // this PushJsonValue will ensure stack space for the string. Thus only a single ensure_stack at the start of the function is enough.
+ state->ensure_stack(1);
+ switch (json.GetType()) {
+ case JSON_UNDEFINED:
+ ythrow yexception() << "cannot push undefined json value";
+
+ case JSON_NULL:
+ state->push_nil();
+ break;
+
+ case JSON_BOOLEAN:
+ state->push_bool(json.GetBoolean());
+ break;
+
+ case JSON_INTEGER:
+ state->push_number(json.GetInteger());
+ break;
+
+ case JSON_UINTEGER:
+ state->push_number(json.GetUInteger());
+ break;
+
+ case JSON_DOUBLE:
+ state->push_number(json.GetDouble());
+ break;
+
+ case JSON_STRING:
+ state->push_string(json.GetString());
+ break;
+
+ case JSON_MAP:
+ state->create_table();
+ for (const auto& pair : json.GetMap()) {
+ PushJsonValue(state, pair.second); // Recursive call tests for stack space on its own
+ state->set_field(-2, pair.first.data());
+ }
+ break;
+
+ case JSON_ARRAY: {
+ state->create_table();
+ int index = 1; // lua arrays start from 1
+ for (const auto& element : json.GetArray()) {
+ PushJsonValue(state, element); // Recursive call tests for stack space on its own, no need to double check
+ state->rawseti(-2, index++);
+ }
+ break;
+ }
+
+ default:
+ ythrow yexception() << "Unexpected json value type";
+ }
+}
diff --git a/library/cpp/lua/json.h b/library/cpp/lua/json.h
new file mode 100644
index 0000000000..eead596756
--- /dev/null
+++ b/library/cpp/lua/json.h
@@ -0,0 +1,14 @@
+#pragma once
+
+class TLuaStateHolder;
+
+namespace NJson {
+ class TJsonValue;
+}
+
+namespace NLua {
+ // Try to push TJsonValue to lua stack.
+ // Lua stack state is undefined if there's not enough memory to grow stack appropriately
+ // Exception is thrown in this case
+ void PushJsonValue(TLuaStateHolder* state, const NJson::TJsonValue& json);
+}
diff --git a/library/cpp/lua/wrapper.cpp b/library/cpp/lua/wrapper.cpp
new file mode 100644
index 0000000000..ad0fb7537a
--- /dev/null
+++ b/library/cpp/lua/wrapper.cpp
@@ -0,0 +1,229 @@
+#include "wrapper.h"
+
+#include <util/datetime/cputimer.h>
+#include <util/stream/buffered.h>
+#include <util/stream/buffer.h>
+#include <util/stream/format.h>
+#include <util/stream/input.h>
+#include <util/stream/mem.h>
+#include <util/stream/output.h>
+#include <util/system/sys_alloc.h>
+
+namespace {
+ class TLuaCountLimit {
+ public:
+ TLuaCountLimit(lua_State* state, int count)
+ : State(state)
+ {
+ lua_sethook(State, LuaHookCallback, LUA_MASKCOUNT, count);
+ }
+
+ ~TLuaCountLimit() {
+ lua_sethook(State, LuaHookCallback, 0, 0);
+ }
+
+ static void LuaHookCallback(lua_State* L, lua_Debug*) {
+ luaL_error(L, "Lua instruction count limit exceeded");
+ }
+
+ private:
+ lua_State* State;
+ }; // class TLuaCountLimit
+
+ class TLuaTimeLimit {
+ public:
+ TLuaTimeLimit(lua_State* state, TDuration limit, int count)
+ : State(state)
+ , Limit(limit)
+ {
+ lua_pushlightuserdata(State, (void*)LuaHookCallback); //key
+ lua_pushlightuserdata(State, (void*)this); //value
+ lua_settable(State, LUA_REGISTRYINDEX);
+
+ lua_sethook(State, LuaHookCallback, LUA_MASKCOUNT, count);
+ }
+
+ ~TLuaTimeLimit() {
+ lua_sethook(State, LuaHookCallback, 0, 0);
+ }
+
+ bool Exceeded() {
+ return Timer.Get() > Limit;
+ }
+
+ static void LuaHookCallback(lua_State* L, lua_Debug*) {
+ lua_pushlightuserdata(L, (void*)LuaHookCallback);
+ lua_gettable(L, LUA_REGISTRYINDEX);
+ TLuaTimeLimit* t = static_cast<TLuaTimeLimit*>(lua_touserdata(L, -1));
+ lua_pop(L, 1);
+ if (t->Exceeded()) {
+ luaL_error(L, "time limit exceeded");
+ }
+ }
+
+ private:
+ lua_State* State;
+ const TDuration Limit;
+ TSimpleTimer Timer;
+ }; // class TLuaTimeLimit
+
+ class TLuaReader {
+ public:
+ TLuaReader(IZeroCopyInput* in)
+ : In_(in)
+ {
+ }
+
+ inline void Load(lua_State* state, const char* name) {
+ if (lua_load(state, ReadCallback, this, name
+#if LUA_VERSION_NUM > 501
+ ,
+ nullptr
+#endif
+ ))
+ {
+ ythrow TLuaStateHolder::TError() << "can not parse lua chunk " << name << ": " << lua_tostring(state, -1);
+ }
+ }
+
+ static const char* ReadCallback(lua_State*, void* data, size_t* size) {
+ return ((TLuaReader*)(data))->Read(size);
+ }
+
+ private:
+ inline const char* Read(size_t* readed) {
+ const char* ret;
+
+ if (*readed = In_->Next(&ret)) {
+ return ret;
+ }
+
+ return nullptr;
+ }
+
+ private:
+ IZeroCopyInput* In_;
+ }; // class TLuaReader
+
+ class TLuaWriter {
+ public:
+ TLuaWriter(IOutputStream* out)
+ : Out_(out)
+ {
+ }
+
+ inline void Dump(lua_State* state) {
+ if (lua_dump(state, WriteCallback, this)) {
+ ythrow TLuaStateHolder::TError() << "can not dump lua state: " << lua_tostring(state, -1);
+ }
+ }
+
+ static int WriteCallback(lua_State*, const void* data, size_t size, void* user) {
+ return ((TLuaWriter*)(user))->Write(data, size);
+ }
+
+ private:
+ inline int Write(const void* data, size_t size) {
+ Out_->Write(static_cast<const char*>(data), size);
+ return 0;
+ }
+
+ private:
+ IOutputStream* Out_;
+ }; // class TLuaWriter
+
+} //namespace
+
+void TLuaStateHolder::Load(IInputStream* in, TZtStringBuf name) {
+ TBufferedInput wi(in, 8192);
+ return TLuaReader(&wi).Load(State_, name.c_str());
+}
+
+void TLuaStateHolder::Dump(IOutputStream* out) {
+ return TLuaWriter(out).Dump(State_);
+}
+
+void TLuaStateHolder::DumpStack(IOutputStream* out) {
+ for (int i = lua_gettop(State_) * -1; i < 0; ++i) {
+ *out << i << " is " << lua_typename(State_, lua_type(State_, i)) << " (";
+ if (is_number(i)) {
+ *out << to_number<long long>(i);
+ } else if (is_string(i)) {
+ *out << to_string(i);
+ } else {
+ *out << Hex((uintptr_t)lua_topointer(State_, i), HF_ADDX);
+ }
+ *out << ')' << Endl;
+ }
+}
+
+void* TLuaStateHolder::Alloc(void* ud, void* ptr, size_t /*osize*/, size_t nsize) {
+ (void)ud;
+
+ if (nsize == 0) {
+ y_deallocate(ptr);
+
+ return nullptr;
+ }
+
+ return y_reallocate(ptr, nsize);
+}
+
+void* TLuaStateHolder::AllocLimit(void* ud, void* ptr, size_t osize, size_t nsize) {
+ TLuaStateHolder& state = *static_cast<TLuaStateHolder*>(ud);
+
+ if (nsize == 0) {
+ y_deallocate(ptr);
+ state.AllocFree += osize;
+
+ return nullptr;
+ }
+
+ if (state.AllocFree + osize < nsize) {
+ return nullptr;
+ }
+
+ ptr = y_reallocate(ptr, nsize);
+
+ if (ptr) {
+ state.AllocFree += osize;
+ state.AllocFree -= nsize;
+ }
+
+ return ptr;
+}
+
+void TLuaStateHolder::call(int args, int rets, int count) {
+ TLuaCountLimit limit(State_, count);
+ return call(args, rets);
+}
+
+void TLuaStateHolder::call(int args, int rets, TDuration time_limit, int count) {
+ TLuaTimeLimit limit(State_, time_limit, count);
+ return call(args, rets);
+}
+
+template <>
+void Out<NLua::TStackDumper>(IOutputStream& out, const NLua::TStackDumper& sd) {
+ sd.State.DumpStack(&out);
+}
+
+template <>
+void Out<NLua::TMarkedStackDumper>(IOutputStream& out, const NLua::TMarkedStackDumper& sd) {
+ out << sd.Mark << Endl;
+ sd.State.DumpStack(&out);
+ out << sd.Mark << Endl;
+}
+
+namespace NLua {
+ TBuffer& Compile(TStringBuf script, TBuffer& buffer) {
+ TMemoryInput input(script.data(), script.size());
+ TLuaStateHolder state;
+ state.Load(&input, "main");
+
+ TBufferOutput out(buffer);
+ state.Dump(&out);
+ return buffer;
+ }
+
+}
diff --git a/library/cpp/lua/wrapper.h b/library/cpp/lua/wrapper.h
new file mode 100644
index 0000000000..0d568f049a
--- /dev/null
+++ b/library/cpp/lua/wrapper.h
@@ -0,0 +1,565 @@
+#pragma once
+
+#include <library/cpp/string_utils/ztstrbuf/ztstrbuf.h>
+
+#include <util/memory/alloc.h>
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+#include <util/generic/yexception.h>
+#include <util/generic/buffer.h>
+#include <util/datetime/base.h>
+#include <functional>
+
+#include <contrib/libs/lua/lua.h>
+
+class IInputStream;
+class IOutputStream;
+
+class TLuaStateHolder {
+ struct TDeleteState {
+ static inline void Destroy(lua_State* state) {
+ lua_close(state);
+ }
+ };
+
+public:
+ class TError: public yexception {
+ };
+
+ inline TLuaStateHolder(size_t memory_limit = 0)
+ : AllocFree(memory_limit)
+ , MyState_(lua_newstate(memory_limit ? AllocLimit : Alloc, (void*)this))
+ , State_(MyState_.Get())
+ {
+ if (!State_) {
+ ythrow TError() << "can not construct lua state: not enough memory";
+ }
+ }
+
+ inline TLuaStateHolder(lua_State* state) noexcept
+ : State_(state)
+ {
+ }
+
+ inline operator lua_State*() noexcept {
+ return State_;
+ }
+
+ inline void BootStrap() {
+ luaL_openlibs(State_);
+ }
+
+ inline void error() {
+ ythrow TError() << "lua error: " << pop_string();
+ }
+
+ inline bool is_string(int index) {
+ return lua_isstring(State_, index);
+ }
+
+ inline void is_string_strict(int index) {
+ if (!is_string(index)) {
+ ythrow TError() << "internal lua error (not a string)";
+ }
+ }
+
+ inline TStringBuf to_string(int index) {
+ size_t len = 0;
+ const char* data = lua_tolstring(State_, index, &len);
+ return TStringBuf(data, len);
+ }
+
+ inline TStringBuf to_string(int index, TStringBuf defaultValue) {
+ return is_string(index) ? to_string(index) : defaultValue;
+ }
+
+ inline TStringBuf to_string_strict(int index) {
+ is_string_strict(index);
+ return to_string(index);
+ }
+
+ inline TString pop_string() {
+ TString ret(to_string(-1));
+ pop();
+ return ret;
+ }
+
+ inline TString pop_string(TStringBuf defaultValue) {
+ TString ret(to_string(-1, defaultValue));
+ pop();
+ return ret;
+ }
+
+ inline TString pop_string_strict() {
+ require(1);
+ TString ret(to_string_strict(-1));
+ pop();
+ return ret;
+ }
+
+ inline TString pop_value() {
+ require(1);
+ if (is_bool(-1)) {
+ return pop_bool() ? "true" : "false";
+ }
+ return pop_string_strict();
+ }
+
+ inline void push_string(const char* st) {
+ lua_pushstring(State_, st ? st : "");
+ }
+
+ inline void push_string(TStringBuf st) {
+ lua_pushlstring(State_, st.data(), st.size());
+ }
+
+ inline bool is_number(int index) {
+ return lua_isnumber(State_, index);
+ }
+
+ inline void is_number_strict(int index) {
+ if (!is_number(index)) {
+ ythrow TError() << "internal lua error (not a number)";
+ }
+ }
+
+ template <typename T>
+ inline T to_number(int index) {
+ return static_cast<T>(lua_tonumber(State_, index));
+ }
+
+ template <typename T>
+ inline T to_number(int index, T defaultValue) {
+ return is_number(index) ? to_number<T>(index) : defaultValue;
+ }
+
+ template <typename T>
+ inline T to_number_strict(int index) {
+ is_number_strict(index);
+ return to_number<T>(index);
+ }
+
+ template <typename T>
+ inline T pop_number() {
+ const T ret = to_number<T>(-1);
+ pop();
+ return ret;
+ }
+
+ template <typename T>
+ inline T pop_number(T defaultValue) {
+ const T ret = to_number<T>(-1, defaultValue);
+ pop();
+ return ret;
+ }
+
+ template <typename T>
+ inline T pop_number_strict() {
+ require(1);
+ const T ret = to_number_strict<T>(-1);
+ pop();
+ return ret;
+ }
+
+ template <typename T>
+ inline void push_number(T val) {
+ lua_pushnumber(State_, static_cast<lua_Number>(val));
+ }
+
+ inline bool is_bool(int index) {
+ return lua_isboolean(State_, index);
+ }
+
+ inline void is_bool_strict(int index) {
+ if (!is_bool(index)) {
+ ythrow TError() << "internal lua error (not a boolean)";
+ }
+ }
+
+ inline bool to_bool(int index) {
+ return lua_toboolean(State_, index);
+ }
+
+ inline bool to_bool(int index, bool defaultValue) {
+ return is_bool(index) ? to_bool(index) : defaultValue;
+ }
+
+ inline bool to_bool_strict(int index) {
+ is_bool_strict(index);
+ return to_bool(index);
+ }
+
+ inline bool pop_bool() {
+ const bool ret = to_bool(-1);
+ pop();
+ return ret;
+ }
+
+ inline bool pop_bool(bool defaultValue) {
+ const bool ret = to_bool(-1, defaultValue);
+ pop();
+ return ret;
+ }
+
+ inline bool pop_bool_strict() {
+ require(1);
+ const bool ret = to_bool_strict(-1);
+ pop();
+ return ret;
+ }
+
+ inline void push_bool(bool val) {
+ lua_pushboolean(State_, val);
+ }
+
+ inline bool is_nil(int index) {
+ return lua_isnil(State_, index);
+ }
+
+ inline void is_nil_strict(int index) {
+ if (!is_nil(index)) {
+ ythrow TError() << "internal lua error (not a nil)";
+ }
+ }
+
+ inline bool pop_nil() {
+ const bool ret = is_nil(-1);
+ pop();
+ return ret;
+ }
+
+ inline void pop_nil_strict() {
+ require(1);
+ is_nil_strict(-1);
+ pop();
+ }
+
+ inline void push_nil() {
+ lua_pushnil(State_);
+ }
+
+ inline bool is_void(int index) {
+ return lua_islightuserdata(State_, index);
+ }
+
+ inline void is_void_strict(int index) {
+ if (!is_void(index)) {
+ ythrow TError() << "internal lua error (not a void*)";
+ }
+ }
+
+ inline void* to_void(int index) {
+ return lua_touserdata(State_, index);
+ }
+
+ inline void* to_void(int index, void* defaultValue) {
+ return is_void(index) ? to_void(index) : defaultValue;
+ }
+
+ inline void* to_void_strict(int index) {
+ is_void_strict(index);
+ return to_void(index);
+ }
+
+ inline void* pop_void() {
+ void* ret = to_void(-1);
+ pop();
+ return ret;
+ }
+
+ inline void* pop_void(void* defaultValue) {
+ void* ret = to_void(-1, defaultValue);
+ pop();
+ return ret;
+ }
+
+ inline void* pop_void_strict() {
+ require(1);
+ void* ret = to_void_strict(-1);
+ pop();
+ return ret;
+ }
+
+ inline void push_void(void* ptr) {
+ lua_pushlightuserdata(State_, ptr);
+ }
+
+ template <typename T>
+ inline bool is_userdata(int index) {
+ return to_userdata<T>(index) != NULL;
+ }
+
+ template <typename T>
+ inline void is_userdata_strict(int index) {
+ to_userdata_strict<T>(index);
+ }
+
+ template <typename T>
+ inline T* to_userdata(int index) {
+ return static_cast<T*>(luaL_testudata(State_, index, T::LUA_METATABLE_NAME));
+ }
+
+ template <typename T>
+ inline T* to_userdata_strict(int index) {
+ T* ret = to_userdata<T>(index);
+ if (ret == nullptr) {
+ ythrow TError() << "internal error (not a userdata '" << T::LUA_METATABLE_NAME << "')";
+ }
+ return ret;
+ }
+
+ template <typename T>
+ inline T pop_userdata_strict() {
+ require(1);
+ const T ret(*to_userdata_strict<T>(-1));
+ pop();
+ return ret;
+ }
+
+ template <typename T>
+ inline T* push_userdata(const T& x) {
+ // copy constructor
+ return new (new_userdata<T>()) T(x);
+ }
+
+ template <typename T, typename... R>
+ inline T* push_userdata(const R&... r) {
+ return new (new_userdata<T>()) T(r...);
+ }
+
+ inline void push_global(const char* name) {
+ lua_getglobal(State_, name);
+ }
+
+ inline void set_global(const char* name, const char* value) {
+ lua_pushstring(State_, value);
+ set_global(name);
+ }
+
+ inline void set_global(const char* name, const double value) {
+ lua_pushnumber(State_, value);
+ set_global(name);
+ }
+
+ inline void set_global(const char* name) {
+ lua_setglobal(State_, name);
+ }
+
+ inline void register_function(const char* name, lua_CFunction func) {
+ lua_register(State_, name, func);
+ }
+
+ inline bool is_table(int index) {
+ return lua_istable(State_, index);
+ }
+
+ inline void is_table_strict(int index) {
+ if (!is_table(index)) {
+ ythrow TError() << "internal lua error (not a table)";
+ }
+ }
+
+ inline void create_table(int narr = 0, int nrec = 0) {
+ lua_createtable(State_, narr, nrec);
+ }
+
+ inline void set_table(int index) {
+ lua_settable(State_, index);
+ }
+
+ inline void get_field(int index, const char* key) {
+ lua_getfield(State_, index, key);
+ }
+
+ inline void set_field(int index, const char* key) {
+ lua_setfield(State_, index, key);
+ }
+
+ inline void rawseti(int index, int arr_index) {
+ lua_rawseti(State_, index, arr_index);
+ }
+
+ inline int check_stack(int extra) {
+ return lua_checkstack(State_, extra);
+ }
+
+ inline void ensure_stack(int extra) {
+ if (!check_stack(extra)) {
+ ythrow TError() << "cannot allocate more lua stack space";
+ };
+ }
+
+ inline void require(int n) {
+ if (on_stack() < n) {
+ ythrow TError() << "lua requirement failed";
+ }
+ }
+
+ inline void call(int args, int rets) {
+ if (lua_pcall(State_, args, rets, 0)) {
+ error();
+ }
+ }
+
+ void call(int args, int rets, TDuration time_limit, int count = 1000);
+ void call(int args, int rets, int limit);
+
+ inline void remove(int index) {
+ lua_remove(State_, index);
+ }
+
+ inline int next(int index) {
+ return lua_next(State_, index);
+ }
+
+ inline void pop(int n = 1) {
+ lua_pop(State_, Min(n, on_stack()));
+ }
+
+ inline void push_value(int index) {
+ lua_pushvalue(State_, index);
+ }
+
+ inline int on_stack() {
+ return lua_gettop(State_);
+ }
+
+ inline void gc() {
+ lua_gc(State_, LUA_GCCOLLECT, 0);
+ }
+
+ inline TLuaStateHolder new_thread() {
+ return lua_newthread(State_);
+ }
+
+ inline bool is_thread(int index) {
+ return lua_isthread(State_, index);
+ }
+
+ inline void is_thread_strict(int index) {
+ if (!is_thread(index)) {
+ ythrow TError() << "internal lua error (not a thread)";
+ }
+ }
+
+ inline TLuaStateHolder to_thread(int index) {
+ return lua_tothread(State_, index);
+ }
+
+ inline TLuaStateHolder to_thread_strict(int index) {
+ is_thread_strict(index);
+ return to_thread(index);
+ }
+
+ void Load(IInputStream* in, TZtStringBuf name);
+ void Dump(IOutputStream* out);
+ void DumpStack(IOutputStream* out);
+
+private:
+ template <typename T>
+ inline void set_metatable() {
+ if (luaL_newmetatable(State_, T::LUA_METATABLE_NAME)) {
+ // metatable isn't registered yet
+ push_string("__index");
+ push_value(-2); // pushes the metatable
+ set_table(-3); // metatable.__index = metatable
+ luaL_setfuncs(State_, T::LUA_FUNCTIONS, 0);
+ }
+ lua_setmetatable(State_, -2);
+ }
+
+ template <typename T>
+ inline void* new_userdata() {
+ void* p = lua_newuserdata(State_, sizeof(T));
+ set_metatable<T>();
+ return p;
+ }
+
+private:
+ static void* Alloc(void* ud, void* ptr, size_t osize, size_t nsize);
+ static void* AllocLimit(void* ud, void* ptr, size_t osize, size_t nsize);
+
+private:
+ size_t AllocFree = 0;
+ THolder<lua_State, TDeleteState> MyState_;
+ lua_State* State_ = nullptr;
+};
+
+namespace NLua {
+ template <int func(TLuaStateHolder&)>
+ int FunctionHandler(lua_State* L) {
+ try {
+ TLuaStateHolder state(L);
+ return func(state);
+ } catch (const yexception& e) {
+ lua_pushstring(L, e.what());
+ }
+ return lua_error(L);
+ }
+
+ template <class T, int (T::*Method)(TLuaStateHolder&)>
+ int MethodHandler(lua_State* L) {
+ T* x = static_cast<T*>(luaL_checkudata(L, 1, T::LUA_METATABLE_NAME));
+ try {
+ TLuaStateHolder state(L);
+ return (x->*Method)(state);
+ } catch (const yexception& e) {
+ lua_pushstring(L, e.what());
+ }
+ return lua_error(L);
+ }
+
+ template <class T, int (T::*Method)(TLuaStateHolder&) const>
+ int MethodConstHandler(lua_State* L) {
+ const T* x = static_cast<const T*>(luaL_checkudata(L, 1, T::LUA_METATABLE_NAME));
+ try {
+ TLuaStateHolder state(L);
+ return (x->*Method)(state);
+ } catch (const yexception& e) {
+ lua_pushstring(L, e.what());
+ }
+ return lua_error(L);
+ }
+
+ template <class T>
+ int Destructor(lua_State* L) {
+ T* x = static_cast<T*>(luaL_checkudata(L, 1, T::LUA_METATABLE_NAME));
+ try {
+ x->~T();
+ return 0;
+ } catch (const yexception& e) {
+ lua_pushstring(L, e.what());
+ }
+ return lua_error(L);
+ }
+
+ TBuffer& Compile(TStringBuf script, TBuffer& buffer);
+
+ struct TStackDumper {
+ TStackDumper(TLuaStateHolder& state)
+ : State(state)
+ {
+ }
+
+ TLuaStateHolder& State;
+ };
+
+ struct TMarkedStackDumper: public TStackDumper {
+ TMarkedStackDumper(TLuaStateHolder& state, TStringBuf mark)
+ : TStackDumper(state)
+ , Mark(mark)
+ {
+ }
+
+ TStringBuf Mark;
+ };
+
+ inline TMarkedStackDumper DumpStack(TLuaStateHolder& state, TStringBuf mark) {
+ return TMarkedStackDumper(state, mark);
+ }
+
+ inline TStackDumper DumpStack(TLuaStateHolder& state) {
+ return TStackDumper(state);
+ }
+
+}