diff --git a/lib/mammoth/include/mammoth/endpoint_server.h b/lib/mammoth/include/mammoth/endpoint_server.h index c4e3e5a..7224d11 100644 --- a/lib/mammoth/include/mammoth/endpoint_server.h +++ b/lib/mammoth/include/mammoth/endpoint_server.h @@ -5,6 +5,8 @@ #include #include "mammoth/endpoint_client.h" +#include "mammoth/request_context.h" +#include "mammoth/response_context.h" class EndpointServer { public: @@ -12,19 +14,24 @@ class EndpointServer { EndpointServer(const EndpointServer&) = delete; EndpointServer& operator=(const EndpointServer&) = delete; - static glcr::ErrorOr> Create(); - static glcr::UniquePtr Adopt(z_cap_t endpoint_cap); - glcr::ErrorOr> CreateClient(); - // FIXME: Release Cap here. z_cap_t GetCap() { return endpoint_cap_; } - glcr::ErrorCode Recieve(uint64_t* num_bytes, void* data, + glcr::ErrorCode Receive(uint64_t* num_bytes, void* data, z_cap_t* reply_port_cap); + glcr::ErrorCode RunServer(); + + virtual glcr::ErrorCode HandleRequest(RequestContext& request, + ResponseContext& response) = 0; + + protected: + EndpointServer(z_cap_t cap) : endpoint_cap_(cap) {} + private: z_cap_t endpoint_cap_; - EndpointServer(z_cap_t cap) : endpoint_cap_(cap) {} + static const uint64_t kBufferSize = 1024; + uint8_t recieve_buffer_[kBufferSize]; }; diff --git a/lib/mammoth/include/mammoth/request_context.h b/lib/mammoth/include/mammoth/request_context.h new file mode 100644 index 0000000..0f8ca96 --- /dev/null +++ b/lib/mammoth/include/mammoth/request_context.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +class RequestContext { + public: + RequestContext(void* buffer, uint64_t buffer_length) + : buffer_(buffer), buffer_length_(buffer_length) { + if (buffer_length_ < sizeof(uint64_t)) { + request_id_ = -1; + } else { + request_id_ = *reinterpret_cast(buffer); + } + } + + uint64_t request_id() { return request_id_; } + + template + glcr::ErrorCode As(T** arg) { + if (buffer_length_ < sizeof(T)) { + return glcr::INVALID_ARGUMENT; + } + *arg = reinterpret_cast(buffer_); + return glcr::OK; + } + + private: + uint64_t request_id_; + void* buffer_; + uint64_t buffer_length_; +}; diff --git a/lib/mammoth/include/mammoth/response_context.h b/lib/mammoth/include/mammoth/response_context.h new file mode 100644 index 0000000..fbd0d2a --- /dev/null +++ b/lib/mammoth/include/mammoth/response_context.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include +#include + +class ResponseContext { + public: + ResponseContext(z_cap_t reply_port) : reply_port_(reply_port) {} + + ResponseContext(ResponseContext&) = delete; + + template + glcr::ErrorCode WriteStruct(const T& response) { + // FIXME: Here and below probably don't count as written on error. + written_ = true; + return ZReplyPortSend(reply_port_, sizeof(T), &response, 0, nullptr); + } + + template + glcr::ErrorCode WriteStructWithCap(const T& response, z_cap_t capability) { + written_ = true; + return ZReplyPortSend(reply_port_, sizeof(T), &response, 1, &capability); + } + + glcr::ErrorCode WriteError(glcr::ErrorCode code) { + uint64_t response[2]{ + static_cast(-1), + code, + }; + written_ = true; + return ZReplyPortSend(reply_port_, sizeof(response), &response, 0, nullptr); + } + + bool HasWritten() { return written_; } + + private: + z_cap_t reply_port_; + bool written_ = false; +}; diff --git a/lib/mammoth/src/endpoint_server.cpp b/lib/mammoth/src/endpoint_server.cpp index cacc5a6..42f5ecd 100644 --- a/lib/mammoth/src/endpoint_server.cpp +++ b/lib/mammoth/src/endpoint_server.cpp @@ -1,14 +1,6 @@ #include "mammoth/endpoint_server.h" -glcr::ErrorOr> EndpointServer::Create() { - uint64_t cap; - RET_ERR(ZEndpointCreate(&cap)); - return glcr::UniquePtr(new EndpointServer(cap)); -} - -glcr::UniquePtr EndpointServer::Adopt(z_cap_t endpoint_cap) { - return glcr::UniquePtr(new EndpointServer(endpoint_cap)); -} +#include "mammoth/debug.h" glcr::ErrorOr> EndpointServer::CreateClient() { uint64_t client_cap; @@ -17,7 +9,24 @@ glcr::ErrorOr> EndpointServer::CreateClient() { return EndpointClient::AdoptEndpoint(client_cap); } -glcr::ErrorCode EndpointServer::Recieve(uint64_t* num_bytes, void* data, +glcr::ErrorCode EndpointServer::Receive(uint64_t* num_bytes, void* data, z_cap_t* reply_port_cap) { return ZEndpointRecv(endpoint_cap_, num_bytes, data, reply_port_cap); } + +glcr::ErrorCode EndpointServer::RunServer() { + while (true) { + uint64_t message_size = kBufferSize; + uint64_t reply_port_cap = 0; + RET_ERR(Receive(&message_size, recieve_buffer_, &reply_port_cap)); + + RequestContext request(recieve_buffer_, message_size); + ResponseContext response(reply_port_cap); + // FIXME: Consider pumping these errors into the response as well. + RET_ERR(HandleRequest(request, response)); + if (!response.HasWritten()) { + dbgln("Returning without having written a response. Req type %x", + request.request_id()); + } + } +} diff --git a/sys/denali/ahci/command.cpp b/sys/denali/ahci/command.cpp index aa8a020..986ca6d 100644 --- a/sys/denali/ahci/command.cpp +++ b/sys/denali/ahci/command.cpp @@ -7,8 +7,8 @@ Command::~Command() {} DmaReadCommand::DmaReadCommand(uint64_t lba, uint64_t sector_cnt, - DmaCallback callback, z_cap_t reply_port) - : reply_port_(reply_port), + DmaCallback callback, ResponseContext& response) + : response_(response), lba_(lba), sector_cnt_(sector_cnt), callback_(callback) { @@ -50,5 +50,5 @@ void DmaReadCommand::PopulatePrdt(PhysicalRegionDescriptor* prdt) { prdt[0].byte_count = region_.size(); } void DmaReadCommand::Callback() { - callback_(reply_port_, lba_, sector_cnt_, region_.cap()); + callback_(response_, lba_, sector_cnt_, region_.cap()); } diff --git a/sys/denali/ahci/command.h b/sys/denali/ahci/command.h index f502b01..9870685 100644 --- a/sys/denali/ahci/command.h +++ b/sys/denali/ahci/command.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include "ahci/ahci.h" @@ -15,9 +16,9 @@ class Command { class DmaReadCommand : public Command { public: - typedef void (*DmaCallback)(z_cap_t, uint64_t, uint64_t, z_cap_t); + typedef void (*DmaCallback)(ResponseContext&, uint64_t, uint64_t, z_cap_t); DmaReadCommand(uint64_t lba, uint64_t sector_cnt, DmaCallback callback, - z_cap_t reply_port); + ResponseContext& reply_port); virtual ~DmaReadCommand() override; @@ -27,7 +28,7 @@ class DmaReadCommand : public Command { void Callback() override; private: - z_cap_t reply_port_; + ResponseContext& response_; uint64_t lba_; uint64_t sector_cnt_; DmaCallback callback_; diff --git a/sys/denali/client/denali_client.cpp b/sys/denali/client/denali_client.cpp index 6f870b2..9d95f00 100644 --- a/sys/denali/client/denali_client.cpp +++ b/sys/denali/client/denali_client.cpp @@ -6,13 +6,14 @@ glcr::ErrorOr DenaliClient::ReadSectors( uint64_t device_id, uint64_t lba, uint64_t num_sectors) { - DenaliRead read{ + DenaliReadRequest read{ .device_id = device_id, .lba = lba, .size = num_sectors, }; auto pair_or = - endpoint_->CallEndpointGetCap(read); + endpoint_->CallEndpointGetCap( + read); if (!pair_or) { return pair_or.error(); } diff --git a/sys/denali/denali.cpp b/sys/denali/denali.cpp index 382ec07..043de69 100644 --- a/sys/denali/denali.cpp +++ b/sys/denali/denali.cpp @@ -19,14 +19,14 @@ uint64_t main(uint64_t init_port_cap) { ASSIGN_OR_RETURN(MappedMemoryRegion ahci_region, stub.GetAhciConfig()); ASSIGN_OR_RETURN(auto driver, AhciDriver::Init(ahci_region)); - ASSIGN_OR_RETURN(glcr::UniquePtr endpoint, - EndpointServer::Create()); + ASSIGN_OR_RETURN(glcr::UniquePtr server, + DenaliServer::Create(*driver)); + ASSIGN_OR_RETURN(glcr::UniquePtr client, - endpoint->CreateClient()); + server->CreateClient()); check(stub.Register("denali", *client)); - DenaliServer server(glcr::Move(endpoint), *driver); - RET_ERR(server.RunServer()); + RET_ERR(server->RunServer()); // FIXME: Add thread join. return 0; } diff --git a/sys/denali/denali_server.cpp b/sys/denali/denali_server.cpp index f34b21a..caa29a4 100644 --- a/sys/denali/denali_server.cpp +++ b/sys/denali/denali_server.cpp @@ -7,61 +7,57 @@ namespace { DenaliServer* gServer = nullptr; -void HandleResponse(z_cap_t reply_port, uint64_t lba, uint64_t size, +void HandleResponse(ResponseContext& response, uint64_t lba, uint64_t size, z_cap_t mem) { - gServer->HandleResponse(reply_port, lba, size, mem); + gServer->HandleResponse(response, lba, size, mem); } } // namespace -DenaliServer::DenaliServer(glcr::UniquePtr server, - AhciDriver& driver) - : server_(glcr::Move(server)), driver_(driver) { - gServer = this; +glcr::ErrorOr> DenaliServer::Create( + AhciDriver& driver) { + z_cap_t cap; + RET_ERR(ZEndpointCreate(&cap)); + return glcr::UniquePtr(new DenaliServer(cap, driver)); } -glcr::ErrorCode DenaliServer::RunServer() { - while (true) { - uint64_t buff_size = kBuffSize; - z_cap_t reply_port; - RET_ERR(server_->Recieve(&buff_size, read_buffer_, &reply_port)); - if (buff_size < sizeof(uint64_t)) { - dbgln("Skipping invalid message"); - continue; - } - uint64_t type = *reinterpret_cast(read_buffer_); - switch (type) { - case Z_INVALID: - dbgln(reinterpret_cast(read_buffer_)); - break; - case DENALI_READ: { - DenaliRead* read_req = reinterpret_cast(read_buffer_); - uint64_t memcap = 0; - RET_ERR(HandleRead(*read_req, reply_port)); - break; +glcr::ErrorCode DenaliServer::HandleRequest(RequestContext& request, + ResponseContext& response) { + switch (request.request_id()) { + case DENALI_READ: { + DenaliReadRequest* req = 0; + glcr::ErrorCode err = request.As(&req); + if (err != glcr::OK) { + response.WriteError(err); } - default: - dbgln("Invalid message type."); - return glcr::UNIMPLEMENTED; + err = HandleRead(req, response); + if (err != glcr::OK) { + response.WriteError(err); + } + break; } + default: + response.WriteError(glcr::UNIMPLEMENTED); + break; } + return glcr::OK; } -glcr::ErrorCode DenaliServer::HandleRead(const DenaliRead& read, - z_cap_t reply_port) { - ASSIGN_OR_RETURN(AhciDevice * device, driver_.GetDevice(read.device_id)); +glcr::ErrorCode DenaliServer::HandleRead(DenaliReadRequest* request, + ResponseContext& context) { + ASSIGN_OR_RETURN(AhciDevice * device, driver_.GetDevice(request->device_id)); - device->IssueCommand( - new DmaReadCommand(read.lba, read.size, ::HandleResponse, reply_port)); + device->IssueCommand(new DmaReadCommand(request->lba, request->size, + ::HandleResponse, context)); return glcr::OK; } -void DenaliServer::HandleResponse(z_cap_t reply_port, uint64_t lba, +void DenaliServer::HandleResponse(ResponseContext& response, uint64_t lba, uint64_t size, z_cap_t mem) { DenaliReadResponse resp{ .device_id = 0, .lba = lba, .size = size, }; - check(ZReplyPortSend(reply_port, sizeof(resp), &resp, 1, &mem)); + check(response.WriteStructWithCap(resp, mem)); } diff --git a/sys/denali/denali_server.h b/sys/denali/denali_server.h index 2cef169..677185d 100644 --- a/sys/denali/denali_server.h +++ b/sys/denali/denali_server.h @@ -6,21 +6,26 @@ #include "ahci/ahci_driver.h" #include "denali/denali.h" -class DenaliServer { +class DenaliServer : public EndpointServer { public: - DenaliServer(glcr::UniquePtr server, AhciDriver& driver); + static glcr::ErrorOr> Create( + AhciDriver& driver); - glcr::ErrorCode RunServer(); - - void HandleResponse(z_cap_t reply_port, uint64_t lba, uint64_t size, + void HandleResponse(ResponseContext& response, uint64_t lba, uint64_t size, z_cap_t cap); + virtual glcr::ErrorCode HandleRequest(RequestContext& request, + ResponseContext& response) override; + private: static const uint64_t kBuffSize = 1024; - glcr::UniquePtr server_; uint8_t read_buffer_[kBuffSize]; AhciDriver& driver_; - glcr::ErrorCode HandleRead(const DenaliRead& read, z_cap_t reply_port); + DenaliServer(z_cap_t endpoint_cap, AhciDriver& driver) + : EndpointServer(endpoint_cap), driver_(driver) {} + + glcr::ErrorCode HandleRead(DenaliReadRequest* request, + ResponseContext& context); }; diff --git a/sys/denali/include/denali/denali.h b/sys/denali/include/denali/denali.h index 806592a..398e68c 100644 --- a/sys/denali/include/denali/denali.h +++ b/sys/denali/include/denali/denali.h @@ -5,7 +5,7 @@ #define DENALI_INVALID 0 #define DENALI_READ 100 -struct DenaliRead { +struct DenaliReadRequest { uint64_t request_type = DENALI_READ; uint64_t device_id; diff --git a/sys/yellowstone/yellowstone.cpp b/sys/yellowstone/yellowstone.cpp index 6d3a3fe..d7ec505 100644 --- a/sys/yellowstone/yellowstone.cpp +++ b/sys/yellowstone/yellowstone.cpp @@ -13,16 +13,15 @@ uint64_t main(uint64_t port_cap) { check(ParseInitPort(port_cap)); ASSIGN_OR_RETURN(auto server, YellowstoneServer::Create()); - Thread server_thread = server->RunServer(); Thread registration_thread = server->RunRegistration(); uint64_t vaddr; check(ZAddressSpaceMap(gSelfVmasCap, 0, gBootDenaliVmmoCap, &vaddr)); ASSIGN_OR_RETURN(glcr::UniquePtr client, - server->GetServerClient()); + server->CreateClient()); check(SpawnProcessFromElfRegion(vaddr, glcr::Move(client))); - check(server_thread.Join()); + check(server->RunServer()); check(registration_thread.Join()); dbgln("Yellowstone Finished Successfully."); return 0; diff --git a/sys/yellowstone/yellowstone_server.cpp b/sys/yellowstone/yellowstone_server.cpp index 244650c..e1d965f 100644 --- a/sys/yellowstone/yellowstone_server.cpp +++ b/sys/yellowstone/yellowstone_server.cpp @@ -12,11 +12,6 @@ namespace { -void ServerThreadBootstrap(void* yellowstone) { - dbgln("Yellowstone server starting"); - static_cast(yellowstone)->ServerThread(); -} - void RegistrationThreadBootstrap(void* yellowstone) { dbgln("Yellowstone registration starting"); static_cast(yellowstone)->RegistrationThread(); @@ -40,71 +35,60 @@ glcr::ErrorOr HandleDenaliRegistration(z_cap_t endpoint_cap) { } // namespace glcr::ErrorOr> YellowstoneServer::Create() { - ASSIGN_OR_RETURN(auto server, EndpointServer::Create()); + z_cap_t cap; + RET_ERR(ZEndpointCreate(&cap)); ASSIGN_OR_RETURN(PortServer port, PortServer::Create()); - return glcr::UniquePtr( - new YellowstoneServer(glcr::Move(server), port)); + return glcr::UniquePtr(new YellowstoneServer(cap, port)); } -YellowstoneServer::YellowstoneServer(glcr::UniquePtr server, - PortServer port) - : server_(glcr::Move(server)), register_port_(port) {} - -Thread YellowstoneServer::RunServer() { - return Thread(ServerThreadBootstrap, this); -} +YellowstoneServer::YellowstoneServer(z_cap_t endpoint_cap, PortServer port) + : EndpointServer(endpoint_cap), register_port_(port) {} Thread YellowstoneServer::RunRegistration() { return Thread(RegistrationThreadBootstrap, this); } -void YellowstoneServer::ServerThread() { - while (true) { - uint64_t num_bytes = kBufferSize; - uint64_t reply_port_cap; - // FIXME: Error handling. - check(server_->Recieve(&num_bytes, server_buffer_, &reply_port_cap)); - YellowstoneGetReq* req = - reinterpret_cast(server_buffer_); - switch (req->type) { - case kYellowstoneGetAhci: { - dbgln("Yellowstone::GetAHCI"); - YellowstoneGetAhciResp resp{ - .type = kYellowstoneGetAhci, - .ahci_phys_offset = pci_reader_.GetAhciPhysical(), - }; - check(ZReplyPortSend(reply_port_cap, sizeof(resp), &resp, 0, nullptr)); - break; - } - case kYellowstoneGetRegistration: { - dbgln("Yellowstone::GetRegistration"); - auto client_or = register_port_.CreateClient(); - if (!client_or.ok()) { - check(client_or.error()); - } - YellowstoneGetRegistrationResp resp; - uint64_t reg_cap = client_or.value().cap(); - check(ZReplyPortSend(reply_port_cap, sizeof(resp), &resp, 1, ®_cap)); - break; - } - case kYellowstoneGetDenali: { - dbgln("Yellowstone::GetDenali"); - z_cap_t new_denali; - check(ZCapDuplicate(denali_cap_, &new_denali)); - YellowstoneGetDenaliResp resp{ - .type = kYellowstoneGetDenali, - .device_id = device_id_, - .lba_offset = lba_offset_, - }; - check(ZReplyPortSend(reply_port_cap, sizeof(resp), &resp, 1, - &new_denali)); - break; - } - default: - dbgln("Unknown request type: %x", req->type); - break; +glcr::ErrorCode YellowstoneServer::HandleRequest(RequestContext& request, + ResponseContext& response) { + switch (request.request_id()) { + case kYellowstoneGetAhci: { + dbgln("Yellowstone::GetAHCI"); + YellowstoneGetAhciResp resp{ + .type = kYellowstoneGetAhci, + .ahci_phys_offset = pci_reader_.GetAhciPhysical(), + }; + RET_ERR(response.WriteStruct(resp)); + break; } + case kYellowstoneGetRegistration: { + dbgln("Yellowstone::GetRegistration"); + auto client_or = register_port_.CreateClient(); + if (!client_or.ok()) { + check(client_or.error()); + } + YellowstoneGetRegistrationResp resp; + uint64_t reg_cap = client_or.value().cap(); + RET_ERR(response.WriteStructWithCap(resp, reg_cap)); + break; + } + case kYellowstoneGetDenali: { + dbgln("Yellowstone::GetDenali"); + z_cap_t new_denali; + check(ZCapDuplicate(denali_cap_, &new_denali)); + YellowstoneGetDenaliResp resp{ + .type = kYellowstoneGetDenali, + .device_id = device_id_, + .lba_offset = lba_offset_, + }; + RET_ERR(response.WriteStructWithCap(resp, new_denali)); + break; + } + default: + dbgln("Unknown request type: %x", request.request_id()); + return glcr::UNIMPLEMENTED; + break; } + return glcr::OK; } void YellowstoneServer::RegistrationThread() { @@ -127,7 +111,7 @@ void YellowstoneServer::RegistrationThread() { uint64_t vaddr; check( ZAddressSpaceMap(gSelfVmasCap, 0, gBootVictoriaFallsVmmoCap, &vaddr)); - auto client_or = GetServerClient(); + auto client_or = CreateClient(); if (!client_or.ok()) { check(client_or.error()); } @@ -144,8 +128,3 @@ void YellowstoneServer::RegistrationThread() { dbgln(name.cstr()); } } - -glcr::ErrorOr> -YellowstoneServer::GetServerClient() { - return server_->CreateClient(); -} diff --git a/sys/yellowstone/yellowstone_server.h b/sys/yellowstone/yellowstone_server.h index 609f238..bd643dc 100644 --- a/sys/yellowstone/yellowstone_server.h +++ b/sys/yellowstone/yellowstone_server.h @@ -8,20 +8,19 @@ #include "hw/pcie.h" -class YellowstoneServer { +class YellowstoneServer : public EndpointServer { public: static glcr::ErrorOr> Create(); - Thread RunServer(); Thread RunRegistration(); - void ServerThread(); void RegistrationThread(); - glcr::ErrorOr> GetServerClient(); + virtual glcr::ErrorCode HandleRequest(RequestContext& request, + ResponseContext& response) override; private: - glcr::UniquePtr server_; + // FIXME: Separate this to its own service. PortServer register_port_; static const uint64_t kBufferSize = 128; @@ -36,5 +35,5 @@ class YellowstoneServer { PciReader pci_reader_; - YellowstoneServer(glcr::UniquePtr server, PortServer port); + YellowstoneServer(z_cap_t endpoint_cap, PortServer port); };