From 376039869d9b8600499ff0f14d35e674336e1fd8 Mon Sep 17 00:00:00 2001 From: Drew Galbraith Date: Sat, 6 Jul 2024 00:50:54 -0700 Subject: [PATCH] Move to tower middleware for handling CORS. --- Cargo.lock | 19 +++++++++++++++++++ Cargo.toml | 3 +++ src/main.rs | 18 ++++++++++++------ src/routes/tasks.rs | 34 +++------------------------------- 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7f1c6f0..37e7bf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -186,10 +186,13 @@ version = "0.1.0" dependencies = [ "axum", "dotenvy", + "http", "serde", "serde_json", "sqlx", "tokio", + "tower", + "tower-http", ] [[package]] @@ -1552,6 +1555,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags 2.5.0", + "bytes", + "http", + "http-body", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.2" diff --git a/Cargo.toml b/Cargo.toml index 7e6ad3c..4272624 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,10 @@ edition = "2021" [dependencies] axum = "0.7" dotenvy = "0.15" +http = "1.1.0" serde = "1.0" serde_json = "1.0" sqlx = { version = "0.7", features=[ "runtime-tokio", "sqlite" ] } tokio = { version = "1.38", features=[ "full" ] } +tower = "0.4.13" +tower-http = { version = "0.5.2", features=[ "cors" ] } diff --git a/src/main.rs b/src/main.rs index 0a7c01e..9dd5205 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,10 @@ use axum::http::{StatusCode, Uri}; use axum::response::IntoResponse; use axum::routing::{delete, get}; use dotenvy::dotenv; +use http::{HeaderName, Method}; use sqlx::SqlitePool; +use tower::ServiceBuilder; +use tower_http::cors::{Any, CorsLayer}; #[tokio::main] async fn main() { @@ -24,13 +27,16 @@ async fn main() { .fallback(handle404) .route( "/tasks", - get(routes::tasks::list) - .post(routes::tasks::create) - .options(routes::tasks::options), + get(routes::tasks::list).post(routes::tasks::create), ) - .route( - "/tasks/:task_id", - delete(routes::tasks::delete).options(routes::tasks::options), + .route("/tasks/:task_id", delete(routes::tasks::delete)) + .layer( + ServiceBuilder::new().layer( + CorsLayer::new() + .allow_headers([HeaderName::from_lowercase(b"content-type").unwrap()]) + .allow_methods([Method::GET, Method::POST, Method::DELETE]) + .allow_origin(Any), + ), ) .with_state(state); diff --git a/src/routes/tasks.rs b/src/routes/tasks.rs index cd5ec41..f66bbf2 100644 --- a/src/routes/tasks.rs +++ b/src/routes/tasks.rs @@ -10,11 +10,7 @@ use crate::{global::AppState, models::Task}; pub async fn list(state: State) -> impl IntoResponse { let tasks = Task::all(&state.db_pool).await.unwrap(); - ( - StatusCode::OK, - [(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")], - serde_json::to_string(&tasks).unwrap(), - ) + (StatusCode::OK, serde_json::to_string(&tasks).unwrap()) } #[derive(Deserialize)] @@ -27,7 +23,6 @@ pub async fn create(state: State, Json(req): Json) -> impl In if req.title.is_empty() { return ( StatusCode::BAD_REQUEST, - [(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")], serde_json::to_string(&FormErrorResponse { message: "Failed to validate new task".to_string(), field_errors: [( @@ -45,34 +40,11 @@ pub async fn create(state: State, Json(req): Json) -> impl In let new_task = task.insert(&state.db_pool).await.unwrap(); - ( - StatusCode::OK, - [(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")], - serde_json::to_string(&new_task).unwrap(), - ) -} - -pub async fn options() -> impl IntoResponse { - ( - StatusCode::OK, - [ - (header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"), - ( - header::ACCESS_CONTROL_ALLOW_METHODS, - "POST, GET, OPTIONS, DELETE", - ), - (header::ACCESS_CONTROL_ALLOW_HEADERS, "Content-Type"), - (header::ACCESS_CONTROL_MAX_AGE, "86400"), - ], - ) + (StatusCode::OK, serde_json::to_string(&new_task).unwrap()) } pub async fn delete(state: State, Path(id): Path) -> impl IntoResponse { Task::delete(&state.db_pool, id).await.unwrap(); - ( - StatusCode::OK, - [(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")], - "", - ) + (StatusCode::OK, "") }