LibWasm+LibWeb+test-wasm: Refcount Wasm::Module for function references

Prior to funcref, a partial chunk of an invalid module was never needed,
but funcref allows a partially instantiated module to modify imported
tables with references to its own functions, which means we need to keep
the second module alive while that function reference is present within
the imported table.
This was tested by the spectests, but very rarely caught as our GC does
not behave particularly predictably, making it so the offending module
remains in memory just long enough to let the tests pass.

This commit makes it so all function references keep their respective
modules alive.

(cherry picked from commit a60ecea16abe62aae988ba877fdb98466d2919d3)
This commit is contained in:
Ali Mohammad Pur 2024-08-22 01:13:37 +02:00 committed by Nico Weber
parent 1b0ed2d74a
commit 7c96d9a96d
9 changed files with 54 additions and 37 deletions

View file

@ -53,7 +53,7 @@ public:
Wasm::Module& module() { return *m_module; } Wasm::Module& module() { return *m_module; }
Wasm::ModuleInstance& module_instance() { return *m_module_instance; } Wasm::ModuleInstance& module_instance() { return *m_module_instance; }
static JS::ThrowCompletionOr<WebAssemblyModule*> create(JS::Realm& realm, Wasm::Module module, HashMap<Wasm::Linker::Name, Wasm::ExternValue> const& imports) static JS::ThrowCompletionOr<WebAssemblyModule*> create(JS::Realm& realm, NonnullRefPtr<Wasm::Module> module, HashMap<Wasm::Linker::Name, Wasm::ExternValue> const& imports)
{ {
auto& vm = realm.vm(); auto& vm = realm.vm();
auto instance = realm.heap().allocate<WebAssemblyModule>(realm, realm.intrinsics().object_prototype()); auto instance = realm.heap().allocate<WebAssemblyModule>(realm, realm.intrinsics().object_prototype());
@ -148,7 +148,7 @@ private:
static HashMap<Wasm::Linker::Name, Wasm::ExternValue> s_spec_test_namespace; static HashMap<Wasm::Linker::Name, Wasm::ExternValue> s_spec_test_namespace;
static Wasm::AbstractMachine m_machine; static Wasm::AbstractMachine m_machine;
Optional<Wasm::Module> m_module; RefPtr<Wasm::Module> m_module;
OwnPtr<Wasm::ModuleInstance> m_module_instance; OwnPtr<Wasm::ModuleInstance> m_module_instance;
}; };
@ -379,13 +379,15 @@ JS_DEFINE_NATIVE_FUNCTION(WebAssemblyModule::wasm_invoke)
arguments.append(Wasm::Value(bits)); arguments.append(Wasm::Value(bits));
break; break;
} }
case Wasm::ValueType::Kind::FunctionReference: case Wasm::ValueType::Kind::FunctionReference: {
if (argument.is_null()) { if (argument.is_null()) {
arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::FunctionReference) } })); arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::FunctionReference) } }));
break; break;
} }
arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Func { static_cast<u64>(double_value) } })); Wasm::FunctionAddress addr = static_cast<u64>(double_value);
arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Func { addr, machine().store().get_module_for(addr) } }));
break; break;
}
case Wasm::ValueType::Kind::ExternReference: case Wasm::ValueType::Kind::ExternReference:
if (argument.is_null()) { if (argument.is_null()) {
arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::ExternReference) } })); arguments.append(Wasm::Value(Wasm::Reference { Wasm::Reference::Null { Wasm::ValueType(Wasm::ValueType::Kind::ExternReference) } }));

View file

@ -14,14 +14,14 @@
namespace Wasm { namespace Wasm {
Optional<FunctionAddress> Store::allocate(ModuleInstance& module, CodeSection::Code const& code, TypeIndex type_index) Optional<FunctionAddress> Store::allocate(ModuleInstance& instance, Module const& module, CodeSection::Code const& code, TypeIndex type_index)
{ {
FunctionAddress address { m_functions.size() }; FunctionAddress address { m_functions.size() };
if (type_index.value() > module.types().size()) if (type_index.value() > instance.types().size())
return {}; return {};
auto& type = module.types()[type_index.value()]; auto& type = instance.types()[type_index.value()];
m_functions.empend(WasmFunction { type, module, code }); m_functions.empend(WasmFunction { type, instance, module, code });
return address; return address;
} }
@ -81,6 +81,14 @@ FunctionInstance* Store::get(FunctionAddress address)
return &m_functions[value]; return &m_functions[value];
} }
Module const* Store::get_module_for(Wasm::FunctionAddress address)
{
auto* function = get(address);
if (!function || function->has<HostFunction>())
return nullptr;
return function->get<WasmFunction>().module_ref().ptr();
}
TableInstance* Store::get(TableAddress address) TableInstance* Store::get(TableAddress address)
{ {
auto value = address.value(); auto value = address.value();
@ -220,7 +228,7 @@ InstantiationResult AbstractMachine::instantiate(Module const& module, Vector<Ex
size_t i = 0; size_t i = 0;
for (auto& code : module.code_section().functions()) { for (auto& code : module.code_section().functions()) {
auto type_index = module.function_section().types()[i]; auto type_index = module.function_section().types()[i];
auto address = m_store.allocate(main_module_instance, code, type_index); auto address = m_store.allocate(main_module_instance, module, code, type_index);
VERIFY(address.has_value()); VERIFY(address.has_value());
auxiliary_instance.functions().append(*address); auxiliary_instance.functions().append(*address);
module_functions.append(*address); module_functions.append(*address);

View file

@ -51,6 +51,7 @@ public:
}; };
struct Func { struct Func {
FunctionAddress address; FunctionAddress address;
RefPtr<Module> source_module; // null if host function.
}; };
struct Extern { struct Extern {
ExternAddress address; ExternAddress address;
@ -123,7 +124,7 @@ public:
// 2: null funcref // 2: null funcref
// 3: null externref // 3: null externref
ref.ref().visit( ref.ref().visit(
[&](Reference::Func const& func) { m_value = u128(bit_cast<u64>(func.address), 0); }, [&](Reference::Func const& func) { m_value = u128(bit_cast<u64>(func.address), bit_cast<u64>(func.source_module.ptr())); },
[&](Reference::Extern const& func) { m_value = u128(bit_cast<u64>(func.address), 1); }, [&](Reference::Extern const& func) { m_value = u128(bit_cast<u64>(func.address), 1); },
[&](Reference::Null const& null) { m_value = u128(0, null.type.kind() == ValueType::Kind::FunctionReference ? 2 : 3); }); [&](Reference::Null const& null) { m_value = u128(0, null.type.kind() == ValueType::Kind::FunctionReference ? 2 : 3); });
} }
@ -161,17 +162,15 @@ public:
return bit_cast<f64>(m_value.low()); return bit_cast<f64>(m_value.low());
} }
if constexpr (IsSame<T, Reference>) { if constexpr (IsSame<T, Reference>) {
switch (m_value.high()) { switch (m_value.high() & 3) {
case 0: case 0:
return Reference { Reference::Func { bit_cast<FunctionAddress>(m_value.low()) } }; return Reference { Reference::Func { bit_cast<FunctionAddress>(m_value.low()), bit_cast<Wasm::Module*>(m_value.high()) } };
case 1: case 1:
return Reference { Reference::Extern { bit_cast<ExternAddress>(m_value.low()) } }; return Reference { Reference::Extern { bit_cast<ExternAddress>(m_value.low()) } };
case 2: case 2:
return Reference { Reference::Null { ValueType(ValueType::Kind::FunctionReference) } }; return Reference { Reference::Null { ValueType(ValueType::Kind::FunctionReference) } };
case 3: case 3:
return Reference { Reference::Null { ValueType(ValueType::Kind::ExternReference) } }; return Reference { Reference::Null { ValueType(ValueType::Kind::ExternReference) } };
default:
VERIFY_NOT_REACHED();
} }
} }
VERIFY_NOT_REACHED(); VERIFY_NOT_REACHED();
@ -325,20 +324,23 @@ private:
class WasmFunction { class WasmFunction {
public: public:
explicit WasmFunction(FunctionType const& type, ModuleInstance const& module, CodeSection::Code const& code) explicit WasmFunction(FunctionType const& type, ModuleInstance const& instance, Module const& module, CodeSection::Code const& code)
: m_type(type) : m_type(type)
, m_module(module) , m_module(module.make_weak_ptr())
, m_module_instance(instance)
, m_code(code) , m_code(code)
{ {
} }
auto& type() const { return m_type; } auto& type() const { return m_type; }
auto& module() const { return m_module; } auto& module() const { return m_module_instance; }
auto& code() const { return m_code; } auto& code() const { return m_code; }
RefPtr<Module const> module_ref() const { return m_module.strong_ref(); }
private: private:
FunctionType m_type; FunctionType m_type;
ModuleInstance const& m_module; WeakPtr<Module const> m_module;
ModuleInstance const& m_module_instance;
CodeSection::Code const& m_code; CodeSection::Code const& m_code;
}; };
@ -535,7 +537,7 @@ class Store {
public: public:
Store() = default; Store() = default;
Optional<FunctionAddress> allocate(ModuleInstance&, CodeSection::Code const&, TypeIndex); Optional<FunctionAddress> allocate(ModuleInstance&, Module const&, CodeSection::Code const&, TypeIndex);
Optional<FunctionAddress> allocate(HostFunction&&); Optional<FunctionAddress> allocate(HostFunction&&);
Optional<TableAddress> allocate(TableType const&); Optional<TableAddress> allocate(TableType const&);
Optional<MemoryAddress> allocate(MemoryType const&); Optional<MemoryAddress> allocate(MemoryType const&);
@ -543,6 +545,7 @@ public:
Optional<GlobalAddress> allocate(GlobalType const&, Value); Optional<GlobalAddress> allocate(GlobalType const&, Value);
Optional<ElementAddress> allocate(ValueType const&, Vector<Reference>); Optional<ElementAddress> allocate(ValueType const&, Vector<Reference>);
Module const* get_module_for(FunctionAddress);
FunctionInstance* get(FunctionAddress); FunctionInstance* get(FunctionAddress);
TableInstance* get(TableAddress); TableInstance* get(TableAddress);
MemoryInstance* get(MemoryAddress); MemoryInstance* get(MemoryAddress);

View file

@ -864,7 +864,7 @@ ALWAYS_INLINE void BytecodeInterpreter::interpret_instruction(Configuration& con
auto index = instruction.arguments().get<FunctionIndex>().value(); auto index = instruction.arguments().get<FunctionIndex>().value();
auto& functions = configuration.frame().module().functions(); auto& functions = configuration.frame().module().functions();
auto address = functions[index]; auto address = functions[index];
configuration.value_stack().append(Value(address.value())); configuration.value_stack().append(Value(Reference { Reference::Func { address, configuration.store().get_module_for(address) } }));
return; return;
} }
case Instructions::ref_is_null.value(): { case Instructions::ref_is_null.value(): {

View file

@ -1248,7 +1248,7 @@ ParseResult<SectionId> SectionId::parse(Stream& stream)
} }
} }
ParseResult<Module> Module::parse(Stream& stream) ParseResult<NonnullRefPtr<Module>> Module::parse(Stream& stream)
{ {
ScopeLogger<WASM_BINPARSER_DEBUG> logger("Module"sv); ScopeLogger<WASM_BINPARSER_DEBUG> logger("Module"sv);
u8 buf[4]; u8 buf[4];
@ -1263,7 +1263,9 @@ ParseResult<Module> Module::parse(Stream& stream)
return with_eof_check(stream, ParseError::InvalidModuleVersion); return with_eof_check(stream, ParseError::InvalidModuleVersion);
auto last_section_id = SectionId::SectionIdKind::Custom; auto last_section_id = SectionId::SectionIdKind::Custom;
Module module; auto module_ptr = make_ref_counted<Module>();
auto& module = *module_ptr;
while (!stream.is_eof()) { while (!stream.is_eof()) {
auto section_id = TRY(SectionId::parse(stream)); auto section_id = TRY(SectionId::parse(stream));
size_t section_size = TRY_READ(stream, LEB128<u32>, ParseError::ExpectedSize); size_t section_size = TRY_READ(stream, LEB128<u32>, ParseError::ExpectedSize);
@ -1324,7 +1326,7 @@ ParseResult<Module> Module::parse(Stream& stream)
return ParseError::SectionSizeMismatch; return ParseError::SectionSizeMismatch;
} }
return module; return module_ptr;
} }
ByteString parse_error_to_byte_string(ParseError error) ByteString parse_error_to_byte_string(ParseError error)

View file

@ -14,6 +14,7 @@
#include <AK/String.h> #include <AK/String.h>
#include <AK/UFixedBigInt.h> #include <AK/UFixedBigInt.h>
#include <AK/Variant.h> #include <AK/Variant.h>
#include <AK/WeakPtr.h>
#include <LibWasm/Constants.h> #include <LibWasm/Constants.h>
#include <LibWasm/Forward.h> #include <LibWasm/Forward.h>
#include <LibWasm/Opcode.h> #include <LibWasm/Opcode.h>
@ -982,7 +983,8 @@ private:
Optional<u32> m_count; Optional<u32> m_count;
}; };
class Module { class Module : public RefCounted<Module>
, public Weakable<Module> {
public: public:
enum class ValidationStatus { enum class ValidationStatus {
Unchecked, Unchecked,
@ -1027,7 +1029,7 @@ public:
StringView validation_error() const { return *m_validation_error; } StringView validation_error() const { return *m_validation_error; }
void set_validation_error(ByteString error) { m_validation_error = move(error); } void set_validation_error(ByteString error) { m_validation_error = move(error); }
static ParseResult<Module> parse(Stream& stream); static ParseResult<NonnullRefPtr<Module>> parse(Stream& stream);
private: private:
void set_validation_status(ValidationStatus status) { m_validation_status = status; } void set_validation_status(ValidationStatus status) { m_validation_status = status; }

View file

@ -434,7 +434,7 @@ JS::ThrowCompletionOr<Wasm::Value> to_webassembly_value(JS::VM& vm, JS::Value va
auto& cache = get_cache(*vm.current_realm()); auto& cache = get_cache(*vm.current_realm());
for (auto& entry : cache.function_instances()) { for (auto& entry : cache.function_instances()) {
if (entry.value == &function) if (entry.value == &function)
return Wasm::Value { Wasm::Reference { Wasm::Reference::Func { entry.key } } }; return Wasm::Value { Wasm::Reference { Wasm::Reference::Func { entry.key, cache.abstract_machine().store().get_module_for(entry.key) } } };
} }
} }

View file

@ -29,12 +29,12 @@ WebIDL::ExceptionOr<JS::Value> instantiate(JS::VM&, Module const& module_object,
namespace Detail { namespace Detail {
struct CompiledWebAssemblyModule : public RefCounted<CompiledWebAssemblyModule> { struct CompiledWebAssemblyModule : public RefCounted<CompiledWebAssemblyModule> {
explicit CompiledWebAssemblyModule(Wasm::Module&& module) explicit CompiledWebAssemblyModule(NonnullRefPtr<Wasm::Module> module)
: module(move(module)) : module(move(module))
{ {
} }
Wasm::Module module; NonnullRefPtr<Wasm::Module> module;
}; };
class WebAssemblyCache { class WebAssemblyCache {

View file

@ -491,7 +491,7 @@ static bool pre_interpret_hook(Wasm::Configuration& config, Wasm::InstructionPoi
} }
} }
static Optional<Wasm::Module> parse(StringView filename) static RefPtr<Wasm::Module> parse(StringView filename)
{ {
auto result = Core::MappedFile::map(filename); auto result = Core::MappedFile::map(filename);
if (result.is_error()) { if (result.is_error()) {
@ -603,7 +603,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
attempt_instantiate = true; attempt_instantiate = true;
auto parse_result = parse(filename); auto parse_result = parse(filename);
if (!parse_result.has_value()) if (parse_result.is_null())
return 1; return 1;
g_stdout = TRY(Core::File::standard_output()); g_stdout = TRY(Core::File::standard_output());
@ -611,7 +611,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
if (print && !attempt_instantiate) { if (print && !attempt_instantiate) {
Wasm::Printer printer(*g_stdout); Wasm::Printer printer(*g_stdout);
printer.print(parse_result.value()); printer.print(*parse_result);
} }
if (attempt_instantiate) { if (attempt_instantiate) {
@ -653,14 +653,14 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
// First, resolve the linked modules // First, resolve the linked modules
Vector<NonnullOwnPtr<Wasm::ModuleInstance>> linked_instances; Vector<NonnullOwnPtr<Wasm::ModuleInstance>> linked_instances;
Vector<Wasm::Module> linked_modules; Vector<NonnullRefPtr<Wasm::Module>> linked_modules;
for (auto& name : modules_to_link_in) { for (auto& name : modules_to_link_in) {
auto parse_result = parse(name); auto parse_result = parse(name);
if (!parse_result.has_value()) { if (parse_result.is_null()) {
warnln("Failed to parse linked module '{}'", name); warnln("Failed to parse linked module '{}'", name);
return 1; return 1;
} }
linked_modules.append(parse_result.release_value()); linked_modules.append(parse_result.release_nonnull());
Wasm::Linker linker { linked_modules.last() }; Wasm::Linker linker { linked_modules.last() };
for (auto& instance : linked_instances) for (auto& instance : linked_instances)
linker.link(*instance); linker.link(*instance);
@ -678,7 +678,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
linked_instances.append(instantiation_result.release_value()); linked_instances.append(instantiation_result.release_value());
} }
Wasm::Linker linker { parse_result.value() }; Wasm::Linker linker { *parse_result };
for (auto& instance : linked_instances) for (auto& instance : linked_instances)
linker.link(*instance); linker.link(*instance);
@ -704,7 +704,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
for (auto& entry : linker.unresolved_imports()) { for (auto& entry : linker.unresolved_imports()) {
if (!entry.type.has<Wasm::TypeIndex>()) if (!entry.type.has<Wasm::TypeIndex>())
continue; continue;
auto type = parse_result.value().type_section().types()[entry.type.get<Wasm::TypeIndex>().value()]; auto type = parse_result->type_section().types()[entry.type.get<Wasm::TypeIndex>().value()];
auto address = machine.store().allocate(Wasm::HostFunction( auto address = machine.store().allocate(Wasm::HostFunction(
[name = entry.name, type = type](auto&, auto& arguments) -> Wasm::Result { [name = entry.name, type = type](auto&, auto& arguments) -> Wasm::Result {
StringBuilder argument_builder; StringBuilder argument_builder;
@ -743,7 +743,7 @@ ErrorOr<int> serenity_main(Main::Arguments arguments)
print_link_error(link_result.error()); print_link_error(link_result.error());
return 1; return 1;
} }
auto result = machine.instantiate(parse_result.value(), link_result.release_value()); auto result = machine.instantiate(*parse_result, link_result.release_value());
if (result.is_error()) { if (result.is_error()) {
warnln("Module instantiation failed: {}", result.error().error); warnln("Module instantiation failed: {}", result.error().error);
return 1; return 1;