Move Processes and Threads to be stored in SharedPtr

Reference counting lets us pass these around a bit more easily.

SharedPtr was lightly tested using uint64_t in the main zion funcion.

Also add a sleep functionality instead of panicking. Functionally the
same right now since we don't preempt.
This commit is contained in:
Drew Galbraith 2023-05-29 15:50:38 -07:00
parent 9f3ffbf5b4
commit 80d2bf1aaa
6 changed files with 96 additions and 23 deletions

63
zion/lib/shared_ptr.h Normal file
View File

@ -0,0 +1,63 @@
#pragma once
#include <stdint.h>
#include "debug/debug.h"
template <typename T>
class SharedPtr {
public:
SharedPtr() : init_(false), ptr_(0), ref_cnt_(0) {}
// Takes ownership.
SharedPtr(T* ptr) {
ptr_ = ptr;
ref_cnt_ = new uint64_t(1);
}
SharedPtr(const SharedPtr<T>& other)
: ptr_(other.ptr_), ref_cnt_(other.ref_cnt_) {
(*ref_cnt_)++;
}
SharedPtr& operator=(const SharedPtr<T>& other) {
Cleanup();
ptr_ = other.ptr_;
ref_cnt_ = other.ref_cnt_;
(*ref_cnt_)++;
return *this;
}
~SharedPtr() { Cleanup(); }
T& operator*() { return *ptr_; }
const T& operator*() const { return *ptr_; }
T* operator->() { return ptr_; }
const T* operator->() const { return ptr_; }
T* ptr() { return ptr_; }
bool operator==(const SharedPtr<T>& other) { return ptr_ == other.ptr_; }
bool empty() { return !init_; }
private:
bool init_ = true;
T* ptr_;
uint64_t* ref_cnt_;
void Cleanup() {
if (!init_) {
return;
}
if (--(*ref_cnt_) == 0) {
dbgln("Deleting shared ptr: %m", ptr_);
delete ptr_;
delete ref_cnt_;
}
}
};
template <typename T, class... A>
SharedPtr<T> MakeShared(A... args) {
return {new T(args...)};
}

View File

@ -12,11 +12,11 @@ static uint64_t gNextId = 1;
}
Process* Process::RootProcess() {
SharedPtr<Process> Process::RootProcess() {
uint64_t pml4_addr = 0;
asm volatile("mov %%cr3, %0;" : "=r"(pml4_addr));
Process* proc = new Process(0, pml4_addr);
proc->threads_.PushBack(Thread::RootThread(proc));
SharedPtr<Process> proc(new Process(0, pml4_addr));
proc->threads_.PushBack(Thread::RootThread(proc.ptr()));
proc->next_thread_id_ = 1;
return proc;
@ -34,7 +34,7 @@ void Process::CreateThread(uint64_t elf_ptr) {
sched::EnqueueThread(thread);
}
Thread* Process::GetThread(uint64_t tid) {
SharedPtr<Thread> Process::GetThread(uint64_t tid) {
auto iter = threads_.begin();
while (iter != threads_.end()) {
if (iter->tid() == tid) {

View File

@ -3,6 +3,7 @@
#include <stdint.h>
#include "lib/linked_list.h"
#include "lib/shared_ptr.h"
// Forward decl due to cyclic dependency.
class Thread;
@ -15,15 +16,14 @@ class Process {
RUNNING,
FINISHED,
};
// Caller takes ownership of returned process.
static Process* RootProcess();
static SharedPtr<Process> RootProcess();
Process(uint64_t elf_ptr);
uint64_t id() { return id_; }
uint64_t cr3() { return cr3_; }
uint64_t id() const { return id_; }
uint64_t cr3() const { return cr3_; }
void CreateThread(uint64_t elf_ptr);
Thread* GetThread(uint64_t tid);
SharedPtr<Thread> GetThread(uint64_t tid);
// Checks the state of all child threads and transitions to
// finished if all have finished.
@ -39,5 +39,5 @@ class Process {
uint64_t next_thread_id_ = 0;
LinkedList<Thread*> threads_;
LinkedList<SharedPtr<Thread>> threads_;
};

View File

@ -8,7 +8,7 @@ namespace {
extern "C" void context_switch(uint64_t* current_esp, uint64_t* next_esp);
void DumpProcessStates(LinkedList<Process*>& proc_list) {
void DumpProcessStates(LinkedList<SharedPtr<Process>>& proc_list) {
dbgln("Process States: %u", proc_list.size());
auto iter = proc_list.begin();
while (iter != proc_list.end()) {
@ -20,8 +20,9 @@ void DumpProcessStates(LinkedList<Process*>& proc_list) {
class Scheduler {
public:
Scheduler() {
Process* root = Process::RootProcess();
runnable_threads_.PushBack(root->GetThread(0));
SharedPtr<Process> root = Process::RootProcess();
sleep_thread_ = root->GetThread(0);
runnable_threads_.PushBack(sleep_thread_);
proc_list_.PushBack(Process::RootProcess());
}
void Enable() { enabled_ = true; }
@ -38,7 +39,7 @@ class Scheduler {
}
asm volatile("cli");
Thread* prev = nullptr;
SharedPtr<Thread> prev;
if (CurrentThread().GetState() == Thread::RUNNING) {
prev = runnable_threads_.CycleFront();
prev->SetState(Thread::RUNNABLE);
@ -48,12 +49,14 @@ class Scheduler {
prev = runnable_threads_.PopFront();
}
SharedPtr<Thread> next;
if (runnable_threads_.size() == 0) {
next = sleep_thread_;
DumpProcessStates(proc_list_);
panic("FIXME: Implement Sleep");
} else {
next = runnable_threads_.PeekFront();
}
Thread* next = runnable_threads_.PeekFront();
if (next->GetState() != Thread::RUNNABLE) {
panic("Non-runnable thread in the queue");
}
@ -74,8 +77,10 @@ class Scheduler {
private:
bool enabled_ = false;
// TODO: move this to a separate process manager class.
LinkedList<Process*> proc_list_;
LinkedList<Thread*> runnable_threads_;
LinkedList<SharedPtr<Process>> proc_list_;
LinkedList<SharedPtr<Thread>> runnable_threads_;
SharedPtr<Thread> sleep_thread_;
};
static Scheduler* gScheduler = nullptr;

View File

@ -19,9 +19,11 @@ extern "C" void thread_init() {
} // namespace
Thread* Thread::RootThread(Process* root_proc) { return new Thread(root_proc); }
SharedPtr<Thread> Thread::RootThread(Process* root_proc) {
return new Thread(root_proc);
}
Thread::Thread(Process* proc, uint64_t tid, uint64_t elf_ptr)
Thread::Thread(const SharedPtr<Process>& proc, uint64_t tid, uint64_t elf_ptr)
: process_(proc), id_(tid), elf_ptr_(elf_ptr) {
uint64_t* stack = new uint64_t[512];
uint64_t* stack_ptr = stack + 511;

View File

@ -2,6 +2,8 @@
#include <stdint.h>
#include "lib/shared_ptr.h"
// Forward decl due to cyclic dependency.
class Process;
@ -15,9 +17,10 @@ class Thread {
BLOCKED,
FINISHED,
};
static Thread* RootThread(Process* root_proc);
static SharedPtr<Thread> RootThread(Process* root_proc);
explicit Thread(Process* proc, uint64_t tid, uint64_t elf_ptr);
explicit Thread(const SharedPtr<Process>& proc, uint64_t tid,
uint64_t elf_ptr);
uint64_t tid() { return id_; };
uint64_t pid();
@ -38,7 +41,7 @@ class Thread {
private:
// Special constructor for the root thread only.
Thread(Process* proc) : process_(proc), id_(0) {}
Process* process_;
SharedPtr<Process> process_;
uint64_t id_;
State state_ = RUNNABLE;