diff --git a/zion/lib/shared_ptr.h b/zion/lib/shared_ptr.h new file mode 100644 index 0000000..7af3b10 --- /dev/null +++ b/zion/lib/shared_ptr.h @@ -0,0 +1,63 @@ +#pragma once + +#include + +#include "debug/debug.h" + +template +class SharedPtr { + public: + SharedPtr() : init_(false), ptr_(0), ref_cnt_(0) {} + // Takes ownership. + SharedPtr(T* ptr) { + ptr_ = ptr; + ref_cnt_ = new uint64_t(1); + } + + SharedPtr(const SharedPtr& other) + : ptr_(other.ptr_), ref_cnt_(other.ref_cnt_) { + (*ref_cnt_)++; + } + + SharedPtr& operator=(const SharedPtr& other) { + Cleanup(); + ptr_ = other.ptr_; + ref_cnt_ = other.ref_cnt_; + (*ref_cnt_)++; + return *this; + } + + ~SharedPtr() { Cleanup(); } + + T& operator*() { return *ptr_; } + const T& operator*() const { return *ptr_; } + T* operator->() { return ptr_; } + const T* operator->() const { return ptr_; } + + T* ptr() { return ptr_; } + + bool operator==(const SharedPtr& other) { return ptr_ == other.ptr_; } + + bool empty() { return !init_; } + + private: + bool init_ = true; + T* ptr_; + uint64_t* ref_cnt_; + + void Cleanup() { + if (!init_) { + return; + } + if (--(*ref_cnt_) == 0) { + dbgln("Deleting shared ptr: %m", ptr_); + delete ptr_; + delete ref_cnt_; + } + } +}; + +template +SharedPtr MakeShared(A... args) { + return {new T(args...)}; +} diff --git a/zion/scheduler/process.cpp b/zion/scheduler/process.cpp index ac80b61..d3bfdbc 100644 --- a/zion/scheduler/process.cpp +++ b/zion/scheduler/process.cpp @@ -12,11 +12,11 @@ static uint64_t gNextId = 1; } -Process* Process::RootProcess() { +SharedPtr Process::RootProcess() { uint64_t pml4_addr = 0; asm volatile("mov %%cr3, %0;" : "=r"(pml4_addr)); - Process* proc = new Process(0, pml4_addr); - proc->threads_.PushBack(Thread::RootThread(proc)); + SharedPtr proc(new Process(0, pml4_addr)); + proc->threads_.PushBack(Thread::RootThread(proc.ptr())); proc->next_thread_id_ = 1; return proc; @@ -34,7 +34,7 @@ void Process::CreateThread(uint64_t elf_ptr) { sched::EnqueueThread(thread); } -Thread* Process::GetThread(uint64_t tid) { +SharedPtr Process::GetThread(uint64_t tid) { auto iter = threads_.begin(); while (iter != threads_.end()) { if (iter->tid() == tid) { diff --git a/zion/scheduler/process.h b/zion/scheduler/process.h index b5f19d5..3014b3e 100644 --- a/zion/scheduler/process.h +++ b/zion/scheduler/process.h @@ -3,6 +3,7 @@ #include #include "lib/linked_list.h" +#include "lib/shared_ptr.h" // Forward decl due to cyclic dependency. class Thread; @@ -15,15 +16,14 @@ class Process { RUNNING, FINISHED, }; - // Caller takes ownership of returned process. - static Process* RootProcess(); + static SharedPtr RootProcess(); Process(uint64_t elf_ptr); - uint64_t id() { return id_; } - uint64_t cr3() { return cr3_; } + uint64_t id() const { return id_; } + uint64_t cr3() const { return cr3_; } void CreateThread(uint64_t elf_ptr); - Thread* GetThread(uint64_t tid); + SharedPtr GetThread(uint64_t tid); // Checks the state of all child threads and transitions to // finished if all have finished. @@ -39,5 +39,5 @@ class Process { uint64_t next_thread_id_ = 0; - LinkedList threads_; + LinkedList> threads_; }; diff --git a/zion/scheduler/scheduler.cpp b/zion/scheduler/scheduler.cpp index 121e378..91c653f 100644 --- a/zion/scheduler/scheduler.cpp +++ b/zion/scheduler/scheduler.cpp @@ -8,7 +8,7 @@ namespace { extern "C" void context_switch(uint64_t* current_esp, uint64_t* next_esp); -void DumpProcessStates(LinkedList& proc_list) { +void DumpProcessStates(LinkedList>& proc_list) { dbgln("Process States: %u", proc_list.size()); auto iter = proc_list.begin(); while (iter != proc_list.end()) { @@ -20,8 +20,9 @@ void DumpProcessStates(LinkedList& proc_list) { class Scheduler { public: Scheduler() { - Process* root = Process::RootProcess(); - runnable_threads_.PushBack(root->GetThread(0)); + SharedPtr root = Process::RootProcess(); + sleep_thread_ = root->GetThread(0); + runnable_threads_.PushBack(sleep_thread_); proc_list_.PushBack(Process::RootProcess()); } void Enable() { enabled_ = true; } @@ -38,7 +39,7 @@ class Scheduler { } asm volatile("cli"); - Thread* prev = nullptr; + SharedPtr prev; if (CurrentThread().GetState() == Thread::RUNNING) { prev = runnable_threads_.CycleFront(); prev->SetState(Thread::RUNNABLE); @@ -48,12 +49,14 @@ class Scheduler { prev = runnable_threads_.PopFront(); } + SharedPtr next; if (runnable_threads_.size() == 0) { + next = sleep_thread_; DumpProcessStates(proc_list_); - panic("FIXME: Implement Sleep"); + } else { + next = runnable_threads_.PeekFront(); } - Thread* next = runnable_threads_.PeekFront(); if (next->GetState() != Thread::RUNNABLE) { panic("Non-runnable thread in the queue"); } @@ -74,8 +77,10 @@ class Scheduler { private: bool enabled_ = false; // TODO: move this to a separate process manager class. - LinkedList proc_list_; - LinkedList runnable_threads_; + LinkedList> proc_list_; + LinkedList> runnable_threads_; + + SharedPtr sleep_thread_; }; static Scheduler* gScheduler = nullptr; diff --git a/zion/scheduler/thread.cpp b/zion/scheduler/thread.cpp index 2521abc..a513d3b 100644 --- a/zion/scheduler/thread.cpp +++ b/zion/scheduler/thread.cpp @@ -19,9 +19,11 @@ extern "C" void thread_init() { } // namespace -Thread* Thread::RootThread(Process* root_proc) { return new Thread(root_proc); } +SharedPtr Thread::RootThread(Process* root_proc) { + return new Thread(root_proc); +} -Thread::Thread(Process* proc, uint64_t tid, uint64_t elf_ptr) +Thread::Thread(const SharedPtr& proc, uint64_t tid, uint64_t elf_ptr) : process_(proc), id_(tid), elf_ptr_(elf_ptr) { uint64_t* stack = new uint64_t[512]; uint64_t* stack_ptr = stack + 511; diff --git a/zion/scheduler/thread.h b/zion/scheduler/thread.h index f47e775..da8036b 100644 --- a/zion/scheduler/thread.h +++ b/zion/scheduler/thread.h @@ -2,6 +2,8 @@ #include +#include "lib/shared_ptr.h" + // Forward decl due to cyclic dependency. class Process; @@ -15,9 +17,10 @@ class Thread { BLOCKED, FINISHED, }; - static Thread* RootThread(Process* root_proc); + static SharedPtr RootThread(Process* root_proc); - explicit Thread(Process* proc, uint64_t tid, uint64_t elf_ptr); + explicit Thread(const SharedPtr& proc, uint64_t tid, + uint64_t elf_ptr); uint64_t tid() { return id_; }; uint64_t pid(); @@ -38,7 +41,7 @@ class Thread { private: // Special constructor for the root thread only. Thread(Process* proc) : process_(proc), id_(0) {} - Process* process_; + SharedPtr process_; uint64_t id_; State state_ = RUNNABLE;