diff options
author | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
---|---|---|
committer | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
commit | 06e5c21a835c0e923506c4ff27929f34e00761c2 (patch) | |
tree | 75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/lua | |
parent | 03f024c4412e3aa613bb543cf1660176320ba8f4 (diff) | |
download | ydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz |
fix ya.make
Diffstat (limited to 'library/cpp/lua')
-rw-r--r-- | library/cpp/lua/eval.cpp | 178 | ||||
-rw-r--r-- | library/cpp/lua/eval.h | 65 | ||||
-rw-r--r-- | library/cpp/lua/json.cpp | 62 | ||||
-rw-r--r-- | library/cpp/lua/json.h | 14 | ||||
-rw-r--r-- | library/cpp/lua/wrapper.cpp | 229 | ||||
-rw-r--r-- | library/cpp/lua/wrapper.h | 565 |
6 files changed, 1113 insertions, 0 deletions
diff --git a/library/cpp/lua/eval.cpp b/library/cpp/lua/eval.cpp new file mode 100644 index 0000000000..5c78d0ad3f --- /dev/null +++ b/library/cpp/lua/eval.cpp @@ -0,0 +1,178 @@ +#include "eval.h" +#include "json.h" +#include <util/string/cast.h> +#include <util/system/guard.h> +#include <util/stream/mem.h> +#include <util/string/builder.h> + +TLuaEval::TLuaEval() + : FunctionNameCounter_(0) +{ + LuaState_.BootStrap(); +} + +void TLuaEval::SetVariable(TZtStringBuf name, const NJson::TJsonValue& value) { + TGuard<TMutex> guard(LuaMutex_); + + NLua::PushJsonValue(&LuaState_, value); + LuaState_.set_global(name.c_str()); +} + +void TLuaEval::RunExpressionLocked(const TGuard<TMutex>&, const TExpression& expr) { + LuaState_.push_global(expr.Name.c_str()); + LuaState_.call(0, 1); +} + +TString TLuaEval::EvalCompiled(const TExpression& expr) { + TGuard<TMutex> guard(LuaMutex_); + RunExpressionLocked(guard, expr); + return LuaState_.pop_value(); +} + +void TLuaEval::EvalCompiledRaw(const TExpression& expr) { + TGuard<TMutex> guard(LuaMutex_); + RunExpressionLocked(guard, expr); +} + +bool TLuaEval::EvalCompiledCondition(const TExpression& expr) { + TGuard<TMutex> guard(LuaMutex_); + RunExpressionLocked(guard, expr); + return LuaState_.pop_bool_strict(); +} + +TString TLuaEval::EvalRaw(TStringBuf code) { + TMemoryInput bodyIn(code.data(), code.size()); + + LuaState_.Load(&bodyIn, "main"); + LuaState_.call(0, 1); + + return LuaState_.pop_value(); +} + +void TLuaEval::ParseChunk(TStringBuf code) { + TMemoryInput in(code.data(), code.size()); + + LuaState_.Load(&in, "chunk_" + GenerateName()); + LuaState_.call(0, 0); +} + +TString TLuaEval::EvalExpression(TStringBuf expression) { + const auto expr = Compile(expression); + try { + return EvalCompiled(expr); + } catch (const yexception& e) { + throw yexception(e) << '\n' << expression; + } +} + +TLuaEval::TExpression TLuaEval::Compile(TStringBuf expression) { + TGuard<TMutex> guard(LuaMutex_); + + TString name = GenerateName(); + + TString body = "function "; + body += name; + body += "()\n\treturn ("; + body += expression; + body += ")\nend\n"; + + try { + TMemoryInput bodyIn(body.c_str(), body.size()); + LuaState_.Load(&bodyIn, "chunk_" + name); + LuaState_.call(0, 0); + } catch (const yexception& e) { + ythrow yexception(e) << "\n" + << body; + } + return {name}; +} + +TLuaEval::TExpression TLuaEval::CompileFunction(TStringBuf expression) { + TString name = GenerateName(); + TStringBuilder body; + body << "function " << name << "()" << Endl + << expression << Endl + << "end"; + + return CompileRaw(TStringBuf(body.data(), body.size()), name); +} + +TLuaEval::TExpression TLuaEval::CompileRaw(TStringBuf body, const TString& name) { + TGuard<TMutex> guard(LuaMutex_); + try { + TMemoryInput bodyIn(body.data(), body.size()); + LuaState_.Load(&bodyIn, "chunk_" + name); + LuaState_.call(0, 0); + } catch (const yexception& e) { + ythrow yexception(e) << "\n" << body; + } + return { name }; +} + +TString TLuaEval::GenerateName() { + TGuard<TMutex> guard(LuaMutex_); + return "dummy_" + ToString(FunctionNameCounter_++); +} + +template <class T> +static inline T FindEnd(T b, T e) { + size_t cnt = 0; + + while (b < e) { + switch (*b) { + case '{': + ++cnt; + break; + + case '}': + if (cnt == 0) { + return b; + } + + --cnt; + break; + } + + ++b; + } + + return b; +} + +TString TLuaEval::PreprocessOne(TStringBuf line) { + const size_t pos = line.find("${"); + + if (pos == TStringBuf::npos) { + return EvalExpression(line); + } + + const char* rpos = FindEnd(line.data() + pos + 2, line.end()); + + if (rpos == line.end()) { + ythrow yexception() << TStringBuf("can not parse ") << line; + } + + const TStringBuf before = line.SubStr(0, pos); + const TStringBuf after = TStringBuf(rpos + 1, line.end()); + const TStringBuf code = TStringBuf(line.data() + pos + 2, rpos); + + TString res; + + if (code.find("${") == TStringBuf::npos) { + res = EvalExpression(code); + } else { + res = EvalExpression(Preprocess(code)); + } + + return ToString(before) + res + ToString(after); +} + +TString TLuaEval::Preprocess(TStringBuf line) { + TString res = ToString(line); + + while (res.find("${") != TString::npos) { + res = PreprocessOne(res); + } + + return res; +} diff --git a/library/cpp/lua/eval.h b/library/cpp/lua/eval.h new file mode 100644 index 0000000000..77c59d7efd --- /dev/null +++ b/library/cpp/lua/eval.h @@ -0,0 +1,65 @@ +#pragma once + +#include "wrapper.h" + +#include <library/cpp/json/json_value.h> + +#include <util/system/mutex.h> + +class TLuaEval { +public: + TLuaEval(); + + template <class C> + inline TLuaEval& SetVars(const C& container) { + for (auto& [k, v] : container) { + SetVariable(k, v); + } + + return *this; + } + + inline TLuaEval& Parse(TStringBuf chunk) { + ParseChunk(chunk); + + return *this; + } + + void SetVariable(TZtStringBuf name, const NJson::TJsonValue& value); + template <typename T> + void SetUserdata(TZtStringBuf name, T&& userdata) { + LuaState_.push_userdata(std::forward<T>(userdata)); + LuaState_.set_global(name.c_str()); + } + TString EvalExpression(TStringBuf expression); + TString EvalRaw(TStringBuf code); + void ParseChunk(TStringBuf code); + TString Preprocess(TStringBuf line); + TString PreprocessOne(TStringBuf line); + + struct TExpression { + TString Name; + }; + + TExpression Compile(TStringBuf expression); + TExpression CompileFunction(TStringBuf expression); + TExpression CompileRaw(TStringBuf body, const TString& name); + TString EvalCompiled(const TExpression& compiled); + void EvalCompiledRaw(const TExpression& compiled); + bool EvalCompiledCondition(const TExpression& compiled); + template <typename TNumber> + TNumber EvalCompiledNumeric(const TExpression& compiled) { + TGuard<TMutex> guard(LuaMutex_); + RunExpressionLocked(guard, compiled); + return LuaState_.pop_number<TNumber>(); + } + +private: + TString GenerateName(); + TString Evaluate(const TString& name, const TString& body); + void RunExpressionLocked(const TGuard<TMutex>& lock, const TExpression& compiled); + + TLuaStateHolder LuaState_; + ui64 FunctionNameCounter_; + TMutex LuaMutex_; +}; diff --git a/library/cpp/lua/json.cpp b/library/cpp/lua/json.cpp new file mode 100644 index 0000000000..da7d228459 --- /dev/null +++ b/library/cpp/lua/json.cpp @@ -0,0 +1,62 @@ +#include "json.h" +#include "wrapper.h" + +#include <library/cpp/json/json_value.h> + +using namespace NJson; + +void NLua::PushJsonValue(TLuaStateHolder* state, const TJsonValue& json) { + // each recursive call will explicitly push only a single element to stack relying on subcalls to reserve stack space for themselves + // I.e. for a map {"a": "b"} the first call will ensure stack space for create_table, then call PushJsonValue for the string, + // this PushJsonValue will ensure stack space for the string. Thus only a single ensure_stack at the start of the function is enough. + state->ensure_stack(1); + switch (json.GetType()) { + case JSON_UNDEFINED: + ythrow yexception() << "cannot push undefined json value"; + + case JSON_NULL: + state->push_nil(); + break; + + case JSON_BOOLEAN: + state->push_bool(json.GetBoolean()); + break; + + case JSON_INTEGER: + state->push_number(json.GetInteger()); + break; + + case JSON_UINTEGER: + state->push_number(json.GetUInteger()); + break; + + case JSON_DOUBLE: + state->push_number(json.GetDouble()); + break; + + case JSON_STRING: + state->push_string(json.GetString()); + break; + + case JSON_MAP: + state->create_table(); + for (const auto& pair : json.GetMap()) { + PushJsonValue(state, pair.second); // Recursive call tests for stack space on its own + state->set_field(-2, pair.first.data()); + } + break; + + case JSON_ARRAY: { + state->create_table(); + int index = 1; // lua arrays start from 1 + for (const auto& element : json.GetArray()) { + PushJsonValue(state, element); // Recursive call tests for stack space on its own, no need to double check + state->rawseti(-2, index++); + } + break; + } + + default: + ythrow yexception() << "Unexpected json value type"; + } +} diff --git a/library/cpp/lua/json.h b/library/cpp/lua/json.h new file mode 100644 index 0000000000..eead596756 --- /dev/null +++ b/library/cpp/lua/json.h @@ -0,0 +1,14 @@ +#pragma once + +class TLuaStateHolder; + +namespace NJson { + class TJsonValue; +} + +namespace NLua { + // Try to push TJsonValue to lua stack. + // Lua stack state is undefined if there's not enough memory to grow stack appropriately + // Exception is thrown in this case + void PushJsonValue(TLuaStateHolder* state, const NJson::TJsonValue& json); +} diff --git a/library/cpp/lua/wrapper.cpp b/library/cpp/lua/wrapper.cpp new file mode 100644 index 0000000000..ad0fb7537a --- /dev/null +++ b/library/cpp/lua/wrapper.cpp @@ -0,0 +1,229 @@ +#include "wrapper.h" + +#include <util/datetime/cputimer.h> +#include <util/stream/buffered.h> +#include <util/stream/buffer.h> +#include <util/stream/format.h> +#include <util/stream/input.h> +#include <util/stream/mem.h> +#include <util/stream/output.h> +#include <util/system/sys_alloc.h> + +namespace { + class TLuaCountLimit { + public: + TLuaCountLimit(lua_State* state, int count) + : State(state) + { + lua_sethook(State, LuaHookCallback, LUA_MASKCOUNT, count); + } + + ~TLuaCountLimit() { + lua_sethook(State, LuaHookCallback, 0, 0); + } + + static void LuaHookCallback(lua_State* L, lua_Debug*) { + luaL_error(L, "Lua instruction count limit exceeded"); + } + + private: + lua_State* State; + }; // class TLuaCountLimit + + class TLuaTimeLimit { + public: + TLuaTimeLimit(lua_State* state, TDuration limit, int count) + : State(state) + , Limit(limit) + { + lua_pushlightuserdata(State, (void*)LuaHookCallback); //key + lua_pushlightuserdata(State, (void*)this); //value + lua_settable(State, LUA_REGISTRYINDEX); + + lua_sethook(State, LuaHookCallback, LUA_MASKCOUNT, count); + } + + ~TLuaTimeLimit() { + lua_sethook(State, LuaHookCallback, 0, 0); + } + + bool Exceeded() { + return Timer.Get() > Limit; + } + + static void LuaHookCallback(lua_State* L, lua_Debug*) { + lua_pushlightuserdata(L, (void*)LuaHookCallback); + lua_gettable(L, LUA_REGISTRYINDEX); + TLuaTimeLimit* t = static_cast<TLuaTimeLimit*>(lua_touserdata(L, -1)); + lua_pop(L, 1); + if (t->Exceeded()) { + luaL_error(L, "time limit exceeded"); + } + } + + private: + lua_State* State; + const TDuration Limit; + TSimpleTimer Timer; + }; // class TLuaTimeLimit + + class TLuaReader { + public: + TLuaReader(IZeroCopyInput* in) + : In_(in) + { + } + + inline void Load(lua_State* state, const char* name) { + if (lua_load(state, ReadCallback, this, name +#if LUA_VERSION_NUM > 501 + , + nullptr +#endif + )) + { + ythrow TLuaStateHolder::TError() << "can not parse lua chunk " << name << ": " << lua_tostring(state, -1); + } + } + + static const char* ReadCallback(lua_State*, void* data, size_t* size) { + return ((TLuaReader*)(data))->Read(size); + } + + private: + inline const char* Read(size_t* readed) { + const char* ret; + + if (*readed = In_->Next(&ret)) { + return ret; + } + + return nullptr; + } + + private: + IZeroCopyInput* In_; + }; // class TLuaReader + + class TLuaWriter { + public: + TLuaWriter(IOutputStream* out) + : Out_(out) + { + } + + inline void Dump(lua_State* state) { + if (lua_dump(state, WriteCallback, this)) { + ythrow TLuaStateHolder::TError() << "can not dump lua state: " << lua_tostring(state, -1); + } + } + + static int WriteCallback(lua_State*, const void* data, size_t size, void* user) { + return ((TLuaWriter*)(user))->Write(data, size); + } + + private: + inline int Write(const void* data, size_t size) { + Out_->Write(static_cast<const char*>(data), size); + return 0; + } + + private: + IOutputStream* Out_; + }; // class TLuaWriter + +} //namespace + +void TLuaStateHolder::Load(IInputStream* in, TZtStringBuf name) { + TBufferedInput wi(in, 8192); + return TLuaReader(&wi).Load(State_, name.c_str()); +} + +void TLuaStateHolder::Dump(IOutputStream* out) { + return TLuaWriter(out).Dump(State_); +} + +void TLuaStateHolder::DumpStack(IOutputStream* out) { + for (int i = lua_gettop(State_) * -1; i < 0; ++i) { + *out << i << " is " << lua_typename(State_, lua_type(State_, i)) << " ("; + if (is_number(i)) { + *out << to_number<long long>(i); + } else if (is_string(i)) { + *out << to_string(i); + } else { + *out << Hex((uintptr_t)lua_topointer(State_, i), HF_ADDX); + } + *out << ')' << Endl; + } +} + +void* TLuaStateHolder::Alloc(void* ud, void* ptr, size_t /*osize*/, size_t nsize) { + (void)ud; + + if (nsize == 0) { + y_deallocate(ptr); + + return nullptr; + } + + return y_reallocate(ptr, nsize); +} + +void* TLuaStateHolder::AllocLimit(void* ud, void* ptr, size_t osize, size_t nsize) { + TLuaStateHolder& state = *static_cast<TLuaStateHolder*>(ud); + + if (nsize == 0) { + y_deallocate(ptr); + state.AllocFree += osize; + + return nullptr; + } + + if (state.AllocFree + osize < nsize) { + return nullptr; + } + + ptr = y_reallocate(ptr, nsize); + + if (ptr) { + state.AllocFree += osize; + state.AllocFree -= nsize; + } + + return ptr; +} + +void TLuaStateHolder::call(int args, int rets, int count) { + TLuaCountLimit limit(State_, count); + return call(args, rets); +} + +void TLuaStateHolder::call(int args, int rets, TDuration time_limit, int count) { + TLuaTimeLimit limit(State_, time_limit, count); + return call(args, rets); +} + +template <> +void Out<NLua::TStackDumper>(IOutputStream& out, const NLua::TStackDumper& sd) { + sd.State.DumpStack(&out); +} + +template <> +void Out<NLua::TMarkedStackDumper>(IOutputStream& out, const NLua::TMarkedStackDumper& sd) { + out << sd.Mark << Endl; + sd.State.DumpStack(&out); + out << sd.Mark << Endl; +} + +namespace NLua { + TBuffer& Compile(TStringBuf script, TBuffer& buffer) { + TMemoryInput input(script.data(), script.size()); + TLuaStateHolder state; + state.Load(&input, "main"); + + TBufferOutput out(buffer); + state.Dump(&out); + return buffer; + } + +} diff --git a/library/cpp/lua/wrapper.h b/library/cpp/lua/wrapper.h new file mode 100644 index 0000000000..0d568f049a --- /dev/null +++ b/library/cpp/lua/wrapper.h @@ -0,0 +1,565 @@ +#pragma once + +#include <library/cpp/string_utils/ztstrbuf/ztstrbuf.h> + +#include <util/memory/alloc.h> +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/generic/yexception.h> +#include <util/generic/buffer.h> +#include <util/datetime/base.h> +#include <functional> + +#include <contrib/libs/lua/lua.h> + +class IInputStream; +class IOutputStream; + +class TLuaStateHolder { + struct TDeleteState { + static inline void Destroy(lua_State* state) { + lua_close(state); + } + }; + +public: + class TError: public yexception { + }; + + inline TLuaStateHolder(size_t memory_limit = 0) + : AllocFree(memory_limit) + , MyState_(lua_newstate(memory_limit ? AllocLimit : Alloc, (void*)this)) + , State_(MyState_.Get()) + { + if (!State_) { + ythrow TError() << "can not construct lua state: not enough memory"; + } + } + + inline TLuaStateHolder(lua_State* state) noexcept + : State_(state) + { + } + + inline operator lua_State*() noexcept { + return State_; + } + + inline void BootStrap() { + luaL_openlibs(State_); + } + + inline void error() { + ythrow TError() << "lua error: " << pop_string(); + } + + inline bool is_string(int index) { + return lua_isstring(State_, index); + } + + inline void is_string_strict(int index) { + if (!is_string(index)) { + ythrow TError() << "internal lua error (not a string)"; + } + } + + inline TStringBuf to_string(int index) { + size_t len = 0; + const char* data = lua_tolstring(State_, index, &len); + return TStringBuf(data, len); + } + + inline TStringBuf to_string(int index, TStringBuf defaultValue) { + return is_string(index) ? to_string(index) : defaultValue; + } + + inline TStringBuf to_string_strict(int index) { + is_string_strict(index); + return to_string(index); + } + + inline TString pop_string() { + TString ret(to_string(-1)); + pop(); + return ret; + } + + inline TString pop_string(TStringBuf defaultValue) { + TString ret(to_string(-1, defaultValue)); + pop(); + return ret; + } + + inline TString pop_string_strict() { + require(1); + TString ret(to_string_strict(-1)); + pop(); + return ret; + } + + inline TString pop_value() { + require(1); + if (is_bool(-1)) { + return pop_bool() ? "true" : "false"; + } + return pop_string_strict(); + } + + inline void push_string(const char* st) { + lua_pushstring(State_, st ? st : ""); + } + + inline void push_string(TStringBuf st) { + lua_pushlstring(State_, st.data(), st.size()); + } + + inline bool is_number(int index) { + return lua_isnumber(State_, index); + } + + inline void is_number_strict(int index) { + if (!is_number(index)) { + ythrow TError() << "internal lua error (not a number)"; + } + } + + template <typename T> + inline T to_number(int index) { + return static_cast<T>(lua_tonumber(State_, index)); + } + + template <typename T> + inline T to_number(int index, T defaultValue) { + return is_number(index) ? to_number<T>(index) : defaultValue; + } + + template <typename T> + inline T to_number_strict(int index) { + is_number_strict(index); + return to_number<T>(index); + } + + template <typename T> + inline T pop_number() { + const T ret = to_number<T>(-1); + pop(); + return ret; + } + + template <typename T> + inline T pop_number(T defaultValue) { + const T ret = to_number<T>(-1, defaultValue); + pop(); + return ret; + } + + template <typename T> + inline T pop_number_strict() { + require(1); + const T ret = to_number_strict<T>(-1); + pop(); + return ret; + } + + template <typename T> + inline void push_number(T val) { + lua_pushnumber(State_, static_cast<lua_Number>(val)); + } + + inline bool is_bool(int index) { + return lua_isboolean(State_, index); + } + + inline void is_bool_strict(int index) { + if (!is_bool(index)) { + ythrow TError() << "internal lua error (not a boolean)"; + } + } + + inline bool to_bool(int index) { + return lua_toboolean(State_, index); + } + + inline bool to_bool(int index, bool defaultValue) { + return is_bool(index) ? to_bool(index) : defaultValue; + } + + inline bool to_bool_strict(int index) { + is_bool_strict(index); + return to_bool(index); + } + + inline bool pop_bool() { + const bool ret = to_bool(-1); + pop(); + return ret; + } + + inline bool pop_bool(bool defaultValue) { + const bool ret = to_bool(-1, defaultValue); + pop(); + return ret; + } + + inline bool pop_bool_strict() { + require(1); + const bool ret = to_bool_strict(-1); + pop(); + return ret; + } + + inline void push_bool(bool val) { + lua_pushboolean(State_, val); + } + + inline bool is_nil(int index) { + return lua_isnil(State_, index); + } + + inline void is_nil_strict(int index) { + if (!is_nil(index)) { + ythrow TError() << "internal lua error (not a nil)"; + } + } + + inline bool pop_nil() { + const bool ret = is_nil(-1); + pop(); + return ret; + } + + inline void pop_nil_strict() { + require(1); + is_nil_strict(-1); + pop(); + } + + inline void push_nil() { + lua_pushnil(State_); + } + + inline bool is_void(int index) { + return lua_islightuserdata(State_, index); + } + + inline void is_void_strict(int index) { + if (!is_void(index)) { + ythrow TError() << "internal lua error (not a void*)"; + } + } + + inline void* to_void(int index) { + return lua_touserdata(State_, index); + } + + inline void* to_void(int index, void* defaultValue) { + return is_void(index) ? to_void(index) : defaultValue; + } + + inline void* to_void_strict(int index) { + is_void_strict(index); + return to_void(index); + } + + inline void* pop_void() { + void* ret = to_void(-1); + pop(); + return ret; + } + + inline void* pop_void(void* defaultValue) { + void* ret = to_void(-1, defaultValue); + pop(); + return ret; + } + + inline void* pop_void_strict() { + require(1); + void* ret = to_void_strict(-1); + pop(); + return ret; + } + + inline void push_void(void* ptr) { + lua_pushlightuserdata(State_, ptr); + } + + template <typename T> + inline bool is_userdata(int index) { + return to_userdata<T>(index) != NULL; + } + + template <typename T> + inline void is_userdata_strict(int index) { + to_userdata_strict<T>(index); + } + + template <typename T> + inline T* to_userdata(int index) { + return static_cast<T*>(luaL_testudata(State_, index, T::LUA_METATABLE_NAME)); + } + + template <typename T> + inline T* to_userdata_strict(int index) { + T* ret = to_userdata<T>(index); + if (ret == nullptr) { + ythrow TError() << "internal error (not a userdata '" << T::LUA_METATABLE_NAME << "')"; + } + return ret; + } + + template <typename T> + inline T pop_userdata_strict() { + require(1); + const T ret(*to_userdata_strict<T>(-1)); + pop(); + return ret; + } + + template <typename T> + inline T* push_userdata(const T& x) { + // copy constructor + return new (new_userdata<T>()) T(x); + } + + template <typename T, typename... R> + inline T* push_userdata(const R&... r) { + return new (new_userdata<T>()) T(r...); + } + + inline void push_global(const char* name) { + lua_getglobal(State_, name); + } + + inline void set_global(const char* name, const char* value) { + lua_pushstring(State_, value); + set_global(name); + } + + inline void set_global(const char* name, const double value) { + lua_pushnumber(State_, value); + set_global(name); + } + + inline void set_global(const char* name) { + lua_setglobal(State_, name); + } + + inline void register_function(const char* name, lua_CFunction func) { + lua_register(State_, name, func); + } + + inline bool is_table(int index) { + return lua_istable(State_, index); + } + + inline void is_table_strict(int index) { + if (!is_table(index)) { + ythrow TError() << "internal lua error (not a table)"; + } + } + + inline void create_table(int narr = 0, int nrec = 0) { + lua_createtable(State_, narr, nrec); + } + + inline void set_table(int index) { + lua_settable(State_, index); + } + + inline void get_field(int index, const char* key) { + lua_getfield(State_, index, key); + } + + inline void set_field(int index, const char* key) { + lua_setfield(State_, index, key); + } + + inline void rawseti(int index, int arr_index) { + lua_rawseti(State_, index, arr_index); + } + + inline int check_stack(int extra) { + return lua_checkstack(State_, extra); + } + + inline void ensure_stack(int extra) { + if (!check_stack(extra)) { + ythrow TError() << "cannot allocate more lua stack space"; + }; + } + + inline void require(int n) { + if (on_stack() < n) { + ythrow TError() << "lua requirement failed"; + } + } + + inline void call(int args, int rets) { + if (lua_pcall(State_, args, rets, 0)) { + error(); + } + } + + void call(int args, int rets, TDuration time_limit, int count = 1000); + void call(int args, int rets, int limit); + + inline void remove(int index) { + lua_remove(State_, index); + } + + inline int next(int index) { + return lua_next(State_, index); + } + + inline void pop(int n = 1) { + lua_pop(State_, Min(n, on_stack())); + } + + inline void push_value(int index) { + lua_pushvalue(State_, index); + } + + inline int on_stack() { + return lua_gettop(State_); + } + + inline void gc() { + lua_gc(State_, LUA_GCCOLLECT, 0); + } + + inline TLuaStateHolder new_thread() { + return lua_newthread(State_); + } + + inline bool is_thread(int index) { + return lua_isthread(State_, index); + } + + inline void is_thread_strict(int index) { + if (!is_thread(index)) { + ythrow TError() << "internal lua error (not a thread)"; + } + } + + inline TLuaStateHolder to_thread(int index) { + return lua_tothread(State_, index); + } + + inline TLuaStateHolder to_thread_strict(int index) { + is_thread_strict(index); + return to_thread(index); + } + + void Load(IInputStream* in, TZtStringBuf name); + void Dump(IOutputStream* out); + void DumpStack(IOutputStream* out); + +private: + template <typename T> + inline void set_metatable() { + if (luaL_newmetatable(State_, T::LUA_METATABLE_NAME)) { + // metatable isn't registered yet + push_string("__index"); + push_value(-2); // pushes the metatable + set_table(-3); // metatable.__index = metatable + luaL_setfuncs(State_, T::LUA_FUNCTIONS, 0); + } + lua_setmetatable(State_, -2); + } + + template <typename T> + inline void* new_userdata() { + void* p = lua_newuserdata(State_, sizeof(T)); + set_metatable<T>(); + return p; + } + +private: + static void* Alloc(void* ud, void* ptr, size_t osize, size_t nsize); + static void* AllocLimit(void* ud, void* ptr, size_t osize, size_t nsize); + +private: + size_t AllocFree = 0; + THolder<lua_State, TDeleteState> MyState_; + lua_State* State_ = nullptr; +}; + +namespace NLua { + template <int func(TLuaStateHolder&)> + int FunctionHandler(lua_State* L) { + try { + TLuaStateHolder state(L); + return func(state); + } catch (const yexception& e) { + lua_pushstring(L, e.what()); + } + return lua_error(L); + } + + template <class T, int (T::*Method)(TLuaStateHolder&)> + int MethodHandler(lua_State* L) { + T* x = static_cast<T*>(luaL_checkudata(L, 1, T::LUA_METATABLE_NAME)); + try { + TLuaStateHolder state(L); + return (x->*Method)(state); + } catch (const yexception& e) { + lua_pushstring(L, e.what()); + } + return lua_error(L); + } + + template <class T, int (T::*Method)(TLuaStateHolder&) const> + int MethodConstHandler(lua_State* L) { + const T* x = static_cast<const T*>(luaL_checkudata(L, 1, T::LUA_METATABLE_NAME)); + try { + TLuaStateHolder state(L); + return (x->*Method)(state); + } catch (const yexception& e) { + lua_pushstring(L, e.what()); + } + return lua_error(L); + } + + template <class T> + int Destructor(lua_State* L) { + T* x = static_cast<T*>(luaL_checkudata(L, 1, T::LUA_METATABLE_NAME)); + try { + x->~T(); + return 0; + } catch (const yexception& e) { + lua_pushstring(L, e.what()); + } + return lua_error(L); + } + + TBuffer& Compile(TStringBuf script, TBuffer& buffer); + + struct TStackDumper { + TStackDumper(TLuaStateHolder& state) + : State(state) + { + } + + TLuaStateHolder& State; + }; + + struct TMarkedStackDumper: public TStackDumper { + TMarkedStackDumper(TLuaStateHolder& state, TStringBuf mark) + : TStackDumper(state) + , Mark(mark) + { + } + + TStringBuf Mark; + }; + + inline TMarkedStackDumper DumpStack(TLuaStateHolder& state, TStringBuf mark) { + return TMarkedStackDumper(state, mark); + } + + inline TStackDumper DumpStack(TLuaStateHolder& state) { + return TStackDumper(state); + } + +} |