Make Capability RefCounted

This commit is contained in:
Drew Galbraith 2023-06-07 06:21:36 -07:00
parent 6c10c57bfa
commit 55340e2917
5 changed files with 22 additions and 111 deletions

View File

@ -8,7 +8,7 @@
class Process;
class Thread;
class Capability {
class Capability : public RefCounted<Capability> {
public:
enum Type {
UNDEFINED,

View File

@ -1,89 +0,0 @@
#pragma once
#include <stdint.h>
#include "debug/debug.h"
template <typename T>
class SharedPtr {
public:
SharedPtr() : init_(false), ptr_(0), ref_cnt_(0) {}
// Takes ownership.
SharedPtr(T* ptr) {
ptr_ = ptr;
ref_cnt_ = new uint64_t(1);
}
SharedPtr(const SharedPtr<T>& other)
: init_(other.init_), ptr_(other.ptr_), ref_cnt_(other.ref_cnt_) {
(*ref_cnt_)++;
}
SharedPtr& operator=(const SharedPtr<T>& other) {
Cleanup();
init_ = other.init_;
ptr_ = other.ptr_;
ref_cnt_ = other.ref_cnt_;
(*ref_cnt_)++;
return *this;
}
~SharedPtr() { Cleanup(); }
T& operator*() {
CheckValid();
return *ptr_;
}
const T& operator*() const {
CheckValid();
return *ptr_;
}
T* operator->() {
CheckValid();
return ptr_;
}
const T* operator->() const {
CheckValid();
return ptr_;
}
T* ptr() {
CheckValid();
return ptr_;
}
bool operator==(const SharedPtr<T>& other) {
CheckValid();
other.CheckValid();
return ptr_ == other.ptr_;
}
bool empty() { return !init_; }
private:
bool init_ = true;
T* ptr_;
uint64_t* ref_cnt_;
void Cleanup() {
if (!init_) {
return;
}
if (--(*ref_cnt_) == 0) {
dbgln("Deleting shared ptr: %m", ptr_);
delete ptr_;
delete ref_cnt_;
}
}
void CheckValid() const {
if (!init_) {
panic("Accessing invalid shared ptr");
}
}
};
template <typename T, class... A>
SharedPtr<T> MakeShared(A... args) {
return {new T(args...)};
}

View File

@ -23,10 +23,10 @@ RefPtr<Process> Process::RootProcess() {
RefPtr<Process> Process::Create() {
auto proc = MakeRefCounted<Process>();
proc->caps_.PushBack(
new Capability(proc, Capability::PROCESS, Z_INIT_PROC_SELF,
ZC_PROC_SPAWN_PROC | ZC_PROC_SPAWN_THREAD));
proc->caps_.PushBack(new Capability(proc->vmas(), Capability::ADDRESS_SPACE,
Z_INIT_VMAS_SELF, ZC_WRITE));
MakeRefCounted<Capability>(proc, Capability::PROCESS, Z_INIT_PROC_SELF,
ZC_PROC_SPAWN_PROC | ZC_PROC_SPAWN_THREAD));
proc->caps_.PushBack(MakeRefCounted<Capability>(
proc->vmas(), Capability::ADDRESS_SPACE, Z_INIT_VMAS_SELF, ZC_WRITE));
return proc;
}
@ -62,7 +62,7 @@ void Process::CheckState() {
state_ = FINISHED;
}
SharedPtr<Capability> Process::GetCapability(uint64_t cid) {
RefPtr<Capability> Process::GetCapability(uint64_t cid) {
auto iter = caps_.begin();
while (iter != caps_.end()) {
if (iter->id() == cid) {
@ -77,30 +77,31 @@ SharedPtr<Capability> Process::GetCapability(uint64_t cid) {
uint64_t Process::AddCapability(const RefPtr<Thread>& thread) {
uint64_t cap_id = next_cap_id_++;
caps_.PushBack(new Capability(thread, Capability::THREAD, cap_id, ZC_WRITE));
caps_.PushBack(
MakeRefCounted<Capability>(thread, Capability::THREAD, cap_id, ZC_WRITE));
return cap_id;
}
uint64_t Process::AddCapability(const RefPtr<Process>& p) {
uint64_t cap_id = next_cap_id_++;
caps_.PushBack(new Capability(p, Capability::PROCESS, cap_id,
ZC_WRITE | ZC_PROC_SPAWN_THREAD));
caps_.PushBack(MakeRefCounted<Capability>(p, Capability::PROCESS, cap_id,
ZC_WRITE | ZC_PROC_SPAWN_THREAD));
return cap_id;
}
uint64_t Process::AddCapability(const RefPtr<AddressSpace>& vmas) {
uint64_t cap_id = next_cap_id_++;
caps_.PushBack(
new Capability(vmas, Capability::ADDRESS_SPACE, cap_id, ZC_WRITE));
caps_.PushBack(MakeRefCounted<Capability>(vmas, Capability::ADDRESS_SPACE,
cap_id, ZC_WRITE));
return cap_id;
}
uint64_t Process::AddCapability(const RefPtr<MemoryObject>& vmmo) {
uint64_t cap_id = next_cap_id_++;
caps_.PushBack(
new Capability(vmmo, Capability::MEMORY_OBJECT, cap_id, ZC_WRITE));
caps_.PushBack(MakeRefCounted<Capability>(vmmo, Capability::MEMORY_OBJECT,
cap_id, ZC_WRITE));
return cap_id;
}
void Process::AddCapability(uint64_t cap_id, const RefPtr<MemoryObject>& vmmo) {
caps_.PushBack(
new Capability(vmmo, Capability::MEMORY_OBJECT, cap_id, ZC_WRITE));
caps_.PushBack(MakeRefCounted<Capability>(vmmo, Capability::MEMORY_OBJECT,
cap_id, ZC_WRITE));
}

View File

@ -5,7 +5,6 @@
#include "capability/capability.h"
#include "lib/linked_list.h"
#include "lib/ref_ptr.h"
#include "lib/shared_ptr.h"
#include "object/address_space.h"
// Forward decl due to cyclic dependency.
@ -28,7 +27,7 @@ class Process : public KernelObject {
RefPtr<Thread> CreateThread();
RefPtr<Thread> GetThread(uint64_t tid);
SharedPtr<Capability> GetCapability(uint64_t cid);
RefPtr<Capability> GetCapability(uint64_t cid);
uint64_t AddCapability(const RefPtr<Thread>& t);
uint64_t AddCapability(const RefPtr<Process>& p);
uint64_t AddCapability(const RefPtr<AddressSpace>& vmas);
@ -53,5 +52,5 @@ class Process : public KernelObject {
uint64_t next_cap_id_ = 0x100;
LinkedList<RefPtr<Thread>> threads_;
LinkedList<SharedPtr<Capability>> caps_;
LinkedList<RefPtr<Capability>> caps_;
};

View File

@ -59,7 +59,7 @@ void InitSyscall() {
uint64_t ProcessSpawn(ZProcessSpawnReq* req, ZProcessSpawnResp* resp) {
auto& curr_proc = gScheduler->CurrentProcess();
auto cap = curr_proc.GetCapability(req->proc_cap);
if (cap.empty()) {
if (!cap) {
return ZE_NOT_FOUND;
}
if (!cap->CheckType(Capability::PROCESS)) {
@ -80,7 +80,7 @@ uint64_t ProcessSpawn(ZProcessSpawnReq* req, ZProcessSpawnResp* resp) {
uint64_t ThreadCreate(ZThreadCreateReq* req, ZThreadCreateResp* resp) {
auto& curr_proc = gScheduler->CurrentProcess();
auto cap = curr_proc.GetCapability(req->proc_cap);
if (cap.empty()) {
if (!cap) {
return ZE_NOT_FOUND;
}
if (!cap->CheckType(Capability::PROCESS)) {
@ -101,7 +101,7 @@ uint64_t ThreadCreate(ZThreadCreateReq* req, ZThreadCreateResp* resp) {
uint64_t ThreadStart(ZThreadStartReq* req) {
auto& curr_proc = gScheduler->CurrentProcess();
auto cap = curr_proc.GetCapability(req->thread_cap);
if (cap.empty()) {
if (!cap) {
return ZE_NOT_FOUND;
}
if (!cap->CheckType(Capability::THREAD)) {
@ -122,7 +122,7 @@ uint64_t AddressSpaceMap(ZAddressSpaceMapReq* req, ZAddressSpaceMapResp* resp) {
auto& curr_proc = gScheduler->CurrentProcess();
auto vmas_cap = curr_proc.GetCapability(req->vmas_cap);
auto vmmo_cap = curr_proc.GetCapability(req->vmmo_cap);
if (vmas_cap.empty() || vmmo_cap.empty()) {
if (!vmas_cap || !vmmo_cap) {
return ZE_NOT_FOUND;
}
if (!vmas_cap->CheckType(Capability::ADDRESS_SPACE) ||