Move to tower middleware for handling CORS.

This commit is contained in:
Drew Galbraith 2024-07-06 00:50:54 -07:00
parent 92ab547b0b
commit 376039869d
4 changed files with 37 additions and 37 deletions

19
Cargo.lock generated
View File

@ -186,10 +186,13 @@ version = "0.1.0"
dependencies = [ dependencies = [
"axum", "axum",
"dotenvy", "dotenvy",
"http",
"serde", "serde",
"serde_json", "serde_json",
"sqlx", "sqlx",
"tokio", "tokio",
"tower",
"tower-http",
] ]
[[package]] [[package]]
@ -1552,6 +1555,22 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "tower-http"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
dependencies = [
"bitflags 2.5.0",
"bytes",
"http",
"http-body",
"http-body-util",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]] [[package]]
name = "tower-layer" name = "tower-layer"
version = "0.3.2" version = "0.3.2"

View File

@ -6,7 +6,10 @@ edition = "2021"
[dependencies] [dependencies]
axum = "0.7" axum = "0.7"
dotenvy = "0.15" dotenvy = "0.15"
http = "1.1.0"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
sqlx = { version = "0.7", features=[ "runtime-tokio", "sqlite" ] } sqlx = { version = "0.7", features=[ "runtime-tokio", "sqlite" ] }
tokio = { version = "1.38", features=[ "full" ] } tokio = { version = "1.38", features=[ "full" ] }
tower = "0.4.13"
tower-http = { version = "0.5.2", features=[ "cors" ] }

View File

@ -6,7 +6,10 @@ use axum::http::{StatusCode, Uri};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::{delete, get}; use axum::routing::{delete, get};
use dotenvy::dotenv; use dotenvy::dotenv;
use http::{HeaderName, Method};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use tower::ServiceBuilder;
use tower_http::cors::{Any, CorsLayer};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -24,13 +27,16 @@ async fn main() {
.fallback(handle404) .fallback(handle404)
.route( .route(
"/tasks", "/tasks",
get(routes::tasks::list) get(routes::tasks::list).post(routes::tasks::create),
.post(routes::tasks::create)
.options(routes::tasks::options),
) )
.route( .route("/tasks/:task_id", delete(routes::tasks::delete))
"/tasks/:task_id", .layer(
delete(routes::tasks::delete).options(routes::tasks::options), ServiceBuilder::new().layer(
CorsLayer::new()
.allow_headers([HeaderName::from_lowercase(b"content-type").unwrap()])
.allow_methods([Method::GET, Method::POST, Method::DELETE])
.allow_origin(Any),
),
) )
.with_state(state); .with_state(state);

View File

@ -10,11 +10,7 @@ use crate::{global::AppState, models::Task};
pub async fn list(state: State<AppState>) -> impl IntoResponse { pub async fn list(state: State<AppState>) -> impl IntoResponse {
let tasks = Task::all(&state.db_pool).await.unwrap(); let tasks = Task::all(&state.db_pool).await.unwrap();
( (StatusCode::OK, serde_json::to_string(&tasks).unwrap())
StatusCode::OK,
[(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
serde_json::to_string(&tasks).unwrap(),
)
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -27,7 +23,6 @@ pub async fn create(state: State<AppState>, Json(req): Json<NewTask>) -> impl In
if req.title.is_empty() { if req.title.is_empty() {
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
[(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
serde_json::to_string(&FormErrorResponse { serde_json::to_string(&FormErrorResponse {
message: "Failed to validate new task".to_string(), message: "Failed to validate new task".to_string(),
field_errors: [( field_errors: [(
@ -45,34 +40,11 @@ pub async fn create(state: State<AppState>, Json(req): Json<NewTask>) -> impl In
let new_task = task.insert(&state.db_pool).await.unwrap(); let new_task = task.insert(&state.db_pool).await.unwrap();
( (StatusCode::OK, serde_json::to_string(&new_task).unwrap())
StatusCode::OK,
[(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
serde_json::to_string(&new_task).unwrap(),
)
}
pub async fn options() -> impl IntoResponse {
(
StatusCode::OK,
[
(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"),
(
header::ACCESS_CONTROL_ALLOW_METHODS,
"POST, GET, OPTIONS, DELETE",
),
(header::ACCESS_CONTROL_ALLOW_HEADERS, "Content-Type"),
(header::ACCESS_CONTROL_MAX_AGE, "86400"),
],
)
} }
pub async fn delete(state: State<AppState>, Path(id): Path<i64>) -> impl IntoResponse { pub async fn delete(state: State<AppState>, Path(id): Path<i64>) -> impl IntoResponse {
Task::delete(&state.db_pool, id).await.unwrap(); Task::delete(&state.db_pool, id).await.unwrap();
( (StatusCode::OK, "")
StatusCode::OK,
[(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")],
"",
)
} }