From d5a5041c4b1b1da4e15d65fd720e6c984cf5b790 Mon Sep 17 00:00:00 2001 From: Drew Galbraith Date: Sun, 28 Jul 2024 19:02:03 -0700 Subject: [PATCH] Add multithreading to rust code. --- rust/lib/mammoth/src/bindings.rs | 4 +-- rust/lib/mammoth/src/lib.rs | 4 +++ rust/lib/mammoth/src/syscall.rs | 6 ++-- rust/lib/mammoth/src/thread.rs | 60 ++++++++++++++++++++++++++++++++ rust/usr/testbed/src/main.rs | 8 +++++ 5 files changed, 76 insertions(+), 6 deletions(-) create mode 100644 rust/lib/mammoth/src/thread.rs diff --git a/rust/lib/mammoth/src/bindings.rs b/rust/lib/mammoth/src/bindings.rs index 31568ef..6fdfc13 100644 --- a/rust/lib/mammoth/src/bindings.rs +++ b/rust/lib/mammoth/src/bindings.rs @@ -351,9 +351,7 @@ pub struct ZThreadStartReq { } #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct ZThreadExitReq { - pub _address: u8, -} +pub struct ZThreadExitReq {} extern "C" { #[link_name = "\u{1}_Z11ZThreadExitv"] pub fn ZThreadExit() -> z_err_t; diff --git a/rust/lib/mammoth/src/lib.rs b/rust/lib/mammoth/src/lib.rs index 3d8b7da..38564fd 100644 --- a/rust/lib/mammoth/src/lib.rs +++ b/rust/lib/mammoth/src/lib.rs @@ -3,10 +3,14 @@ #![allow(non_camel_case_types)] #![allow(non_snake_case)] +extern crate alloc; + use core::ffi::c_void; pub mod mem; +#[macro_use] pub mod syscall; +pub mod thread; // From /zion/include/ztypes.h const Z_INIT_SELF_PROC: u64 = 0x4000_0000; diff --git a/rust/lib/mammoth/src/syscall.rs b/rust/lib/mammoth/src/syscall.rs index ec8fe9d..8db78f6 100644 --- a/rust/lib/mammoth/src/syscall.rs +++ b/rust/lib/mammoth/src/syscall.rs @@ -138,10 +138,10 @@ impl fmt::Write for Writer { #[macro_export] macro_rules! debug { () => { - debug(""); + $crate::syscall::debug(""); }; ($fmt:literal) => { - debug($fmt); + $crate::syscall::debug($fmt); }; ($fmt:literal, $($val:expr),+) => {{ use core::fmt::Write as _; @@ -149,6 +149,6 @@ macro_rules! debug { let mut w = $crate::syscall::Writer::new(); write!(&mut w, $fmt, $($val),*).expect("Failed to format"); let s: String = w.into(); - debug(&s); + $crate::syscall::debug(&s); }}; } diff --git a/rust/lib/mammoth/src/thread.rs b/rust/lib/mammoth/src/thread.rs new file mode 100644 index 0000000..04c0366 --- /dev/null +++ b/rust/lib/mammoth/src/thread.rs @@ -0,0 +1,60 @@ +use crate::syscall; +use crate::syscall::z_cap_t; + +use core::ffi::c_void; + +pub type ThreadEntry = fn(*const c_void) -> (); + +#[no_mangle] +extern "C" fn entry_point(entry_ptr: *const ThreadEntry, arg1: *const c_void) -> ! { + debug!("Entry {:#p} arg1 {:#x}", entry_ptr, arg1 as u64); + let entry = unsafe { *entry_ptr }; + + entry(arg1); + + let _ = syscall::syscall(syscall::kZionThreadExit, &syscall::ZThreadExitReq {}); + + unreachable!(); +} + +// TODO: Add a Drop implementation that kills this thread and drops its capability. +pub struct Thread<'a> { + cap: z_cap_t, + // This field only exists to ensure that the entry reference will outlive the thread object + // itself. + _entry: &'a ThreadEntry, +} + +impl<'a> Thread<'a> { + pub fn spawn(entry: &'a ThreadEntry, arg1: *const c_void) -> Self { + let mut cap: z_cap_t = 0; + let req = syscall::ZThreadCreateReq { + proc_cap: unsafe { crate::SELF_PROC_CAP }, + thread_cap: &mut cap as *mut z_cap_t, + }; + + syscall::syscall(syscall::kZionThreadCreate, &req).expect("Failed to create thread."); + + syscall::syscall( + syscall::kZionThreadStart, + &syscall::ZThreadStartReq { + thread_cap: cap, + entry: entry_point as u64, + arg1: entry as *const ThreadEntry as u64, + arg2: arg1 as u64, + }, + ) + .expect("Failed to start thread."); + + Self { cap, _entry: entry } + } + + pub fn join(&self) -> Result<(), syscall::ZError> { + syscall::syscall( + syscall::kZionThreadWait, + &syscall::ZThreadWaitReq { + thread_cap: self.cap, + }, + ) + } +} diff --git a/rust/usr/testbed/src/main.rs b/rust/usr/testbed/src/main.rs index 79f9e2e..1b16921 100644 --- a/rust/usr/testbed/src/main.rs +++ b/rust/usr/testbed/src/main.rs @@ -9,6 +9,7 @@ use mammoth::debug; use mammoth::define_entry; use mammoth::syscall::debug; use mammoth::syscall::z_err_t; +use mammoth::thread; use yellowstone::GetEndpointRequest; use yellowstone::YellowstoneClient; @@ -41,5 +42,12 @@ pub extern "C" fn main() -> z_err_t { debug!("Got endpoint w/ cap: {:#x}", endpoint.endpoint); + let e: thread::ThreadEntry = |_| { + debug!("Testing 1 2 3"); + }; + let t = thread::Thread::spawn(&e, core::ptr::null()); + + t.join().expect("Failed to wait."); + 0 }