diff --git a/src/pc/lua/smlua.c b/src/pc/lua/smlua.c index acf20e70c..69e5adaca 100644 --- a/src/pc/lua/smlua.c +++ b/src/pc/lua/smlua.c @@ -1,5 +1,4 @@ #include "smlua.h" -#include "smlua_cobject_map.h" #include "game/hardcoded.h" #include "pc/mods/mods.h" #include "pc/mods/mods_utils.h" @@ -282,7 +281,6 @@ static void smlua_load_script(struct Mod* mod, struct ModFile* file, u16 remoteI void smlua_init(void) { smlua_shutdown(); - smlua_pointer_user_data_init(); gLuaState = luaL_newstate(); lua_State* L = gLuaState; @@ -362,7 +360,6 @@ void smlua_shutdown(void) { smlua_text_utils_reset_all(); smlua_audio_utils_reset_all(); audio_custom_shutdown(); - smlua_pointer_user_data_shutdown(); smlua_clear_hooks(); smlua_model_util_clear(); smlua_level_util_reset(); diff --git a/src/pc/lua/smlua.h b/src/pc/lua/smlua.h index 3ed42aab5..3704460b1 100644 --- a/src/pc/lua/smlua.h +++ b/src/pc/lua/smlua.h @@ -24,9 +24,11 @@ #define LOG_LUA_LINE_WARNING(...) { if (!gLuaActiveMod->showedScriptWarning) { gLuaActiveMod->showedScriptWarning = true; smlua_mod_warning(); snprintf(gDjuiConsoleTmpBuffer, CONSOLE_MAX_TMP_BUFFER, __VA_ARGS__), sys_swap_backslashes(gDjuiConsoleTmpBuffer), djui_console_message_create(gDjuiConsoleTmpBuffer, CONSOLE_MESSAGE_WARNING); } } #ifdef DEVELOPMENT -#define LUA_STACK_CHECK_BEGIN() int __LUA_STACK_TOP = lua_gettop(gLuaState) +#define LUA_STACK_CHECK_BEGIN_NUM(n) int __LUA_STACK_TOP = lua_gettop(gLuaState) + (n) +#define LUA_STACK_CHECK_BEGIN() LUA_STACK_CHECK_BEGIN_NUM(0) #define LUA_STACK_CHECK_END() if ((__LUA_STACK_TOP) != lua_gettop(gLuaState)) { smlua_dump_stack(); fflush(stdout); } assert((__LUA_STACK_TOP) == lua_gettop(gLuaState)) #else +#define LUA_STACK_CHECK_BEGIN_NUM(n) #define LUA_STACK_CHECK_BEGIN() #define LUA_STACK_CHECK_END() #endif diff --git a/src/pc/lua/smlua_cobject.c b/src/pc/lua/smlua_cobject.c index 900641246..d0e35e99a 100644 --- a/src/pc/lua/smlua_cobject.c +++ b/src/pc/lua/smlua_cobject.c @@ -11,7 +11,6 @@ #include "object_fields.h" #include "pc/djui/djui_hud_utils.h" #include "pc/lua/smlua.h" -#include "pc/lua/smlua_cobject_map.h" #include "pc/lua/utils/smlua_anim_utils.h" #include "pc/lua/utils/smlua_collision_utils.h" #include "pc/lua/utils/smlua_obj_utils.h" @@ -19,6 +18,11 @@ extern struct LuaObjectTable sLuaObjectTable[LOT_MAX]; +int gSmLuaCObjects = 0; +int gSmLuaCPointers = 0; +int gSmLuaCObjectMetatable = 0; +int gSmLuaCPointerMetatable = 0; + struct LuaObjectField* smlua_get_object_field_from_ot(struct LuaObjectTable* ot, const char* key) { // binary search s32 min = 0; @@ -324,26 +328,28 @@ struct LuaObjectField* smlua_get_custom_field(lua_State* L, u32 lot, int keyInde ///////////////////// static int smlua__get_field(lua_State* L) { - LUA_STACK_CHECK_BEGIN(); + LUA_STACK_CHECK_BEGIN_NUM(1); - CObject *cobj = lua_touserdata(L, 1); + const CObject *cobj = lua_touserdata(L, 1); enum LuaObjectType lot = cobj->lot; u64 pointer = (u64)(intptr_t) cobj->pointer; - const char *key = smlua_to_string(L, 2); - if (!gSmLuaConvertSuccess) { + const char *key = lua_tostring(L, 2); + if (!key) { LOG_LUA_LINE("Tried to get a non-string field of cobject"); return 0; } // Legacy support - if (strcmp(key, "_pointer") == 0) { - lua_pushinteger(L, pointer); - return 1; - } - if (strcmp(key, "_lot") == 0) { - lua_pushinteger(L, cobj->lot); - return 1; + if (key[0] == '_') { + if (strcmp(key, "_lot") == 0) { + lua_pushinteger(L, lot); + return 1; + } + if (strcmp(key, "_pointer") == 0) { + lua_pushinteger(L, pointer); + return 1; + } } if (cobj->freed) { @@ -360,8 +366,6 @@ static int smlua__get_field(lua_State* L) { return 0; } - LUA_STACK_CHECK_END(); - u8* p = ((u8*)(intptr_t)pointer) + data->valueOffset; switch (data->valueType) { case LVT_BOOL: lua_pushboolean(L, *(u8* )p); break; @@ -406,18 +410,19 @@ static int smlua__get_field(lua_State* L) { return 0; } + LUA_STACK_CHECK_END(); return 1; } static int smlua__set_field(lua_State* L) { LUA_STACK_CHECK_BEGIN(); - CObject *cobj = lua_touserdata(L, 1); + const CObject *cobj = lua_touserdata(L, 1); enum LuaObjectType lot = cobj->lot; u64 pointer = (u64)(intptr_t) cobj->pointer; - const char *key = smlua_to_string(L, 2); - if (!gSmLuaConvertSuccess) { + const char *key = lua_tostring(L, 2); + if (!key) { LOG_LUA_LINE("Tried to set a non-string field of cobject"); return 0; } @@ -496,37 +501,27 @@ static int smlua__set_field(lua_State* L) { } int smlua__eq(lua_State *L) { - CObject *a = lua_touserdata(L, 1); - CObject *b = lua_touserdata(L, 2); - lua_pushboolean(L, a->lot == b->lot && a->pointer == b->pointer); + const CObject *a = lua_touserdata(L, 1); + const CObject *b = lua_touserdata(L, 2); + lua_pushboolean(L, a && b && a->lot == b->lot && a->pointer == b->pointer); return 1; } -int smlua__gc(lua_State *L) { - CObject *cobj = lua_touserdata(L, 1); - if (!cobj->freed) { - switch (cobj->lot) { - case LOT_SURFACE: { - smlua_pointer_user_data_delete((uintptr_t) cobj->pointer); - } - } - } - return 0; -} - static int smlua_cpointer_get(lua_State* L) { - CPointer *cptr = lua_touserdata(L, 1); - const char *key = smlua_to_string(L, 2); + const CPointer *cptr = lua_touserdata(L, 1); + const char *key = lua_tostring(L, 2); if (key == NULL) { return 0; } // Legacy support - if (strcmp(key, "_pointer") == 0) { - lua_pushinteger(L, (u64)(intptr_t) cptr->pointer); - return 1; - } - if (strcmp(key, "_lot") == 0) { - lua_pushinteger(L, cptr->lvt); - return 1; + if (key[0] == '_') { + if (strcmp(key, "_pointer") == 0) { + lua_pushinteger(L, (u64)(intptr_t) cptr->pointer); + return 1; + } + if (strcmp(key, "_lot") == 0) { + lua_pushinteger(L, cptr->lvt); + return 1; + } } return 0; @@ -540,26 +535,33 @@ static int smlua_cpointer_set(UNUSED lua_State* L) { return 0; } void smlua_cobject_init_globals(void) { lua_State* L = gLuaState; + // Create object pools + lua_newtable(L); + gSmLuaCObjects = luaL_ref(L, LUA_REGISTRYINDEX); + lua_newtable(L); + gSmLuaCPointers = luaL_ref(L, LUA_REGISTRYINDEX); + // Create metatables luaL_newmetatable(L, "CObject"); luaL_Reg cObjectMethods[] = { { "__index", smlua__get_field }, { "__newindex", smlua__set_field }, { "__eq", smlua__eq }, - { "__gc", smlua__gc }, + { "__metatable", NULL }, { NULL, NULL } }; luaL_setfuncs(L, cObjectMethods, 0); - lua_pop(L, 1); + gSmLuaCObjectMetatable = luaL_ref(L, LUA_REGISTRYINDEX); luaL_newmetatable(L, "CPointer"); luaL_Reg cPointerMethods[] = { { "__index", smlua_cpointer_get }, { "__newindex", smlua_cpointer_set }, { "__eq", smlua__eq }, + { "__metatable", NULL }, { NULL, NULL } }; luaL_setfuncs(L, cPointerMethods, 0); - lua_pop(L, 1); + gSmLuaCPointerMetatable = luaL_ref(L, LUA_REGISTRYINDEX); #define EXPOSE_GLOBAL_ARRAY(lot, ptr, iterator) \ { \ diff --git a/src/pc/lua/smlua_cobject.h b/src/pc/lua/smlua_cobject.h index 7f26bb754..1cb3e14e8 100644 --- a/src/pc/lua/smlua_cobject.h +++ b/src/pc/lua/smlua_cobject.h @@ -65,6 +65,11 @@ typedef struct { bool freed; } CPointer; +extern int gSmLuaCObjects; +extern int gSmLuaCPointers; +extern int gSmLuaCObjectMetatable; +extern int gSmLuaCPointerMetatable; + bool smlua_valid_lot(u16 lot); bool smlua_valid_lvt(u16 lvt); struct LuaObjectField* smlua_get_object_field_from_ot(struct LuaObjectTable* ot, const char* key); diff --git a/src/pc/lua/smlua_cobject_map.c b/src/pc/lua/smlua_cobject_map.c deleted file mode 100644 index f3d0a0d69..000000000 --- a/src/pc/lua/smlua_cobject_map.c +++ /dev/null @@ -1,32 +0,0 @@ -#include -#include "smlua.h" -#include "data/dynos_cmap.cpp.h" - -static void* sPointers = NULL; - -void smlua_pointer_user_data_shutdown(void) { - hmap_clear(sPointers); -} - -void smlua_pointer_user_data_init(void) { - smlua_pointer_user_data_shutdown(); -} - -void smlua_pointer_user_data_add(uintptr_t pointer, CObject *obj) { - if (pointer == 0) { return; } - - if (!sPointers) { - sPointers = hmap_create(true); - } - hmap_put(sPointers, pointer, obj); -} - -void smlua_pointer_user_data_delete(uintptr_t pointer) { - if (pointer == 0) { return; } - hmap_del(sPointers, pointer); -} - -CObject *smlua_pointer_user_data_get(uintptr_t pointer) { - if (pointer == 0) { return NULL; } - return hmap_get(sPointers, pointer); -} diff --git a/src/pc/lua/smlua_cobject_map.h b/src/pc/lua/smlua_cobject_map.h deleted file mode 100644 index e5e22a37b..000000000 --- a/src/pc/lua/smlua_cobject_map.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef SMLUA_COBJECT_MAP_H -#define SMLUA_COBJECT_MAP_H - -void smlua_pointer_user_data_init(void); -void smlua_pointer_user_data_shutdown(void); -void smlua_pointer_user_data_add(uintptr_t pointer, CObject *obj); -void smlua_pointer_user_data_delete(uintptr_t pointer); -CObject *smlua_pointer_user_data_get(uintptr_t pointer); - -#endif diff --git a/src/pc/lua/smlua_utils.c b/src/pc/lua/smlua_utils.c index 9ae8d8ecd..94086c5af 100644 --- a/src/pc/lua/smlua_utils.c +++ b/src/pc/lua/smlua_utils.c @@ -1,5 +1,4 @@ #include "smlua.h" -#include "smlua_cobject_map.h" #include "pc/mods/mods.h" #include "audio/external.h" @@ -354,19 +353,30 @@ void smlua_push_object(lua_State* L, u16 lot, void* p) { lua_pushnil(L); return; } + LUA_STACK_CHECK_BEGIN_NUM(1); + + uintptr_t key = lot ^ (uintptr_t) p; + lua_rawgeti(L, LUA_REGISTRYINDEX, gSmLuaCObjects); + lua_pushinteger(L, key); + lua_gettable(L, -2); + if (lua_isuserdata(L, -1)) { + lua_remove(L, -2); // Remove gSmLuaCObjects table + return; + } + lua_pop(L, 1); CObject *cobject = lua_newuserdata(L, sizeof(CObject)); cobject->pointer = p; cobject->lot = lot; cobject->freed = false; - luaL_getmetatable(L, "CObject"); + lua_rawgeti(L, LUA_REGISTRYINDEX, gSmLuaCObjectMetatable); lua_setmetatable(L, -2); + lua_pushinteger(L, key); + lua_pushvalue(L, -2); // Duplicate userdata + lua_settable(L, -4); + lua_remove(L, -2); // Remove gSmLuaCObjects table - switch (lot) { - case LOT_SURFACE: { - smlua_pointer_user_data_add((uintptr_t) p, cobject); - } - } + LUA_STACK_CHECK_END(); } void smlua_push_pointer(lua_State* L, u16 lvt, void* p) { @@ -374,13 +384,29 @@ void smlua_push_pointer(lua_State* L, u16 lvt, void* p) { lua_pushnil(L); return; } + LUA_STACK_CHECK_BEGIN_NUM(1); + + uintptr_t key = lvt ^ (uintptr_t) p; + lua_rawgeti(L, LUA_REGISTRYINDEX, gSmLuaCPointers); + lua_pushinteger(L, key); + lua_gettable(L, -2); + if (lua_isuserdata(L, -1)) { + lua_remove(L, -2); // Remove gSmLuaCPointers table + return; + } + lua_pop(L, 1); CPointer *cpointer = lua_newuserdata(L, sizeof(CPointer)); cpointer->pointer = p; cpointer->lvt = lvt; cpointer->freed = false; - luaL_getmetatable(L, "CPointer"); + lua_rawgeti(L, LUA_REGISTRYINDEX, gSmLuaCPointerMetatable); lua_setmetatable(L, -2); + lua_pushinteger(L, key); + lua_pushvalue(L, -2); // Duplicate userdata + lua_settable(L, -4); + lua_remove(L, -2); // Remove gSmLuaCPointers table + LUA_STACK_CHECK_END(); } void smlua_push_integer_field(int index, const char* name, lua_Integer val) { @@ -710,7 +736,7 @@ void smlua_logline(void) { while (lua_getstack(L, level, &info)) { lua_getinfo(L, "nSl", &info); - // Get the folder and file of the crash + // Get the folder and file // in the format: "folder/file.lua" const char* src = info.source; int slashCount = 0; @@ -733,13 +759,28 @@ void smlua_logline(void) { // If an object is freed that Lua has a CObject to, // Lua is able to use-after-free that pointer +// todo figure out a better way to do this void smlua_free(void *ptr) { if (ptr && gLuaState) { - CObject *obj = smlua_pointer_user_data_get((uintptr_t) ptr); - if (obj) { + lua_State *L = gLuaState; + LUA_STACK_CHECK_BEGIN(); + u16 lot = LOT_SURFACE; // Assuming this is a surface + uintptr_t key = lot ^ (uintptr_t) ptr; + lua_rawgeti(L, LUA_REGISTRYINDEX, gSmLuaCObjects); + lua_pushinteger(L, key); + lua_gettable(L, -2); + CObject *obj = (CObject *) lua_touserdata(L, -1); + if (obj && obj->pointer == ptr) { obj->freed = true; - smlua_pointer_user_data_delete((uintptr_t) ptr); + lua_pop(L, 1); + lua_pushinteger(L, key); + lua_pushnil(L); + lua_settable(L, -3); + } else { + lua_pop(L, 1); } + lua_pop(L, 1); + LUA_STACK_CHECK_END(); } free(ptr); }