diff --git a/Cargo.toml b/Cargo.toml index 31be78b..2d050c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,6 +85,7 @@ http02 = { package = "http", version = "0.2.12", optional = true } # Axum axum = { version = "0.8.4", features = [ "json", + "original-uri", ], default-features = false, optional = true } tower = { version = "0.5.2", optional = true } diff --git a/examples/local_federation/axum/http.rs b/examples/local_federation/axum/http.rs index e010a73..4cbe319 100644 --- a/examples/local_federation/axum/http.rs +++ b/examples/local_federation/axum/http.rs @@ -29,6 +29,7 @@ pub fn listen(config: &FederationConfig) -> Result<(), Error> { let hostname = config.domain(); info!("Listening with axum on {hostname}"); let config = config.clone(); + let app = Router::new() .route("/{user}/inbox", post(http_post_user_inbox)) .route("/{user}", get(http_get_user)) diff --git a/src/axum/inbox.rs b/src/axum/inbox.rs index 84567e0..29b8e72 100644 --- a/src/axum/inbox.rs +++ b/src/axum/inbox.rs @@ -11,11 +11,11 @@ use crate::{ }; use axum::{ body::Body, - extract::FromRequest, + extract::{FromRequest, FromRequestParts, OriginalUri}, http::{Request, StatusCode}, response::{IntoResponse, Response}, }; -use http::{HeaderMap, Method, Uri}; +use http::{HeaderMap, Method}; use serde::de::DeserializeOwned; use tracing::debug; @@ -53,7 +53,7 @@ where pub struct ActivityData { headers: HeaderMap, method: Method, - uri: Uri, + uri: OriginalUri, body: Vec, } @@ -63,8 +63,14 @@ where { type Rejection = Response; - async fn from_request(req: Request, _state: &S) -> Result { - let (parts, body) = req.into_parts(); + async fn from_request(req: Request, state: &S) -> Result { + let (mut parts, body) = req.into_parts(); + + // take the full URI to handle nested routers + // OriginalUri::from_request_parts has an Infallible error type + let uri = OriginalUri::from_request_parts(&mut parts, state) + .await + .expect("infallible"); // this wont work if the body is an long running stream let bytes = axum::body::to_bytes(body, usize::MAX) @@ -74,7 +80,7 @@ where Ok(Self { headers: parts.headers, method: parts.method, - uri: parts.uri, + uri, body: bytes.to_vec(), }) }