Create a RefCounted type and use it for Thread.
This should prevent me from actually creating 2 shared ptrs of a single kernel object with their separate ref counts.
This commit is contained in:
parent
d9b17d96d7
commit
2e1357255c
|
@ -0,0 +1,41 @@
|
|||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "debug/debug.h"
|
||||
|
||||
template <typename T>
|
||||
class RefCounted {
|
||||
public:
|
||||
RefCounted() {}
|
||||
~RefCounted() { dbgln("RefCounted object destroyed"); }
|
||||
void Adopt() {
|
||||
if (ref_count_ != -1) {
|
||||
panic("Adopting owned ptr");
|
||||
} else {
|
||||
ref_count_ = 1;
|
||||
}
|
||||
}
|
||||
|
||||
void Acquire() {
|
||||
if (ref_count_ == -1) {
|
||||
panic("Acquiring unowned ptr");
|
||||
}
|
||||
ref_count_++;
|
||||
}
|
||||
bool Release() {
|
||||
if (ref_count_ == -1 || ref_count_ == 0) {
|
||||
panic("Releasing unowned ptr");
|
||||
}
|
||||
return (--ref_count_) == 0;
|
||||
}
|
||||
|
||||
private:
|
||||
// FIXME: This should be an atomic type.
|
||||
uint64_t ref_count_ = -1;
|
||||
// Disallow copy and move.
|
||||
RefCounted(RefCounted&) = delete;
|
||||
RefCounted(RefCounted&&) = delete;
|
||||
RefCounted& operator=(RefCounted&) = delete;
|
||||
RefCounted& operator=(RefCounted&&) = delete;
|
||||
};
|
|
@ -0,0 +1,76 @@
|
|||
#pragma once
|
||||
|
||||
template <typename T>
|
||||
class RefPtr;
|
||||
|
||||
template <typename T>
|
||||
RefPtr<T> AdoptPtr(T* ptr);
|
||||
|
||||
template <typename T>
|
||||
class RefPtr {
|
||||
public:
|
||||
RefPtr() : ptr_(nullptr) {}
|
||||
RefPtr(decltype(nullptr)) : ptr_(nullptr) {}
|
||||
RefPtr(const RefPtr& other) : ptr_(other.ptr_) {
|
||||
if (ptr_) {
|
||||
ptr_->Acquire();
|
||||
}
|
||||
}
|
||||
RefPtr& operator=(const RefPtr& other) {
|
||||
T* old = ptr_;
|
||||
ptr_ = other.ptr_;
|
||||
if (ptr_) {
|
||||
ptr_->Acquire();
|
||||
}
|
||||
if (old && old->Release()) {
|
||||
delete old;
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
RefPtr(RefPtr&& other) : ptr_(other.ptr_) { other.ptr_ = nullptr; }
|
||||
RefPtr& operator=(RefPtr&& other) {
|
||||
// Swap
|
||||
T* ptr = ptr_;
|
||||
ptr_ = other.ptr_;
|
||||
other.ptr_ = ptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
T* get() const { return ptr_; };
|
||||
T& operator*() const { return *ptr_; }
|
||||
T* operator->() const { return ptr_; }
|
||||
operator bool() const { return ptr_ != nullptr; }
|
||||
|
||||
bool operator==(decltype(nullptr)) const { return (ptr_ == nullptr); }
|
||||
bool operator!=(decltype(nullptr)) const { return (ptr_ != nullptr); }
|
||||
|
||||
bool operator==(const RefPtr<T>& other) const { return (ptr_ == other.ptr_); }
|
||||
bool operator!=(const RefPtr<T>& other) const { return (ptr_ != other.ptr_); }
|
||||
|
||||
private:
|
||||
T* ptr_;
|
||||
|
||||
friend RefPtr<T> AdoptPtr<T>(T* ptr);
|
||||
RefPtr(T* ptr) : ptr_(ptr) { ptr->Adopt(); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MakeRefCountedFriend final {
|
||||
public:
|
||||
template <typename... Args>
|
||||
static RefPtr<T> Make(Args&&... args) {
|
||||
return AdoptPtr(new T(args...));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename... Args>
|
||||
RefPtr<T> MakeRefCounted(Args&&... args) {
|
||||
return MakeRefCountedFriend<T>::Make(args...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RefPtr<T> AdoptPtr(T* ptr) {
|
||||
return RefPtr(ptr);
|
||||
}
|
|
@ -26,13 +26,13 @@ Process::Process() : id_(gNextId++), state_(RUNNING) {
|
|||
ZC_PROC_SPAWN_PROC | ZC_PROC_SPAWN_THREAD));
|
||||
}
|
||||
|
||||
SharedPtr<Thread> Process::CreateThread() {
|
||||
SharedPtr<Thread> thread{new Thread(*this, next_thread_id_++)};
|
||||
RefPtr<Thread> Process::CreateThread() {
|
||||
RefPtr<Thread> thread = MakeRefCounted<Thread>(*this, next_thread_id_++);
|
||||
threads_.PushBack(thread);
|
||||
return thread;
|
||||
}
|
||||
|
||||
SharedPtr<Thread> Process::GetThread(uint64_t tid) {
|
||||
RefPtr<Thread> Process::GetThread(uint64_t tid) {
|
||||
auto iter = threads_.begin();
|
||||
while (iter != threads_.end()) {
|
||||
if (iter->tid() == tid) {
|
||||
|
@ -67,9 +67,9 @@ SharedPtr<Capability> Process::GetCapability(uint64_t cid) {
|
|||
return {};
|
||||
}
|
||||
|
||||
uint64_t Process::AddCapability(SharedPtr<Thread>& thread) {
|
||||
uint64_t Process::AddCapability(RefPtr<Thread>& thread) {
|
||||
uint64_t cap_id = next_cap_id_++;
|
||||
caps_.PushBack(
|
||||
new Capability(thread.ptr(), Capability::THREAD, cap_id, ZC_WRITE));
|
||||
new Capability(thread.get(), Capability::THREAD, cap_id, ZC_WRITE));
|
||||
return cap_id;
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#include "capability/capability.h"
|
||||
#include "lib/linked_list.h"
|
||||
#include "lib/ref_ptr.h"
|
||||
#include "lib/shared_ptr.h"
|
||||
#include "memory/virtual_memory.h"
|
||||
|
||||
|
@ -24,11 +25,11 @@ class Process {
|
|||
uint64_t id() const { return id_; }
|
||||
VirtualMemory& vmm() { return vmm_; }
|
||||
|
||||
SharedPtr<Thread> CreateThread();
|
||||
SharedPtr<Thread> GetThread(uint64_t tid);
|
||||
RefPtr<Thread> CreateThread();
|
||||
RefPtr<Thread> GetThread(uint64_t tid);
|
||||
|
||||
SharedPtr<Capability> GetCapability(uint64_t cid);
|
||||
uint64_t AddCapability(SharedPtr<Thread>& t);
|
||||
uint64_t AddCapability(RefPtr<Thread>& t);
|
||||
// Checks the state of all child threads and transitions to
|
||||
// finished if all have finished.
|
||||
void CheckState();
|
||||
|
@ -44,6 +45,6 @@ class Process {
|
|||
uint64_t next_thread_id_ = 0;
|
||||
uint64_t next_cap_id_ = 0x100;
|
||||
|
||||
LinkedList<SharedPtr<Thread>> threads_;
|
||||
LinkedList<RefPtr<Thread>> threads_;
|
||||
LinkedList<SharedPtr<Capability>> caps_;
|
||||
};
|
||||
|
|
|
@ -48,7 +48,7 @@ void Scheduler::Preempt() {
|
|||
return;
|
||||
}
|
||||
|
||||
SharedPtr<Thread> prev = current_thread_;
|
||||
RefPtr<Thread> prev = current_thread_;
|
||||
prev->SetState(Thread::RUNNABLE);
|
||||
current_thread_ = runnable_threads_.CycleFront(prev);
|
||||
|
||||
|
@ -62,7 +62,7 @@ void Scheduler::Yield() {
|
|||
}
|
||||
asm volatile("cli");
|
||||
|
||||
SharedPtr<Thread> prev = current_thread_;
|
||||
RefPtr<Thread> prev = current_thread_;
|
||||
if (prev == sleep_thread_) {
|
||||
if (runnable_threads_.size() == 0) {
|
||||
panic("Sleep thread yielded without next.");
|
||||
|
|
|
@ -15,7 +15,7 @@ class Scheduler {
|
|||
Process& CurrentProcess() { return current_thread_->process(); }
|
||||
Thread& CurrentThread() { return *current_thread_; }
|
||||
|
||||
void Enqueue(const SharedPtr<Thread> thread) {
|
||||
void Enqueue(const RefPtr<Thread>& thread) {
|
||||
runnable_threads_.PushBack(thread);
|
||||
}
|
||||
|
||||
|
@ -25,10 +25,10 @@ class Scheduler {
|
|||
private:
|
||||
bool enabled_ = false;
|
||||
|
||||
SharedPtr<Thread> current_thread_;
|
||||
LinkedList<SharedPtr<Thread>> runnable_threads_;
|
||||
RefPtr<Thread> current_thread_;
|
||||
LinkedList<RefPtr<Thread>> runnable_threads_;
|
||||
|
||||
SharedPtr<Thread> sleep_thread_;
|
||||
RefPtr<Thread> sleep_thread_;
|
||||
|
||||
Scheduler();
|
||||
void SwapToCurrent(Thread& prev);
|
||||
|
|
|
@ -20,8 +20,12 @@ extern "C" void thread_init() {
|
|||
|
||||
} // namespace
|
||||
|
||||
SharedPtr<Thread> Thread::RootThread(Process& root_proc) {
|
||||
return new Thread(root_proc);
|
||||
RefPtr<Thread> Thread::RootThread(Process& root_proc) {
|
||||
return MakeRefCounted<Thread>(root_proc);
|
||||
}
|
||||
|
||||
RefPtr<Thread> Thread::Create(Process& proc, uint64_t tid) {
|
||||
return MakeRefCounted<Thread>(proc, tid);
|
||||
}
|
||||
|
||||
Thread::Thread(Process& proc, uint64_t tid) : process_(proc), id_(tid) {
|
||||
|
|
|
@ -2,12 +2,13 @@
|
|||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "lib/shared_ptr.h"
|
||||
#include "lib/ref_counted.h"
|
||||
#include "lib/ref_ptr.h"
|
||||
|
||||
// Forward decl due to cyclic dependency.
|
||||
class Process;
|
||||
|
||||
class Thread {
|
||||
class Thread : public RefCounted<Thread> {
|
||||
public:
|
||||
enum State {
|
||||
UNSPECIFIED,
|
||||
|
@ -16,9 +17,8 @@ class Thread {
|
|||
RUNNABLE,
|
||||
FINISHED,
|
||||
};
|
||||
static SharedPtr<Thread> RootThread(Process& root_proc);
|
||||
|
||||
Thread(Process& proc, uint64_t tid);
|
||||
static RefPtr<Thread> RootThread(Process& root_proc);
|
||||
static RefPtr<Thread> Create(Process& proc, uint64_t tid);
|
||||
|
||||
uint64_t tid() const { return id_; };
|
||||
uint64_t pid() const;
|
||||
|
@ -40,6 +40,8 @@ class Thread {
|
|||
void Exit();
|
||||
|
||||
private:
|
||||
friend class MakeRefCountedFriend<Thread>;
|
||||
Thread(Process& proc, uint64_t tid);
|
||||
// Special constructor for the root thread only.
|
||||
Thread(Process& proc) : process_(proc), id_(0) {}
|
||||
Process& process_;
|
||||
|
|
Loading…
Reference in New Issue