#include "wrapper.h" #include <util/datetime/cputimer.h> #include <util/stream/buffered.h> #include <util/stream/buffer.h> #include <util/stream/format.h> #include <util/stream/input.h> #include <util/stream/mem.h> #include <util/stream/output.h> #include <util/system/sys_alloc.h> namespace { class TLuaCountLimit { public: TLuaCountLimit(lua_State* state, int count) : State(state) { lua_sethook(State, LuaHookCallback, LUA_MASKCOUNT, count); } ~TLuaCountLimit() { lua_sethook(State, LuaHookCallback, 0, 0); } static void LuaHookCallback(lua_State* L, lua_Debug*) { luaL_error(L, "Lua instruction count limit exceeded"); } private: lua_State* State; }; // class TLuaCountLimit class TLuaTimeLimit { public: TLuaTimeLimit(lua_State* state, TDuration limit, int count) : State(state) , Limit(limit) { lua_pushlightuserdata(State, (void*)LuaHookCallback); //key lua_pushlightuserdata(State, (void*)this); //value lua_settable(State, LUA_REGISTRYINDEX); lua_sethook(State, LuaHookCallback, LUA_MASKCOUNT, count); } ~TLuaTimeLimit() { lua_sethook(State, LuaHookCallback, 0, 0); } bool Exceeded() { return Timer.Get() > Limit; } static void LuaHookCallback(lua_State* L, lua_Debug*) { lua_pushlightuserdata(L, (void*)LuaHookCallback); lua_gettable(L, LUA_REGISTRYINDEX); TLuaTimeLimit* t = static_cast<TLuaTimeLimit*>(lua_touserdata(L, -1)); lua_pop(L, 1); if (t->Exceeded()) { luaL_error(L, "time limit exceeded"); } } private: lua_State* State; const TDuration Limit; TSimpleTimer Timer; }; // class TLuaTimeLimit class TLuaReader { public: TLuaReader(IZeroCopyInput* in) : In_(in) { } inline void Load(lua_State* state, const char* name) { if (lua_load(state, ReadCallback, this, name #if LUA_VERSION_NUM > 501 , nullptr #endif )) { ythrow TLuaStateHolder::TError() << "can not parse lua chunk " << name << ": " << lua_tostring(state, -1); } } static const char* ReadCallback(lua_State*, void* data, size_t* size) { return ((TLuaReader*)(data))->Read(size); } private: inline const char* Read(size_t* readed) { const char* ret; if (*readed = In_->Next(&ret)) { return ret; } return nullptr; } private: IZeroCopyInput* In_; }; // class TLuaReader class TLuaWriter { public: TLuaWriter(IOutputStream* out) : Out_(out) { } inline void Dump(lua_State* state) { if (lua_dump(state, WriteCallback, this)) { ythrow TLuaStateHolder::TError() << "can not dump lua state: " << lua_tostring(state, -1); } } static int WriteCallback(lua_State*, const void* data, size_t size, void* user) { return ((TLuaWriter*)(user))->Write(data, size); } private: inline int Write(const void* data, size_t size) { Out_->Write(static_cast<const char*>(data), size); return 0; } private: IOutputStream* Out_; }; // class TLuaWriter } //namespace void TLuaStateHolder::Load(IInputStream* in, TZtStringBuf name) { TBufferedInput wi(in, 8192); return TLuaReader(&wi).Load(State_, name.c_str()); } void TLuaStateHolder::Dump(IOutputStream* out) { return TLuaWriter(out).Dump(State_); } void TLuaStateHolder::DumpStack(IOutputStream* out) { for (int i = lua_gettop(State_) * -1; i < 0; ++i) { *out << i << " is " << lua_typename(State_, lua_type(State_, i)) << " ("; if (is_number(i)) { *out << to_number<long long>(i); } else if (is_string(i)) { *out << to_string(i); } else { *out << Hex((uintptr_t)lua_topointer(State_, i), HF_ADDX); } *out << ')' << Endl; } } void* TLuaStateHolder::Alloc(void* ud, void* ptr, size_t /*osize*/, size_t nsize) { (void)ud; if (nsize == 0) { y_deallocate(ptr); return nullptr; } return y_reallocate(ptr, nsize); } void* TLuaStateHolder::AllocLimit(void* ud, void* ptr, size_t osize, size_t nsize) { TLuaStateHolder& state = *static_cast<TLuaStateHolder*>(ud); if (nsize == 0) { y_deallocate(ptr); state.AllocFree += osize; return nullptr; } if (state.AllocFree + osize < nsize) { return nullptr; } ptr = y_reallocate(ptr, nsize); if (ptr) { state.AllocFree += osize; state.AllocFree -= nsize; } return ptr; } void TLuaStateHolder::call(int args, int rets, int count) { TLuaCountLimit limit(State_, count); return call(args, rets); } void TLuaStateHolder::call(int args, int rets, TDuration time_limit, int count) { TLuaTimeLimit limit(State_, time_limit, count); return call(args, rets); } template <> void Out<NLua::TStackDumper>(IOutputStream& out, const NLua::TStackDumper& sd) { sd.State.DumpStack(&out); } template <> void Out<NLua::TMarkedStackDumper>(IOutputStream& out, const NLua::TMarkedStackDumper& sd) { out << sd.Mark << Endl; sd.State.DumpStack(&out); out << sd.Mark << Endl; } namespace NLua { TBuffer& Compile(TStringBuf script, TBuffer& buffer) { TMemoryInput input(script.data(), script.size()); TLuaStateHolder state; state.Load(&input, "main"); TBufferOutput out(buffer); state.Dump(&out); return buffer; } }