AK: Make Function a little bit constexpr

Just enough to make it possible to pass a lambda to a constexpr function
taking an AK::Function parameter and have that constexpr function call
the passed-in AK::Function.

The main thing is that bit_cast<>s of pointers aren't ok in constexpr
functions, and neither is placement new. So add a union with stronger
typing for the constexpr case.

The body of the `if (is_constexpr_evaluated())` in
`init_with_callable()` is identical to the `if constexpr` right
after it. But `if constexpr (is_constexpr_evaluated())` always
evaluates `is_constexpr_evaluated()` in a constexpr context, so
this can't just add ` || is_constexpr_evaluated()` to that
`if constexpr`.
This commit is contained in:
Nico Weber 2024-12-19 21:39:16 -05:00
parent 761320c4d2
commit e4c36876d2
4 changed files with 56 additions and 30 deletions

View file

@ -79,7 +79,7 @@ public:
{
}
~Function()
constexpr ~Function()
{
clear(false);
}
@ -94,14 +94,14 @@ public:
}
template<typename CallableType>
Function(CallableType&& callable)
constexpr Function(CallableType&& callable)
requires((IsFunctionObject<CallableType> && IsCallableWithArguments<CallableType, Out, In...> && !IsSame<RemoveCVReference<CallableType>, Function>))
{
init_with_callable(forward<CallableType>(callable), CallableKind::FunctionObject);
}
template<typename FunctionType>
Function(FunctionType f)
constexpr Function(FunctionType f)
requires((IsFunctionPointer<FunctionType> && IsCallableWithArguments<RemovePointer<FunctionType>, Out, In...> && !IsSame<RemoveCVReference<FunctionType>, Function>))
{
init_with_callable(move(f), CallableKind::FunctionPointer);
@ -113,9 +113,10 @@ public:
}
// Note: Despite this method being const, a mutable lambda _may_ modify its own captures.
Out operator()(In... in) const
constexpr Out operator()(In... in) const
{
auto* wrapper = callable_wrapper();
if (!is_constant_evaluated())
VERIFY(wrapper);
++m_call_nesting_level;
ScopeGuard guard([this] {
@ -128,7 +129,7 @@ public:
explicit operator bool() const { return !!callable_wrapper(); }
template<typename CallableType>
Function& operator=(CallableType&& callable)
constexpr Function& operator=(CallableType&& callable)
requires((IsFunctionObject<CallableType> && IsCallableWithArguments<CallableType, Out, In...>))
{
clear();
@ -137,7 +138,7 @@ public:
}
template<typename FunctionType>
Function& operator=(FunctionType f)
constexpr Function& operator=(FunctionType f)
requires((IsFunctionPointer<FunctionType> && IsCallableWithArguments<RemovePointer<FunctionType>, Out, In...>))
{
clear();
@ -171,8 +172,8 @@ private:
public:
virtual ~CallableWrapperBase() = default;
// Note: This is not const to allow storing mutable lambdas.
virtual Out call(In...) = 0;
virtual void destroy() = 0;
virtual constexpr Out call(In...) = 0;
virtual constexpr void destroy() = 0;
virtual void init_and_swap(u8*, size_t) = 0;
};
@ -182,17 +183,17 @@ private:
AK_MAKE_NONCOPYABLE(CallableWrapper);
public:
explicit CallableWrapper(CallableType&& callable)
explicit constexpr CallableWrapper(CallableType&& callable)
: m_callable(move(callable))
{
}
Out call(In... in) final override
Out constexpr call(In... in) final override
{
return m_callable(forward<In>(in)...);
}
void destroy() final override
void constexpr destroy() final override
{
delete this;
}
@ -214,7 +215,7 @@ private:
Outline,
};
CallableWrapperBase* callable_wrapper() const
constexpr CallableWrapperBase* callable_wrapper() const
{
switch (m_kind) {
case FunctionKind::NullPointer:
@ -222,13 +223,13 @@ private:
case FunctionKind::Inline:
return bit_cast<CallableWrapperBase*>(&m_storage);
case FunctionKind::Outline:
return *bit_cast<CallableWrapperBase**>(&m_storage);
return m_storage.wrapper;
default:
VERIFY_NOT_REACHED();
}
}
void clear(bool may_defer = true)
constexpr void clear(bool may_defer = true)
{
bool called_from_inside_function = m_call_nesting_level > 0;
// NOTE: This VERIFY could fail because a Function is destroyed from within itself.
@ -250,7 +251,7 @@ private:
}
template<typename Callable>
void init_with_callable(Callable&& callable, CallableKind callable_kind)
constexpr void init_with_callable(Callable&& callable, CallableKind callable_kind)
{
if constexpr (alignof(Callable) > ExcessiveAlignmentThreshold && !AccommodateExcessiveAlignmentRequirements) {
static_assert(
@ -259,20 +260,27 @@ private:
"check your capture list if it is a lambda expression, "
"and make sure your callable object is not excessively aligned.");
}
if (!is_constant_evaluated())
VERIFY(m_call_nesting_level == 0);
using WrapperType = CallableWrapper<Callable>;
if (is_constant_evaluated()) {
m_storage.wrapper = new WrapperType(forward<Callable>(callable));
m_kind = FunctionKind::Outline;
} else {
#ifndef KERNEL
if constexpr (alignof(Callable) > inline_alignment || sizeof(WrapperType) > inline_capacity) {
*bit_cast<CallableWrapperBase**>(&m_storage) = new WrapperType(forward<Callable>(callable));
m_storage.wrapper = new WrapperType(forward<Callable>(callable));
m_kind = FunctionKind::Outline;
} else {
#endif
static_assert(sizeof(WrapperType) <= inline_capacity);
new (m_storage) WrapperType(forward<Callable>(callable));
new (m_storage.storage) WrapperType(forward<Callable>(callable));
m_kind = FunctionKind::Inline;
#ifndef KERNEL
}
#endif
}
if (callable_kind == CallableKind::FunctionObject)
m_size = sizeof(WrapperType);
else
@ -288,11 +296,11 @@ private:
case FunctionKind::NullPointer:
break;
case FunctionKind::Inline:
other_wrapper->init_and_swap(m_storage, inline_capacity);
other_wrapper->init_and_swap(m_storage.storage, inline_capacity);
m_kind = FunctionKind::Inline;
break;
case FunctionKind::Outline:
*bit_cast<CallableWrapperBase**>(&m_storage) = other_wrapper;
m_storage.wrapper = other_wrapper;
m_kind = FunctionKind::Outline;
break;
default:
@ -315,7 +323,10 @@ private:
static constexpr size_t inline_capacity = 6 * sizeof(void*);
#endif
alignas(inline_alignment) u8 m_storage[inline_capacity];
alignas(inline_alignment) union {
u8 storage[inline_capacity];
CallableWrapperBase* wrapper;
} m_storage;
};
}

View file

@ -13,12 +13,12 @@ namespace AK {
template<typename Callback>
class ScopeGuard {
public:
ScopeGuard(Callback callback)
constexpr ScopeGuard(Callback callback)
: m_callback(move(callback))
{
}
~ScopeGuard()
constexpr ~ScopeGuard()
{
m_callback();
}

View file

@ -35,6 +35,7 @@ set(AK_TEST_SOURCES
TestFloatingPointParsing.cpp
TestFlyString.cpp
TestFormat.cpp
TestFunction.cpp
TestFuzzyMatch.cpp
TestGeneratorAK.cpp
TestGenericLexer.cpp

14
Tests/AK/TestFunction.cpp Normal file
View file

@ -0,0 +1,14 @@
/*
* Copyright (c) 2024, Nico Weber <thakis@chromium.org>
*
* SPDX-License-Identifier: BSD-2-Clause
*/
#include <AK/Function.h>
constexpr int const_call(Function<int(int)> f, int i)
{
return f(i);
}
constinit int i = const_call([](int i) { return i; }, 4);