From 5ab560ce12ed0df3450968cfe4211e398ff2a8d7 Mon Sep 17 00:00:00 2001 From: Wedson Almeida Filho Date: Wed, 27 Mar 2024 22:36:00 -0300 Subject: [PATCH] rust: alloc: update `VecExt` to take allocation flags We also rename the methods by removing the `try_` prefix since the names are available due to our usage of the `no_global_oom_handling` config when building the `alloc` crate. Reviewed-by: Boqun Feng Signed-off-by: Wedson Almeida Filho Reviewed-by: Benno Lossin Link: https://lore.kernel.org/r/20240328013603.206764-8-wedsonaf@gmail.com Signed-off-by: Miguel Ojeda --- rust/kernel/alloc/vec_ext.rs | 158 +++++++++++++++++++++++++++++++---- rust/kernel/error.rs | 11 +-- rust/kernel/lib.rs | 1 - rust/kernel/str.rs | 6 +- rust/kernel/types.rs | 4 +- samples/rust/rust_minimal.rs | 6 +- 6 files changed, 152 insertions(+), 34 deletions(-) diff --git a/rust/kernel/alloc/vec_ext.rs b/rust/kernel/alloc/vec_ext.rs index 311e62cc5784..e24d7c7675ca 100644 --- a/rust/kernel/alloc/vec_ext.rs +++ b/rust/kernel/alloc/vec_ext.rs @@ -2,47 +2,175 @@ //! Extensions to [`Vec`] for fallible allocations. -use alloc::{collections::TryReserveError, vec::Vec}; +use super::Flags; +use alloc::{alloc::AllocError, vec::Vec}; use core::result::Result; /// Extensions to [`Vec`]. pub trait VecExt: Sized { /// Creates a new [`Vec`] instance with at least the given capacity. - fn try_with_capacity(capacity: usize) -> Result; + /// + /// # Examples + /// + /// ``` + /// let v = Vec::::with_capacity(20, GFP_KERNEL)?; + /// + /// assert!(v.capacity() >= 20); + /// # Ok::<(), Error>(()) + /// ``` + fn with_capacity(capacity: usize, flags: Flags) -> Result; /// Appends an element to the back of the [`Vec`] instance. - fn try_push(&mut self, v: T) -> Result<(), TryReserveError>; + /// + /// # Examples + /// + /// ``` + /// let mut v = Vec::new(); + /// v.push(1, GFP_KERNEL)?; + /// assert_eq!(&v, &[1]); + /// + /// v.push(2, GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 2]); + /// # Ok::<(), Error>(()) + /// ``` + fn push(&mut self, v: T, flags: Flags) -> Result<(), AllocError>; /// Pushes clones of the elements of slice into the [`Vec`] instance. - fn try_extend_from_slice(&mut self, other: &[T]) -> Result<(), TryReserveError> + /// + /// # Examples + /// + /// ``` + /// let mut v = Vec::new(); + /// v.push(1, GFP_KERNEL)?; + /// + /// v.extend_from_slice(&[20, 30, 40], GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 20, 30, 40]); + /// + /// v.extend_from_slice(&[50, 60], GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 20, 30, 40, 50, 60]); + /// # Ok::<(), Error>(()) + /// ``` + fn extend_from_slice(&mut self, other: &[T], flags: Flags) -> Result<(), AllocError> where T: Clone; + + /// Ensures that the capacity exceeds the length by at least `additional` elements. + /// + /// # Examples + /// + /// ``` + /// let mut v = Vec::new(); + /// v.push(1, GFP_KERNEL)?; + /// + /// v.reserve(10, GFP_KERNEL)?; + /// let cap = v.capacity(); + /// assert!(cap >= 10); + /// + /// v.reserve(10, GFP_KERNEL)?; + /// let new_cap = v.capacity(); + /// assert_eq!(new_cap, cap); + /// + /// # Ok::<(), Error>(()) + /// ``` + fn reserve(&mut self, additional: usize, flags: Flags) -> Result<(), AllocError>; } impl VecExt for Vec { - fn try_with_capacity(capacity: usize) -> Result { + fn with_capacity(capacity: usize, flags: Flags) -> Result { let mut v = Vec::new(); - v.try_reserve(capacity)?; + >::reserve(&mut v, capacity, flags)?; Ok(v) } - fn try_push(&mut self, v: T) -> Result<(), TryReserveError> { - if let Err(retry) = self.push_within_capacity(v) { - self.try_reserve(1)?; - let _ = self.push_within_capacity(retry); - } + fn push(&mut self, v: T, flags: Flags) -> Result<(), AllocError> { + >::reserve(self, 1, flags)?; + let s = self.spare_capacity_mut(); + s[0].write(v); + + // SAFETY: We just initialised the first spare entry, so it is safe to increase the length + // by 1. We also know that the new length is <= capacity because of the previous call to + // `reserve` above. + unsafe { self.set_len(self.len() + 1) }; Ok(()) } - fn try_extend_from_slice(&mut self, other: &[T]) -> Result<(), TryReserveError> + fn extend_from_slice(&mut self, other: &[T], flags: Flags) -> Result<(), AllocError> where T: Clone, { - self.try_reserve(other.len())?; - for item in other { - self.try_push(item.clone())?; + >::reserve(self, other.len(), flags)?; + for (slot, item) in core::iter::zip(self.spare_capacity_mut(), other) { + slot.write(item.clone()); } + // SAFETY: We just initialised the `other.len()` spare entries, so it is safe to increase + // the length by the same amount. We also know that the new length is <= capacity because + // of the previous call to `reserve` above. + unsafe { self.set_len(self.len() + other.len()) }; Ok(()) } + + #[cfg(any(test, testlib))] + fn reserve(&mut self, additional: usize, _flags: Flags) -> Result<(), AllocError> { + Vec::reserve(self, additional); + Ok(()) + } + + #[cfg(not(any(test, testlib)))] + fn reserve(&mut self, additional: usize, flags: Flags) -> Result<(), AllocError> { + let len = self.len(); + let cap = self.capacity(); + + if cap - len >= additional { + return Ok(()); + } + + if core::mem::size_of::() == 0 { + // The capacity is already `usize::MAX` for SZTs, we can't go higher. + return Err(AllocError); + } + + // We know cap is <= `isize::MAX` because `Layout::array` fails if the resulting byte size + // is greater than `isize::MAX`. So the multiplication by two won't overflow. + let new_cap = core::cmp::max(cap * 2, len.checked_add(additional).ok_or(AllocError)?); + let layout = core::alloc::Layout::array::(new_cap).map_err(|_| AllocError)?; + + let (ptr, len, cap) = destructure(self); + + // SAFETY: `ptr` is valid because it's either NULL or comes from a previous call to + // `krealloc_aligned`. We also verified that the type is not a ZST. + let new_ptr = unsafe { super::allocator::krealloc_aligned(ptr.cast(), layout, flags) }; + if new_ptr.is_null() { + // SAFETY: We are just rebuilding the existing `Vec` with no changes. + unsafe { rebuild(self, ptr, len, cap) }; + Err(AllocError) + } else { + // SAFETY: `ptr` has been reallocated with the layout for `new_cap` elements. New cap + // is greater than `cap`, so it continues to be >= `len`. + unsafe { rebuild(self, new_ptr.cast::(), len, new_cap) }; + Ok(()) + } + } +} + +#[cfg(not(any(test, testlib)))] +fn destructure(v: &mut Vec) -> (*mut T, usize, usize) { + let mut tmp = Vec::new(); + core::mem::swap(&mut tmp, v); + let mut tmp = core::mem::ManuallyDrop::new(tmp); + let len = tmp.len(); + let cap = tmp.capacity(); + (tmp.as_mut_ptr(), len, cap) +} + +/// Rebuilds a `Vec` from a pointer, length, and capacity. +/// +/// # Safety +/// +/// The same as [`Vec::from_raw_parts`]. +#[cfg(not(any(test, testlib)))] +unsafe fn rebuild(v: &mut Vec, ptr: *mut T, len: usize, cap: usize) { + // SAFETY: The safety requirements from this function satisfy those of `from_raw_parts`. + let mut tmp = unsafe { Vec::from_raw_parts(ptr, len, cap) }; + core::mem::swap(&mut tmp, v); } diff --git a/rust/kernel/error.rs b/rust/kernel/error.rs index 4786d3ee1e92..e53466937796 100644 --- a/rust/kernel/error.rs +++ b/rust/kernel/error.rs @@ -6,10 +6,7 @@ use crate::str::CStr; -use alloc::{ - alloc::{AllocError, LayoutError}, - collections::TryReserveError, -}; +use alloc::alloc::{AllocError, LayoutError}; use core::convert::From; use core::fmt; @@ -192,12 +189,6 @@ fn from(_: Utf8Error) -> Error { } } -impl From for Error { - fn from(_: TryReserveError) -> Error { - code::ENOMEM - } -} - impl From for Error { fn from(_: LayoutError) -> Error { code::ENOMEM diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index d3d345aed218..1e910fe7c2c7 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -18,7 +18,6 @@ #![feature(new_uninit)] #![feature(receiver_trait)] #![feature(unsize)] -#![feature(vec_push_within_capacity)] // Ensure conditional compilation based on the kernel configuration works; // otherwise we may silently break things like initcall handling. diff --git a/rust/kernel/str.rs b/rust/kernel/str.rs index 14ef4344cf6e..f454252c6215 100644 --- a/rust/kernel/str.rs +++ b/rust/kernel/str.rs @@ -2,7 +2,7 @@ //! String representations. -use crate::alloc::vec_ext::VecExt; +use crate::alloc::{flags::*, vec_ext::VecExt}; use alloc::alloc::AllocError; use alloc::vec::Vec; use core::fmt::{self, Write}; @@ -807,7 +807,7 @@ pub fn try_from_fmt(args: fmt::Arguments<'_>) -> Result { let size = f.bytes_written(); // Allocate a vector with the required number of bytes, and write to it. - let mut buf = Vec::try_with_capacity(size)?; + let mut buf = as VecExt<_>>::with_capacity(size, GFP_KERNEL)?; // SAFETY: The buffer stored in `buf` is at least of size `size` and is valid for writes. let mut f = unsafe { Formatter::from_buffer(buf.as_mut_ptr(), size) }; f.write_fmt(args)?; @@ -856,7 +856,7 @@ impl<'a> TryFrom<&'a CStr> for CString { fn try_from(cstr: &'a CStr) -> Result { let mut buf = Vec::new(); - buf.try_extend_from_slice(cstr.as_bytes_with_nul()) + as VecExt<_>>::extend_from_slice(&mut buf, cstr.as_bytes_with_nul(), GFP_KERNEL) .map_err(|_| AllocError)?; // INVARIANT: The `CStr` and `CString` types have the same invariants for diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs index aa77bad9bce4..8fad61268465 100644 --- a/rust/kernel/types.rs +++ b/rust/kernel/types.rs @@ -157,11 +157,11 @@ unsafe fn from_foreign(_: *const core::ffi::c_void) -> Self {} /// let mut vec = /// ScopeGuard::new_with_data(Vec::new(), |v| pr_info!("vec had {} elements\n", v.len())); /// -/// vec.try_push(10u8)?; +/// vec.push(10u8, GFP_KERNEL)?; /// if arg { /// return Ok(()); /// } -/// vec.try_push(20u8)?; +/// vec.push(20u8, GFP_KERNEL)?; /// Ok(()) /// } /// diff --git a/samples/rust/rust_minimal.rs b/samples/rust/rust_minimal.rs index dc05f4bbe27e..2a9eaab62d1c 100644 --- a/samples/rust/rust_minimal.rs +++ b/samples/rust/rust_minimal.rs @@ -22,9 +22,9 @@ fn init(_module: &'static ThisModule) -> Result { pr_info!("Am I built-in? {}\n", !cfg!(MODULE)); let mut numbers = Vec::new(); - numbers.try_push(72)?; - numbers.try_push(108)?; - numbers.try_push(200)?; + numbers.push(72, GFP_KERNEL)?; + numbers.push(108, GFP_KERNEL)?; + numbers.push(200, GFP_KERNEL)?; Ok(RustMinimal { numbers }) }