LibCore: Let LocalSocket send and receive messages with SCM_RIGHTS

These new methods combine send/receive with send_fd/receive_fd.

This is the 'correct' way to use SCM_RIGHTS, rather than trying to
emulate the Serenity behavior on other Unixes.
This commit is contained in:
Andrew Kaster 2024-04-17 16:40:57 -06:00 committed by Tim Flynn
parent 0e699743c4
commit a18c7c4405
2 changed files with 73 additions and 0 deletions

View file

@ -10,6 +10,8 @@
namespace Core {
static constexpr size_t MAX_LOCAL_SOCKET_TRANSFER_FDS = 64;
ErrorOr<int> Socket::create_fd(SocketDomain domain, SocketType type)
{
int socket_domain;
@ -362,6 +364,73 @@ ErrorOr<void> LocalSocket::send_fd(int fd)
#endif
}
ErrorOr<ssize_t> LocalSocket::send_message(ReadonlyBytes data, int flags, Vector<int, 1> fds)
{
size_t const num_fds = fds.size();
if (num_fds == 0)
return m_helper.write(data, flags | default_flags());
if (num_fds > MAX_LOCAL_SOCKET_TRANSFER_FDS)
return Error::from_string_literal("Too many file descriptors to send");
auto const fd_payload_size = num_fds * sizeof(int);
alignas(struct cmsghdr) char control_buf[CMSG_SPACE(sizeof(int) * MAX_LOCAL_SOCKET_TRANSFER_FDS)] {};
auto* header = new (control_buf) cmsghdr {
.cmsg_len = static_cast<socklen_t>(CMSG_LEN(fd_payload_size)),
.cmsg_level = SOL_SOCKET,
.cmsg_type = SCM_RIGHTS,
};
memcpy(CMSG_DATA(header), fds.data(), fd_payload_size);
struct iovec iov {
.iov_base = const_cast<u8*>(data.data()),
.iov_len = data.size(),
};
struct msghdr msg = {};
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = header;
msg.msg_controllen = CMSG_LEN(fd_payload_size);
return TRY(Core::System::sendmsg(m_helper.fd(), &msg, default_flags() | flags));
}
ErrorOr<Bytes> LocalSocket::receive_message(AK::Bytes buffer, int flags, Vector<int>& fds)
{
struct iovec iov {
.iov_base = buffer.data(),
.iov_len = buffer.size(),
};
alignas(struct cmsghdr) char control_buf[CMSG_SPACE(sizeof(int) * MAX_LOCAL_SOCKET_TRANSFER_FDS)] {};
struct msghdr msg = {};
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = control_buf;
msg.msg_controllen = sizeof(control_buf);
auto nread = TRY(Core::System::recvmsg(m_helper.fd(), &msg, default_flags() | flags));
if (nread == 0) {
m_helper.did_reach_eof_on_read();
return buffer.trim(nread);
}
fds.clear();
struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
while (cmsg != nullptr) {
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
size_t num_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
auto* fd_data = reinterpret_cast<int*>(CMSG_DATA(cmsg));
for (size_t i = 0; i < num_fds; ++i) {
fds.append(fd_data[i]);
}
}
AK_IGNORE_DIAGNOSTIC("-Wsign-compare", cmsg = CMSG_NXTHDR(&msg, cmsg));
}
return buffer.trim(nread);
}
ErrorOr<pid_t> LocalSocket::peer_pid() const
{
#if defined(AK_OS_MACOS) || defined(AK_OS_IOS)

View file

@ -329,6 +329,10 @@ public:
ErrorOr<int> receive_fd(int flags);
ErrorOr<void> send_fd(int fd);
ErrorOr<Bytes> receive_message(Bytes buffer, int flags, Vector<int>& fds);
ErrorOr<ssize_t> send_message(ReadonlyBytes msg, int flags, Vector<int, 1> fds = {});
ErrorOr<pid_t> peer_pid() const;
ErrorOr<Bytes> read_without_waiting(Bytes buffer);