[zion] Dynamically check Capability type.

Instead of passing an enum with the capability when creating it, relying
on polymorphism and a template struct tag to determine the object type
at runtime.

This is cleaner and avoids errors where we pass the wrong capability
type with the cap and do a bad cast at runtime.
This commit is contained in:
Drew Galbraith 2023-06-16 14:53:57 -07:00
parent b4902a79ef
commit a47bac9966
13 changed files with 113 additions and 106 deletions

View File

@ -1,7 +1,6 @@
add_executable(zion
boot/acpi.cpp
boot/boot_info.cpp
capability/capability.cpp
capability/capability_table.cpp
common/gdt.cpp
common/load_gdt.s

View File

@ -1,52 +0,0 @@
#include "capability/capability.h"
#include "object/process.h"
#include "object/thread.h"
template <>
RefPtr<Process> Capability::obj<Process>() {
if (type_ != PROCESS) {
panic("Accessing %u cap as object.", type_);
}
return StaticCastRefPtr<Process>(obj_);
}
template <>
RefPtr<Thread> Capability::obj<Thread>() {
if (type_ != THREAD) {
panic("Accessing %u cap as object.", type_);
}
return StaticCastRefPtr<Thread>(obj_);
}
template <>
RefPtr<AddressSpace> Capability::obj<AddressSpace>() {
if (type_ != ADDRESS_SPACE) {
panic("Accessing %u cap as object.", type_);
}
return StaticCastRefPtr<AddressSpace>(obj_);
}
template <>
RefPtr<MemoryObject> Capability::obj<MemoryObject>() {
if (type_ != MEMORY_OBJECT) {
panic("Accessing %u cap as object.", type_);
}
return StaticCastRefPtr<MemoryObject>(obj_);
}
template <>
RefPtr<Channel> Capability::obj<Channel>() {
if (type_ != CHANNEL) {
panic("Accessing %u cap as object.", type_);
}
return StaticCastRefPtr<Channel>(obj_);
}
template <>
RefPtr<Port> Capability::obj<Port>() {
if (type_ != PORT) {
panic("Accessing %u cap as object.", type_);
}
return StaticCastRefPtr<Port>(obj_);
}

View File

@ -10,23 +10,12 @@ class Thread;
class Capability : public RefCounted<Capability> {
public:
enum Type {
UNDEFINED,
PROCESS,
THREAD,
ADDRESS_SPACE,
MEMORY_OBJECT,
CHANNEL,
PORT,
};
Capability(const RefPtr<KernelObject>& obj, Type type, uint64_t id,
uint64_t permissions)
: obj_(obj), type_(type), id_(id), permissions_(permissions) {}
Capability(const RefPtr<KernelObject>& obj, uint64_t id, uint64_t permissions)
: obj_(obj), id_(id), permissions_(permissions) {}
template <typename T>
Capability(const RefPtr<T>& obj, Type type, uint64_t id, uint64_t permissions)
: Capability(StaticCastRefPtr<KernelObject>(obj), type, id, permissions) {
}
Capability(const RefPtr<T>& obj, uint64_t id, uint64_t permissions)
: Capability(StaticCastRefPtr<KernelObject>(obj), id, permissions) {}
template <typename T>
RefPtr<T> obj();
@ -34,8 +23,6 @@ class Capability : public RefCounted<Capability> {
uint64_t id() { return id_; }
void set_id(uint64_t id) { id_ = id; }
bool CheckType(Type type) { return type_ == type; }
uint64_t permissions() { return permissions_; }
bool HasPermissions(uint64_t requested) {
return (permissions_ & requested) == requested;
@ -43,7 +30,14 @@ class Capability : public RefCounted<Capability> {
private:
RefPtr<KernelObject> obj_;
Type type_;
uint64_t id_;
uint64_t permissions_;
};
template <typename T>
RefPtr<T> Capability::obj() {
if (obj_->TypeTag() != KernelObjectTag<T>::type) {
return nullptr;
}
return StaticCastRefPtr<T>(obj_);
}

View File

@ -13,14 +13,13 @@ class CapabilityTable {
CapabilityTable& operator=(CapabilityTable&) = delete;
template <typename T>
uint64_t AddNewCapability(const RefPtr<T>& object, Capability::Type type,
uint64_t permissions);
uint64_t AddNewCapability(const RefPtr<T>& object, uint64_t permissions);
uint64_t AddExistingCapability(const RefPtr<Capability>& cap);
// FIXME: Remove reliance on this.
template <typename T>
void AddNewCapabilityWithId(uint64_t id, const RefPtr<T>& object,
Capability::Type type, uint64_t permissions);
uint64_t permissions);
RefPtr<Capability> GetCapability(uint64_t id);
RefPtr<Capability> ReleaseCapability(uint64_t id);
@ -35,20 +34,16 @@ class CapabilityTable {
template <typename T>
uint64_t CapabilityTable::AddNewCapability(const RefPtr<T>& object,
Capability::Type type,
uint64_t permissions) {
MutexHolder h(lock_);
uint64_t id = next_cap_id_++;
capabilities_.PushBack(
MakeRefCounted<Capability>(object, type, id, permissions));
capabilities_.PushBack(MakeRefCounted<Capability>(object, id, permissions));
return id;
}
template <typename T>
void CapabilityTable::AddNewCapabilityWithId(uint64_t id,
const RefPtr<T>& object,
Capability::Type type,
uint64_t permissions) {
capabilities_.PushBack(
MakeRefCounted<Capability>(object, type, id, permissions));
capabilities_.PushBack(MakeRefCounted<Capability>(object, id, permissions));
}

View File

@ -6,6 +6,13 @@
#include "memory/user_stack_manager.h"
#include "object/memory_object.h"
class AddressSpace;
template <>
struct KernelObjectTag<AddressSpace> {
static const uint64_t type = KernelObject::ADDRESS_SPACE;
};
// VirtualMemory class holds a memory space for an individual process.
//
// Memory Regions are predefined for simplicity for now. However, in general
@ -26,6 +33,8 @@
// 0xFFFFFFFF 90000000 - 0xFFFFFFFF 9FFFFFFF : KERNEL_STACK (256 MiB)
class AddressSpace : public KernelObject {
public:
uint64_t TypeTag() override { return KernelObject::ADDRESS_SPACE; }
enum MemoryType {
UNSPECIFIED,
UNMAPPED,

View File

@ -10,8 +10,16 @@
#include "object/kernel_object.h"
#include "usr/zcall_internal.h"
class Channel;
template <>
struct KernelObjectTag<Channel> {
static const uint64_t type = KernelObject::CHANNEL;
};
class Channel : public KernelObject {
public:
uint64_t TypeTag() override { return KernelObject::CHANNEL; }
static Pair<RefPtr<Channel>, RefPtr<Channel>> CreateChannelPair();
RefPtr<Channel> peer() { return peer_; }

View File

@ -2,4 +2,22 @@
#include "lib/ref_counted.h"
class KernelObject : public RefCounted<KernelObject> {};
class KernelObject : public RefCounted<KernelObject> {
public:
enum ObjectType {
INVALID = 0x0,
PROCESS = 0x1,
THREAD = 0x2,
ADDRESS_SPACE = 0x3,
MEMORY_OBJECT = 0x4,
CHANNEL = 0x5,
PORT = 0x6,
};
virtual uint64_t TypeTag() = 0;
};
template <typename T>
struct KernelObjectTag {
static const int type = KernelObject::INVALID;
};

View File

@ -3,6 +3,12 @@
#include "lib/linked_list.h"
#include "object/kernel_object.h"
class MemoryObject;
template <>
struct KernelObjectTag<MemoryObject> {
static const uint64_t type = KernelObject::MEMORY_OBJECT;
};
/*
* MemoryObject is a page-aligned set of memory that corresponds
* to physical pages.
@ -11,6 +17,7 @@
* */
class MemoryObject : public KernelObject {
public:
uint64_t TypeTag() override { return KernelObject::MEMORY_OBJECT; }
MemoryObject(uint64_t size);
uint64_t size() { return size_; }

View File

@ -6,8 +6,17 @@
#include "object/thread.h"
#include "usr/zcall_internal.h"
class Port;
template <>
struct KernelObjectTag<Port> {
static const uint64_t type = KernelObject::PORT;
};
class Port : public KernelObject {
public:
uint64_t TypeTag() override { return KernelObject::PORT; }
Port();
z_err_t Write(const ZMessage& msg);

View File

@ -23,10 +23,8 @@ RefPtr<Process> Process::RootProcess() {
RefPtr<Process> Process::Create() {
auto proc = MakeRefCounted<Process>();
proc->caps_.AddNewCapabilityWithId(Z_INIT_PROC_SELF, proc,
Capability::PROCESS,
ZC_PROC_SPAWN_PROC | ZC_PROC_SPAWN_THREAD);
proc->caps_.AddNewCapabilityWithId(Z_INIT_VMAS_SELF, proc->vmas(),
Capability::ADDRESS_SPACE, ZC_WRITE);
proc->caps_.AddNewCapabilityWithId(Z_INIT_VMAS_SELF, proc->vmas(), ZC_WRITE);
return proc;
}
@ -77,31 +75,29 @@ uint64_t Process::AddCapability(const RefPtr<Capability>& cap) {
return caps_.AddExistingCapability(cap);
}
uint64_t Process::AddCapability(const RefPtr<Thread>& thread) {
return caps_.AddNewCapability(thread, Capability::THREAD, ZC_WRITE);
return caps_.AddNewCapability(thread, ZC_WRITE);
}
uint64_t Process::AddCapability(const RefPtr<Process>& proc) {
return caps_.AddNewCapability(proc, Capability::PROCESS,
ZC_WRITE | ZC_PROC_SPAWN_THREAD);
return caps_.AddNewCapability(proc, ZC_WRITE | ZC_PROC_SPAWN_THREAD);
}
uint64_t Process::AddCapability(const RefPtr<AddressSpace>& vmas) {
return caps_.AddNewCapability(vmas, Capability::ADDRESS_SPACE, ZC_WRITE);
return caps_.AddNewCapability(vmas, ZC_WRITE);
}
uint64_t Process::AddCapability(const RefPtr<MemoryObject>& vmmo) {
return caps_.AddNewCapability(vmmo, Capability::MEMORY_OBJECT, ZC_WRITE);
return caps_.AddNewCapability(vmmo, ZC_WRITE);
}
uint64_t Process::AddCapability(const RefPtr<Channel>& chan) {
return caps_.AddNewCapability(chan, Capability::CHANNEL, ZC_WRITE | ZC_READ);
return caps_.AddNewCapability(chan, ZC_WRITE | ZC_READ);
}
uint64_t Process::AddCapability(const RefPtr<Port>& port) {
return caps_.AddNewCapability(port, Capability::PORT, ZC_WRITE | ZC_READ);
return caps_.AddNewCapability(port, ZC_WRITE | ZC_READ);
}
void Process::AddCapability(uint64_t cap_id, const RefPtr<MemoryObject>& vmmo) {
caps_.AddNewCapabilityWithId(cap_id, vmmo, Capability::MEMORY_OBJECT,
ZC_WRITE);
caps_.AddNewCapabilityWithId(cap_id, vmmo, ZC_WRITE);
}

View File

@ -14,8 +14,14 @@
// Forward decl due to cyclic dependency.
class Thread;
template <>
struct KernelObjectTag<Process> {
static const uint64_t type = KernelObject::PROCESS;
};
class Process : public KernelObject {
public:
uint64_t TypeTag() override { return KernelObject::PROCESS; }
enum State {
UNSPECIFIED,
SETUP,

View File

@ -7,9 +7,16 @@
// Forward decl due to cyclic dependency.
class Process;
class Thread;
template <>
struct KernelObjectTag<Thread> {
static const uint64_t type = KernelObject::THREAD;
};
class Thread : public KernelObject {
public:
uint64_t TypeTag() override { return KernelObject::THREAD; }
enum State {
UNSPECIFIED,
CREATED,

View File

@ -16,6 +16,13 @@
#include "scheduler/scheduler.h"
#include "usr/zcall_internal.h"
#define RET_IF_NULL(expr) \
{ \
if (!expr) { \
return Z_ERR_CAP_TYPE; \
} \
}
#define EFER 0xC0000080
#define STAR 0xC0000081
#define LSTAR 0xC0000082
@ -46,14 +53,10 @@ void InitSyscall() {
SetMSR(LSTAR, reinterpret_cast<uint64_t>(syscall_enter));
}
z_err_t ValidateCap(const RefPtr<Capability>& cap, Capability::Type type,
uint64_t permissions) {
z_err_t ValidateCap(const RefPtr<Capability>& cap, uint64_t permissions) {
if (!cap) {
return Z_ERR_CAP_NOT_FOUND;
}
if (!cap->CheckType(type)) {
return Z_ERR_CAP_TYPE;
}
if (!cap->HasPermissions(permissions)) {
return Z_ERR_CAP_DENIED;
}
@ -63,7 +66,7 @@ z_err_t ValidateCap(const RefPtr<Capability>& cap, Capability::Type type,
z_err_t ProcessSpawn(ZProcessSpawnReq* req, ZProcessSpawnResp* resp) {
auto& curr_proc = gScheduler->CurrentProcess();
auto cap = curr_proc.GetCapability(req->proc_cap);
RET_ERR(ValidateCap(cap, Capability::PROCESS, ZC_PROC_SPAWN_PROC));
RET_ERR(ValidateCap(cap, ZC_PROC_SPAWN_PROC));
RefPtr<Process> proc = Process::Create();
gProcMan->InsertProcess(proc);
@ -86,9 +89,10 @@ z_err_t ProcessSpawn(ZProcessSpawnReq* req, ZProcessSpawnResp* resp) {
z_err_t ThreadCreate(ZThreadCreateReq* req, ZThreadCreateResp* resp) {
auto& curr_proc = gScheduler->CurrentProcess();
auto cap = curr_proc.GetCapability(req->proc_cap);
RET_ERR(ValidateCap(cap, Capability::PROCESS, ZC_PROC_SPAWN_THREAD));
RET_ERR(ValidateCap(cap, ZC_PROC_SPAWN_THREAD));
auto parent_proc = cap->obj<Process>();
RET_IF_NULL(parent_proc);
auto thread = parent_proc->CreateThread();
resp->thread_cap = curr_proc.AddCapability(thread);
@ -98,9 +102,10 @@ z_err_t ThreadCreate(ZThreadCreateReq* req, ZThreadCreateResp* resp) {
z_err_t ThreadStart(ZThreadStartReq* req) {
auto& curr_proc = gScheduler->CurrentProcess();
auto cap = curr_proc.GetCapability(req->thread_cap);
RET_ERR(ValidateCap(cap, Capability::THREAD, ZC_WRITE));
RET_ERR(ValidateCap(cap, ZC_WRITE));
auto thread = cap->obj<Thread>();
RET_IF_NULL(thread);
// FIXME: validate entry point is in user space.
thread->Start(req->entry, req->arg1, req->arg2);
return Z_OK;
@ -110,11 +115,14 @@ z_err_t AddressSpaceMap(ZAddressSpaceMapReq* req, ZAddressSpaceMapResp* resp) {
auto& curr_proc = gScheduler->CurrentProcess();
auto vmas_cap = curr_proc.GetCapability(req->vmas_cap);
auto vmmo_cap = curr_proc.GetCapability(req->vmmo_cap);
RET_ERR(ValidateCap(vmas_cap, Capability::ADDRESS_SPACE, ZC_WRITE));
RET_ERR(ValidateCap(vmmo_cap, Capability::MEMORY_OBJECT, ZC_WRITE));
RET_ERR(ValidateCap(vmas_cap, ZC_WRITE));
RET_ERR(ValidateCap(vmmo_cap, ZC_WRITE));
auto vmas = vmas_cap->obj<AddressSpace>();
auto vmmo = vmmo_cap->obj<MemoryObject>();
RET_IF_NULL(vmas);
RET_IF_NULL(vmmo);
dbgln("Ptr %x, %x", vmas.get(), vmmo.get());
// FIXME: Validation necessary.
if (req->vmas_offset != 0) {
vmas->MapInMemoryObject(req->vmas_offset, vmmo);
@ -169,9 +177,10 @@ z_err_t ChannelCreate(ZChannelCreateResp* resp) {
z_err_t ChannelSend(ZChannelSendReq* req) {
auto& proc = gScheduler->CurrentProcess();
auto chan_cap = proc.GetCapability(req->chan_cap);
RET_ERR(ValidateCap(chan_cap, Capability::CHANNEL, ZC_WRITE));
RET_ERR(ValidateCap(chan_cap, ZC_WRITE));
auto chan = chan_cap->obj<Channel>();
RET_IF_NULL(chan);
chan->Write(req->message);
return Z_OK;
}
@ -179,18 +188,20 @@ z_err_t ChannelSend(ZChannelSendReq* req) {
z_err_t ChannelRecv(ZChannelRecvReq* req) {
auto& proc = gScheduler->CurrentProcess();
auto chan_cap = proc.GetCapability(req->chan_cap);
RET_ERR(ValidateCap(chan_cap, Capability::CHANNEL, ZC_READ));
RET_ERR(ValidateCap(chan_cap, ZC_READ));
auto chan = chan_cap->obj<Channel>();
RET_IF_NULL(chan);
return chan->Read(req->message);
}
z_err_t PortRecv(ZPortRecvReq* req) {
auto& proc = gScheduler->CurrentProcess();
auto port_cap = proc.GetCapability(req->port_cap);
RET_ERR(ValidateCap(port_cap, Capability::PORT, ZC_READ));
RET_ERR(ValidateCap(port_cap, ZC_READ));
auto port = port_cap->obj<Port>();
RET_IF_NULL(port);
return port->Read(req->message);
}