#include "wrapper.h"
#include <util/datetime/cputimer.h>
#include <util/stream/buffered.h>
#include <util/stream/buffer.h>
#include <util/stream/format.h>
#include <util/stream/input.h>
#include <util/stream/mem.h>
#include <util/stream/output.h>
#include <util/system/sys_alloc.h>
namespace {
class TLuaCountLimit {
public:
TLuaCountLimit(lua_State* state, int count)
: State(state)
{
lua_sethook(State, LuaHookCallback, LUA_MASKCOUNT, count);
}
~TLuaCountLimit() {
lua_sethook(State, LuaHookCallback, 0, 0);
}
static void LuaHookCallback(lua_State* L, lua_Debug*) {
luaL_error(L, "Lua instruction count limit exceeded");
}
private:
lua_State* State;
}; // class TLuaCountLimit
class TLuaTimeLimit {
public:
TLuaTimeLimit(lua_State* state, TDuration limit, int count)
: State(state)
, Limit(limit)
{
lua_pushlightuserdata(State, (void*)LuaHookCallback); //key
lua_pushlightuserdata(State, (void*)this); //value
lua_settable(State, LUA_REGISTRYINDEX);
lua_sethook(State, LuaHookCallback, LUA_MASKCOUNT, count);
}
~TLuaTimeLimit() {
lua_sethook(State, LuaHookCallback, 0, 0);
}
bool Exceeded() {
return Timer.Get() > Limit;
}
static void LuaHookCallback(lua_State* L, lua_Debug*) {
lua_pushlightuserdata(L, (void*)LuaHookCallback);
lua_gettable(L, LUA_REGISTRYINDEX);
TLuaTimeLimit* t = static_cast<TLuaTimeLimit*>(lua_touserdata(L, -1));
lua_pop(L, 1);
if (t->Exceeded()) {
luaL_error(L, "time limit exceeded");
}
}
private:
lua_State* State;
const TDuration Limit;
TSimpleTimer Timer;
}; // class TLuaTimeLimit
class TLuaReader {
public:
TLuaReader(IZeroCopyInput* in)
: In_(in)
{
}
inline void Load(lua_State* state, const char* name) {
if (lua_load(state, ReadCallback, this, name
#if LUA_VERSION_NUM > 501
,
nullptr
#endif
))
{
ythrow TLuaStateHolder::TError() << "can not parse lua chunk " << name << ": " << lua_tostring(state, -1);
}
}
static const char* ReadCallback(lua_State*, void* data, size_t* size) {
return ((TLuaReader*)(data))->Read(size);
}
private:
inline const char* Read(size_t* readed) {
const char* ret;
if (*readed = In_->Next(&ret)) {
return ret;
}
return nullptr;
}
private:
IZeroCopyInput* In_;
}; // class TLuaReader
class TLuaWriter {
public:
TLuaWriter(IOutputStream* out)
: Out_(out)
{
}
inline void Dump(lua_State* state) {
if (lua_dump(state, WriteCallback, this)) {
ythrow TLuaStateHolder::TError() << "can not dump lua state: " << lua_tostring(state, -1);
}
}
static int WriteCallback(lua_State*, const void* data, size_t size, void* user) {
return ((TLuaWriter*)(user))->Write(data, size);
}
private:
inline int Write(const void* data, size_t size) {
Out_->Write(static_cast<const char*>(data), size);
return 0;
}
private:
IOutputStream* Out_;
}; // class TLuaWriter
} //namespace
void TLuaStateHolder::Load(IInputStream* in, TZtStringBuf name) {
TBufferedInput wi(in, 8192);
return TLuaReader(&wi).Load(State_, name.c_str());
}
void TLuaStateHolder::Dump(IOutputStream* out) {
return TLuaWriter(out).Dump(State_);
}
void TLuaStateHolder::DumpStack(IOutputStream* out) {
for (int i = lua_gettop(State_) * -1; i < 0; ++i) {
*out << i << " is " << lua_typename(State_, lua_type(State_, i)) << " (";
if (is_number(i)) {
*out << to_number<long long>(i);
} else if (is_string(i)) {
*out << to_string(i);
} else {
*out << Hex((uintptr_t)lua_topointer(State_, i), HF_ADDX);
}
*out << ')' << Endl;
}
}
void* TLuaStateHolder::Alloc(void* ud, void* ptr, size_t /*osize*/, size_t nsize) {
(void)ud;
if (nsize == 0) {
y_deallocate(ptr);
return nullptr;
}
return y_reallocate(ptr, nsize);
}
void* TLuaStateHolder::AllocLimit(void* ud, void* ptr, size_t osize, size_t nsize) {
TLuaStateHolder& state = *static_cast<TLuaStateHolder*>(ud);
if (nsize == 0) {
y_deallocate(ptr);
state.AllocFree += osize;
return nullptr;
}
if (state.AllocFree + osize < nsize) {
return nullptr;
}
ptr = y_reallocate(ptr, nsize);
if (ptr) {
state.AllocFree += osize;
state.AllocFree -= nsize;
}
return ptr;
}
void TLuaStateHolder::call(int args, int rets, int count) {
TLuaCountLimit limit(State_, count);
return call(args, rets);
}
void TLuaStateHolder::call(int args, int rets, TDuration time_limit, int count) {
TLuaTimeLimit limit(State_, time_limit, count);
return call(args, rets);
}
template <>
void Out<NLua::TStackDumper>(IOutputStream& out, const NLua::TStackDumper& sd) {
sd.State.DumpStack(&out);
}
template <>
void Out<NLua::TMarkedStackDumper>(IOutputStream& out, const NLua::TMarkedStackDumper& sd) {
out << sd.Mark << Endl;
sd.State.DumpStack(&out);
out << sd.Mark << Endl;
}
namespace NLua {
TBuffer& Compile(TStringBuf script, TBuffer& buffer) {
TMemoryInput input(script.data(), script.size());
TLuaStateHolder state;
state.Load(&input, "main");
TBufferOutput out(buffer);
state.Dump(&out);
return buffer;
}
}