diff --git a/zion/lib/ref_counted.h b/zion/lib/ref_counted.h new file mode 100644 index 0000000..1926ba1 --- /dev/null +++ b/zion/lib/ref_counted.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +#include "debug/debug.h" + +template +class RefCounted { + public: + RefCounted() {} + ~RefCounted() { dbgln("RefCounted object destroyed"); } + void Adopt() { + if (ref_count_ != -1) { + panic("Adopting owned ptr"); + } else { + ref_count_ = 1; + } + } + + void Acquire() { + if (ref_count_ == -1) { + panic("Acquiring unowned ptr"); + } + ref_count_++; + } + bool Release() { + if (ref_count_ == -1 || ref_count_ == 0) { + panic("Releasing unowned ptr"); + } + return (--ref_count_) == 0; + } + + private: + // FIXME: This should be an atomic type. + uint64_t ref_count_ = -1; + // Disallow copy and move. + RefCounted(RefCounted&) = delete; + RefCounted(RefCounted&&) = delete; + RefCounted& operator=(RefCounted&) = delete; + RefCounted& operator=(RefCounted&&) = delete; +}; diff --git a/zion/lib/ref_ptr.h b/zion/lib/ref_ptr.h new file mode 100644 index 0000000..a6cdf65 --- /dev/null +++ b/zion/lib/ref_ptr.h @@ -0,0 +1,76 @@ +#pragma once + +template +class RefPtr; + +template +RefPtr AdoptPtr(T* ptr); + +template +class RefPtr { + public: + RefPtr() : ptr_(nullptr) {} + RefPtr(decltype(nullptr)) : ptr_(nullptr) {} + RefPtr(const RefPtr& other) : ptr_(other.ptr_) { + if (ptr_) { + ptr_->Acquire(); + } + } + RefPtr& operator=(const RefPtr& other) { + T* old = ptr_; + ptr_ = other.ptr_; + if (ptr_) { + ptr_->Acquire(); + } + if (old && old->Release()) { + delete old; + } + + return *this; + } + + RefPtr(RefPtr&& other) : ptr_(other.ptr_) { other.ptr_ = nullptr; } + RefPtr& operator=(RefPtr&& other) { + // Swap + T* ptr = ptr_; + ptr_ = other.ptr_; + other.ptr_ = ptr; + return *this; + } + + T* get() const { return ptr_; }; + T& operator*() const { return *ptr_; } + T* operator->() const { return ptr_; } + operator bool() const { return ptr_ != nullptr; } + + bool operator==(decltype(nullptr)) const { return (ptr_ == nullptr); } + bool operator!=(decltype(nullptr)) const { return (ptr_ != nullptr); } + + bool operator==(const RefPtr& other) const { return (ptr_ == other.ptr_); } + bool operator!=(const RefPtr& other) const { return (ptr_ != other.ptr_); } + + private: + T* ptr_; + + friend RefPtr AdoptPtr(T* ptr); + RefPtr(T* ptr) : ptr_(ptr) { ptr->Adopt(); } +}; + +template +class MakeRefCountedFriend final { + public: + template + static RefPtr Make(Args&&... args) { + return AdoptPtr(new T(args...)); + } +}; + +template +RefPtr MakeRefCounted(Args&&... args) { + return MakeRefCountedFriend::Make(args...); +} + +template +RefPtr AdoptPtr(T* ptr) { + return RefPtr(ptr); +} diff --git a/zion/scheduler/process.cpp b/zion/scheduler/process.cpp index f0a8d6f..7db76b2 100644 --- a/zion/scheduler/process.cpp +++ b/zion/scheduler/process.cpp @@ -26,13 +26,13 @@ Process::Process() : id_(gNextId++), state_(RUNNING) { ZC_PROC_SPAWN_PROC | ZC_PROC_SPAWN_THREAD)); } -SharedPtr Process::CreateThread() { - SharedPtr thread{new Thread(*this, next_thread_id_++)}; +RefPtr Process::CreateThread() { + RefPtr thread = MakeRefCounted(*this, next_thread_id_++); threads_.PushBack(thread); return thread; } -SharedPtr Process::GetThread(uint64_t tid) { +RefPtr Process::GetThread(uint64_t tid) { auto iter = threads_.begin(); while (iter != threads_.end()) { if (iter->tid() == tid) { @@ -67,9 +67,9 @@ SharedPtr Process::GetCapability(uint64_t cid) { return {}; } -uint64_t Process::AddCapability(SharedPtr& thread) { +uint64_t Process::AddCapability(RefPtr& thread) { uint64_t cap_id = next_cap_id_++; caps_.PushBack( - new Capability(thread.ptr(), Capability::THREAD, cap_id, ZC_WRITE)); + new Capability(thread.get(), Capability::THREAD, cap_id, ZC_WRITE)); return cap_id; } diff --git a/zion/scheduler/process.h b/zion/scheduler/process.h index fa9b7d4..32d728c 100644 --- a/zion/scheduler/process.h +++ b/zion/scheduler/process.h @@ -4,6 +4,7 @@ #include "capability/capability.h" #include "lib/linked_list.h" +#include "lib/ref_ptr.h" #include "lib/shared_ptr.h" #include "memory/virtual_memory.h" @@ -24,11 +25,11 @@ class Process { uint64_t id() const { return id_; } VirtualMemory& vmm() { return vmm_; } - SharedPtr CreateThread(); - SharedPtr GetThread(uint64_t tid); + RefPtr CreateThread(); + RefPtr GetThread(uint64_t tid); SharedPtr GetCapability(uint64_t cid); - uint64_t AddCapability(SharedPtr& t); + uint64_t AddCapability(RefPtr& t); // Checks the state of all child threads and transitions to // finished if all have finished. void CheckState(); @@ -44,6 +45,6 @@ class Process { uint64_t next_thread_id_ = 0; uint64_t next_cap_id_ = 0x100; - LinkedList> threads_; + LinkedList> threads_; LinkedList> caps_; }; diff --git a/zion/scheduler/scheduler.cpp b/zion/scheduler/scheduler.cpp index 79c2f78..39496f1 100644 --- a/zion/scheduler/scheduler.cpp +++ b/zion/scheduler/scheduler.cpp @@ -48,7 +48,7 @@ void Scheduler::Preempt() { return; } - SharedPtr prev = current_thread_; + RefPtr prev = current_thread_; prev->SetState(Thread::RUNNABLE); current_thread_ = runnable_threads_.CycleFront(prev); @@ -62,7 +62,7 @@ void Scheduler::Yield() { } asm volatile("cli"); - SharedPtr prev = current_thread_; + RefPtr prev = current_thread_; if (prev == sleep_thread_) { if (runnable_threads_.size() == 0) { panic("Sleep thread yielded without next."); diff --git a/zion/scheduler/scheduler.h b/zion/scheduler/scheduler.h index f36772c..697e778 100644 --- a/zion/scheduler/scheduler.h +++ b/zion/scheduler/scheduler.h @@ -15,7 +15,7 @@ class Scheduler { Process& CurrentProcess() { return current_thread_->process(); } Thread& CurrentThread() { return *current_thread_; } - void Enqueue(const SharedPtr thread) { + void Enqueue(const RefPtr& thread) { runnable_threads_.PushBack(thread); } @@ -25,10 +25,10 @@ class Scheduler { private: bool enabled_ = false; - SharedPtr current_thread_; - LinkedList> runnable_threads_; + RefPtr current_thread_; + LinkedList> runnable_threads_; - SharedPtr sleep_thread_; + RefPtr sleep_thread_; Scheduler(); void SwapToCurrent(Thread& prev); diff --git a/zion/scheduler/thread.cpp b/zion/scheduler/thread.cpp index 80af108..ac35b2e 100644 --- a/zion/scheduler/thread.cpp +++ b/zion/scheduler/thread.cpp @@ -20,8 +20,12 @@ extern "C" void thread_init() { } // namespace -SharedPtr Thread::RootThread(Process& root_proc) { - return new Thread(root_proc); +RefPtr Thread::RootThread(Process& root_proc) { + return MakeRefCounted(root_proc); +} + +RefPtr Thread::Create(Process& proc, uint64_t tid) { + return MakeRefCounted(proc, tid); } Thread::Thread(Process& proc, uint64_t tid) : process_(proc), id_(tid) { diff --git a/zion/scheduler/thread.h b/zion/scheduler/thread.h index ab78d16..b2897dc 100644 --- a/zion/scheduler/thread.h +++ b/zion/scheduler/thread.h @@ -2,12 +2,13 @@ #include -#include "lib/shared_ptr.h" +#include "lib/ref_counted.h" +#include "lib/ref_ptr.h" // Forward decl due to cyclic dependency. class Process; -class Thread { +class Thread : public RefCounted { public: enum State { UNSPECIFIED, @@ -16,9 +17,8 @@ class Thread { RUNNABLE, FINISHED, }; - static SharedPtr RootThread(Process& root_proc); - - Thread(Process& proc, uint64_t tid); + static RefPtr RootThread(Process& root_proc); + static RefPtr Create(Process& proc, uint64_t tid); uint64_t tid() const { return id_; }; uint64_t pid() const; @@ -40,6 +40,8 @@ class Thread { void Exit(); private: + friend class MakeRefCountedFriend; + Thread(Process& proc, uint64_t tid); // Special constructor for the root thread only. Thread(Process& proc) : process_(proc), id_(0) {} Process& process_;