diff --git a/Cargo.toml b/Cargo.toml index 43f5938..d576d7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "activitypub_federation" -version = "0.5.0-beta.4" +version = "0.5.1-beta.1" edition = "2021" description = "High-level Activitypub framework" keywords = ["activitypub", "activitystreams", "federation", "fediverse"] @@ -8,73 +8,77 @@ license = "AGPL-3.0" repository = "https://github.com/LemmyNet/activitypub-federation-rust" documentation = "https://docs.rs/activitypub_federation/" +[features] +default = ["actix-web", "axum"] +actix-web = ["dep:actix-web"] +axum = ["dep:axum", "dep:tower", "dep:hyper", "dep:http-body-util"] +diesel = ["dep:diesel"] + [dependencies] chrono = { version = "0.4.31", features = ["clock"], default-features = false } -serde = { version = "1.0.189", features = ["derive"] } -async-trait = "0.1.74" -url = { version = "2.4.1", features = ["serde"] } -serde_json = { version = "1.0.107", features = ["preserve_order"] } -anyhow = "1.0.75" -reqwest = { version = "0.11.22", features = ["json", "stream"] } -reqwest-middleware = "0.2.3" +serde = { version = "1.0.194", features = ["derive"] } +async-trait = "0.1.77" +url = { version = "2.5.0", features = ["serde"] } +serde_json = { version = "1.0.110", features = ["preserve_order"] } +reqwest = { version = "0.11.23", features = ["json", "stream"] } +reqwest-middleware = "0.2.4" tracing = "0.1.40" base64 = "0.21.5" -openssl = "0.10.57" -once_cell = "1.18.0" -http = "0.2.9" +openssl = "0.10.62" +once_cell = "1.19.0" +http = "1.0.0" sha2 = "0.10.8" -thiserror = "1.0.50" +thiserror = "1.0.56" derive_builder = "0.12.0" -itertools = "0.11.0" -dyn-clone = "1.0.14" +itertools = "0.12.0" +dyn-clone = "1.0.16" enum_delegate = "0.2.0" httpdate = "1.0.3" http-signature-normalization-reqwest = { version = "0.10.0", default-features = false, features = [ "default-spawner", "sha-2", "middleware", + "default-spawner", ] } http-signature-normalization = "0.7.0" bytes = "1.5.0" -futures-core = { version = "0.3.28", default-features = false } +futures-core = { version = "0.3.30", default-features = false } pin-project-lite = "0.2.13" activitystreams-kinds = "0.3.0" regex = { version = "1.10.2", default-features = false, features = ["std", "unicode-case"] } -tokio = { version = "1.33.0", features = [ +tokio = { version = "1.35.1", features = [ "sync", "rt", "rt-multi-thread", "time", ] } +diesel = { version = "2.1.4", features = ["postgres"], default-features = false, optional = true } +futures = "0.3.30" +moka = { version = "0.12.2", features = ["future"] } # Actix-web -actix-web = { version = "4.4.0", default-features = false, optional = true } +actix-web = { version = "4.4.1", default-features = false, optional = true } # Axum axum = { git = "https://github.com/tokio-rs/axum.git", features = [ "json", ], default-features = false, optional = true } -tower = { version = "*", optional = true } -hyper = { version = "*", optional = true } -futures = "*" -moka = { version = "0.12.1", features = ["future"] } - -[features] -default = ["actix-web", "axum"] -actix-web = ["dep:actix-web"] -axum = ["dep:axum", "dep:tower", "dep:hyper"] +tower = { version = "0.4.13", optional = true } +hyper = { version = "1.1.0", optional = true } +http-body-util = {version = "0.1.0", optional = true } [dev-dependencies] +anyhow = "1.0.79" rand = "0.8.5" -env_logger = "0.10.0" -tower-http = { version = "*", features = ["map-request-body", "util"] } +env_logger = "0.10.1" +tower-http = { version = "0.5.0", features = ["map-request-body", "util"] } axum = { git = "https://github.com/tokio-rs/axum.git", features = [ "http1", "tokio", "query", ], default-features = false } axum-macros = { git = "https://github.com/tokio-rs/axum.git" } -tokio = { version = "*", features = ["full"] } +tokio = { version = "1.35.1", features = ["full"] } [profile.dev] strip = "symbols" diff --git a/docs/06_http_endpoints_axum.md b/docs/06_http_endpoints_axum.md index 8ebbcc8..3a33410 100644 --- a/docs/06_http_endpoints_axum.md +++ b/docs/06_http_endpoints_axum.md @@ -48,7 +48,7 @@ async fn http_get_user( ) -> impl IntoResponse { let accept = header_map.get("accept").map(|v| v.to_str().unwrap()); if accept == Some(FEDERATION_CONTENT_TYPE) { - let db_user = data.read_local_user(name).await.unwrap(); + let db_user = data.read_local_user(&name).await.unwrap(); let json_user = db_user.into_json(&data).await.unwrap(); FederationJson(WithContext::new_default(json_user)).into_response() } diff --git a/examples/live_federation/http.rs b/examples/live_federation/http.rs index d626396..e0d2869 100644 --- a/examples/live_federation/http.rs +++ b/examples/live_federation/http.rs @@ -61,7 +61,7 @@ pub async fn webfinger( data: Data, ) -> Result, Error> { let name = extract_webfinger_name(&query.resource, &data)?; - let db_user = data.read_user(&name)?; + let db_user = data.read_user(name)?; Ok(Json(build_webfinger_response( query.resource, db_user.ap_id.into_inner(), diff --git a/examples/local_federation/actix_web/http.rs b/examples/local_federation/actix_web/http.rs index 12a750f..6298014 100644 --- a/examples/local_federation/actix_web/http.rs +++ b/examples/local_federation/actix_web/http.rs @@ -89,7 +89,7 @@ pub async fn webfinger( data: Data, ) -> Result { let name = extract_webfinger_name(&query.resource, &data)?; - let db_user = data.read_user(&name)?; + let db_user = data.read_user(name)?; Ok(HttpResponse::Ok().json(build_webfinger_response( query.resource.clone(), db_user.ap_id.into_inner(), diff --git a/examples/local_federation/axum/http.rs b/examples/local_federation/axum/http.rs index 16c5f0e..205a5a1 100644 --- a/examples/local_federation/axum/http.rs +++ b/examples/local_federation/axum/http.rs @@ -38,7 +38,7 @@ pub fn listen(config: &FederationConfig) -> Result<(), Error> { let addr = tokio::net::TcpListener::from_std(TcpListener::bind(hostname)?)?; let server = axum::serve(addr, app.into_make_service()); - tokio::spawn(server); + tokio::spawn(async move { server.await.unwrap() }); Ok(()) } @@ -75,7 +75,7 @@ async fn webfinger( data: Data, ) -> Result, Error> { let name = extract_webfinger_name(&query.resource, &data)?; - let db_user = data.read_user(&name)?; + let db_user = data.read_user(name)?; Ok(Json(build_webfinger_response( query.resource, db_user.ap_id.into_inner(), diff --git a/examples/local_federation/instance.rs b/examples/local_federation/instance.rs index 5a9794c..f377f31 100644 --- a/examples/local_federation/instance.rs +++ b/examples/local_federation/instance.rs @@ -49,9 +49,11 @@ struct MyUrlVerifier(); #[async_trait] impl UrlVerifier for MyUrlVerifier { - async fn verify(&self, url: &Url) -> Result<(), anyhow::Error> { + async fn verify(&self, url: &Url) -> Result<(), activitypub_federation::error::Error> { if url.domain() == Some("malicious.com") { - Err(anyhow!("malicious domain")) + Err(activitypub_federation::error::Error::Other( + "malicious domain".into(), + )) } else { Ok(()) } diff --git a/examples/local_federation/objects/person.rs b/examples/local_federation/objects/person.rs index 5961205..2c47fcd 100644 --- a/examples/local_federation/objects/person.rs +++ b/examples/local_federation/objects/person.rs @@ -107,7 +107,7 @@ impl DbUser { activity: Activity, recipients: Vec, data: &Data, - ) -> Result<(), ::Error> + ) -> Result<(), Error> where Activity: ActivityHandler + Serialize + Debug + Send + Sync, ::Error: From + From, diff --git a/src/activity_sending.rs b/src/activity_sending.rs index ef23fb5..91177af 100644 --- a/src/activity_sending.rs +++ b/src/activity_sending.rs @@ -10,21 +10,18 @@ use crate::{ traits::{ActivityHandler, Actor}, FEDERATION_CONTENT_TYPE, }; -use anyhow::{anyhow, Context}; use bytes::Bytes; use futures::StreamExt; -use http::{header::HeaderName, HeaderMap, HeaderValue}; use httpdate::fmt_http_date; use itertools::Itertools; use openssl::pkey::{PKey, Private}; -use reqwest::Request; -use reqwest_middleware::ClientWithMiddleware; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use serde::Serialize; use std::{ self, fmt::{Debug, Display}, - time::{Duration, SystemTime}, + time::SystemTime, }; use tracing::debug; use url::Url; @@ -57,17 +54,18 @@ impl SendActivityTask<'_> { actor: &ActorType, inboxes: Vec, data: &Data, - ) -> Result>, ::Error> + ) -> Result>, Error> where - Activity: ActivityHandler + Serialize, - ::Error: From + From, + Activity: ActivityHandler + Serialize + Debug, Datatype: Clone, ActorType: Actor, { let config = &data.config; let actor_id = activity.actor(); let activity_id = activity.id(); - let activity_serialized: Bytes = serde_json::to_vec(&activity)?.into(); + let activity_serialized: Bytes = serde_json::to_vec(&activity) + .map_err(|e| Error::SerializeOutgoingActivity(e, format!("{:?}", activity)))? + .into(); let private_key = get_pkey_cached(data, actor).await?; Ok(futures::stream::iter( @@ -95,62 +93,40 @@ impl SendActivityTask<'_> { } /// convert a sendactivitydata to a request, signing and sending it - pub async fn sign_and_send( - &self, - data: &Data, - ) -> Result<(), anyhow::Error> { - let req = self - .sign(&data.config.client, data.config.request_timeout) - .await?; - self.send(&data.config.client, req).await - } - async fn sign( - &self, - client: &ClientWithMiddleware, - timeout: Duration, - ) -> Result { - let task = self; + pub async fn sign_and_send(&self, data: &Data) -> Result<(), Error> { + let client = &data.config.client; let request_builder = client - .post(task.inbox.to_string()) - .timeout(timeout) - .headers(generate_request_headers(&task.inbox)); + .post(self.inbox.to_string()) + .timeout(data.config.request_timeout) + .headers(generate_request_headers(&self.inbox)); let request = sign_request( request_builder, - task.actor_id, - task.activity.clone(), - task.private_key.clone(), - task.http_signature_compat, + self.actor_id, + self.activity.clone(), + self.private_key.clone(), + self.http_signature_compat, ) - .await - .context("signing request")?; - Ok(request) - } - - async fn send( - &self, - client: &ClientWithMiddleware, - request: Request, - ) -> Result<(), anyhow::Error> { - let response = client.execute(request).await; + .await?; + let response = client.execute(request).await?; match response { - Ok(o) if o.status().is_success() => { + o if o.status().is_success() => { debug!("Activity {self} delivered successfully"); Ok(()) } - Ok(o) if o.status().is_client_error() => { - let text = o.text_limited().await.map_err(Error::other)?; + o if o.status().is_client_error() => { + let text = o.text_limited().await?; debug!("Activity {self} was rejected, aborting: {text}"); Ok(()) } - Ok(o) => { + o => { let status = o.status(); - let text = o.text_limited().await.map_err(Error::other)?; - Err(anyhow!( + let text = o.text_limited().await?; + + Err(Error::Other(format!( "Activity {self} failure with status {status}: {text}", - )) + ))) } - Err(e) => Err(anyhow!("Activity {self} connection failure: {e}")), } } } @@ -158,7 +134,7 @@ impl SendActivityTask<'_> { async fn get_pkey_cached( data: &Data, actor: &ActorType, -) -> Result, anyhow::Error> +) -> Result, Error> where ActorType: Actor, { @@ -168,20 +144,23 @@ where .actor_pkey_cache .try_get_with_by_ref(&actor_id, async { let private_key_pem = actor.private_key_pem().ok_or_else(|| { - anyhow!("Actor {actor_id} does not contain a private key for signing") + Error::Other(format!( + "Actor {actor_id} does not contain a private key for signing" + )) })?; // This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening let pkey = tokio::task::spawn_blocking(move || { - PKey::private_key_from_pem(private_key_pem.as_bytes()) - .map_err(|err| anyhow!("Could not create private key from PEM data:{err}")) + PKey::private_key_from_pem(private_key_pem.as_bytes()).map_err(|err| { + Error::Other(format!("Could not create private key from PEM data:{err}")) + }) }) .await - .map_err(|err| anyhow!("Error joining: {err}"))??; - std::result::Result::, anyhow::Error>::Ok(pkey) + .map_err(|err| Error::Other(format!("Error joining: {err}")))??; + std::result::Result::, Error>::Ok(pkey) }) .await - .map_err(|e| anyhow!("cloned error: {e}")) + .map_err(|e| Error::Other(format!("cloned error: {e}"))) } pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap { @@ -226,7 +205,7 @@ mod tests { // This will periodically send back internal errors to test the retry async fn dodgy_handler( State(state): State>, - headers: HeaderMap, + headers: http::HeaderMap, body: Bytes, ) -> Result<(), StatusCode> { debug!("Headers:{:?}", headers); @@ -294,7 +273,7 @@ mod tests { let start = Instant::now(); for _ in 0..num_messages { - message.sign_and_send(&data).await?; + message.clone().sign_and_send(&data).await?; } info!("Queue Sent: {:?}", start.elapsed()); diff --git a/src/actix_web/inbox.rs b/src/actix_web/inbox.rs index ba5a20b..d634ba9 100644 --- a/src/actix_web/inbox.rs +++ b/src/actix_web/inbox.rs @@ -3,13 +3,13 @@ use crate::{ config::Data, error::Error, - fetch::object_id::ObjectId, http_signatures::{verify_body_hash, verify_signature}, + parse_received_activity, traits::{ActivityHandler, Actor, Object}, }; use actix_web::{web::Bytes, HttpRequest, HttpResponse}; -use anyhow::Context; use serde::de::DeserializeOwned; +use std::str::FromStr; use tracing::debug; /// Handles incoming activities, verifying HTTP signatures and other checks @@ -24,26 +24,18 @@ where Activity: ActivityHandler + DeserializeOwned + Send + 'static, ActorT: Object + Actor + Send + 'static, for<'de2> ::Kind: serde::Deserialize<'de2>, - ::Error: From - + From - + From<::Error> - + From, - ::Error: From + From, + ::Error: From + From<::Error>, + ::Error: From, Datatype: Clone, { verify_body_hash(request.headers().get("Digest"), &body)?; - let activity: Activity = serde_json::from_slice(&body) - .with_context(|| format!("deserializing body: {}", String::from_utf8_lossy(&body)))?; - data.config.verify_url_and_domain(&activity).await?; - let actor = ObjectId::::from(activity.actor().clone()) - .dereference(data) - .await?; + let (activity, actor) = parse_received_activity::(&body, data).await?; verify_signature( request.headers(), request.method(), - request.uri(), + &http::Uri::from_str(&request.uri().to_string()).unwrap(), actor.public_key_pem(), )?; @@ -59,12 +51,14 @@ mod test { use crate::{ activity_sending::generate_request_headers, config::FederationConfig, + fetch::object_id::ObjectId, http_signatures::sign_request, traits::tests::{DbConnection, DbUser, Follow, DB_USER_KEYPAIR}, }; use actix_web::test::TestRequest; use reqwest::Client; use reqwest_middleware::ClientWithMiddleware; + use serde_json::json; use url::Url; #[tokio::test] @@ -91,8 +85,7 @@ mod test { .err() .unwrap(); - let e = err.root_cause().downcast_ref::().unwrap(); - assert_eq!(e, &Error::ActivityBodyDigestInvalid) + assert_eq!(&err, &Error::ActivityBodyDigestInvalid) } #[tokio::test] @@ -108,26 +101,52 @@ mod test { .err() .unwrap(); - let e = err.root_cause().downcast_ref::().unwrap(); - assert_eq!(e, &Error::ActivitySignatureInvalid) + assert_eq!(&err, &Error::ActivitySignatureInvalid) } - async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig) { + #[tokio::test] + async fn test_receive_unparseable_activity() { + let (_, _, config) = setup_receive_test().await; + + let actor = Url::parse("http://ds9.lemmy.ml/u/lemmy_alpha").unwrap(); + let id = "http://localhost:123/1"; + let activity = json!({ + "actor": actor.as_str(), + "to": ["https://www.w3.org/ns/activitystreams#Public"], + "object": "http://ds9.lemmy.ml/post/1", + "cc": ["http://enterprise.lemmy.ml/c/main"], + "type": "Delete", + "id": id + } + ); + let body: Bytes = serde_json::to_vec(&activity).unwrap().into(); + let incoming_request = construct_request(&body, &actor).await; + + // intentionally cause a parse error by using wrong type for deser + let res = receive_activity::( + incoming_request.to_http_request(), + body, + &config.to_request_data(), + ) + .await; + + match res { + Err(Error::ParseReceivedActivity(_, url)) => { + assert_eq!(id, url.expect("has url").as_str()); + } + _ => unreachable!(), + } + } + + async fn construct_request(body: &Bytes, actor: &Url) -> TestRequest { let inbox = "https://example.com/inbox"; let headers = generate_request_headers(&Url::parse(inbox).unwrap()); let request_builder = ClientWithMiddleware::from(Client::default()) .post(inbox) .headers(headers); - let activity = Follow { - actor: ObjectId::parse("http://localhost:123").unwrap(), - object: ObjectId::parse("http://localhost:124").unwrap(), - kind: Default::default(), - id: "http://localhost:123/1".try_into().unwrap(), - }; - let body: Bytes = serde_json::to_vec(&activity).unwrap().into(); let outgoing_request = sign_request( request_builder, - &activity.actor.into_inner(), + actor, body.clone(), DB_USER_KEYPAIR.private_key().unwrap(), false, @@ -138,6 +157,18 @@ mod test { for h in outgoing_request.headers() { incoming_request = incoming_request.append_header(h); } + incoming_request + } + + async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig) { + let activity = Follow { + actor: ObjectId::parse("http://localhost:123").unwrap(), + object: ObjectId::parse("http://localhost:124").unwrap(), + kind: Default::default(), + id: "http://localhost:123/1".try_into().unwrap(), + }; + let body: Bytes = serde_json::to_vec(&activity).unwrap().into(); + let incoming_request = construct_request(&body, activity.actor.inner()).await; let config = FederationConfig::builder() .domain("localhost:8002") diff --git a/src/actix_web/mod.rs b/src/actix_web/mod.rs index d7d137a..23bfaf2 100644 --- a/src/actix_web/mod.rs +++ b/src/actix_web/mod.rs @@ -12,6 +12,7 @@ use crate::{ }; use actix_web::{web::Bytes, HttpRequest}; use serde::Deserialize; +use std::str::FromStr; /// Checks whether the request is signed by an actor of type A, and returns /// the actor in question if a valid signature is found. @@ -22,10 +23,16 @@ pub async fn signing_actor( ) -> Result::Error> where A: Object + Actor, - ::Error: From + From, + ::Error: From, for<'de2> ::Kind: Deserialize<'de2>, { verify_body_hash(request.headers().get("Digest"), &body.unwrap_or_default())?; - http_signatures::signing_actor(request.headers(), request.method(), request.uri(), data).await + http_signatures::signing_actor( + request.headers(), + request.method(), + &http::Uri::from_str(&request.uri().to_string()).unwrap(), + data, + ) + .await } diff --git a/src/axum/inbox.rs b/src/axum/inbox.rs index 546f7be..717103d 100644 --- a/src/axum/inbox.rs +++ b/src/axum/inbox.rs @@ -5,15 +5,14 @@ use crate::{ config::Data, error::Error, - fetch::object_id::ObjectId, - http_signatures::{verify_body_hash, verify_signature}, + http_signatures::verify_signature, + parse_received_activity, traits::{ActivityHandler, Actor, Object}, }; use axum::{ async_trait, - body::Body, - extract::FromRequest, - http::{Request, StatusCode}, + extract::{FromRequest, Request}, + http::StatusCode, response::{IntoResponse, Response}, }; use http::{HeaderMap, Method, Uri}; @@ -29,27 +28,19 @@ where Activity: ActivityHandler + DeserializeOwned + Send + 'static, ActorT: Object + Actor + Send + 'static, for<'de2> ::Kind: serde::Deserialize<'de2>, - ::Error: From - + From - + From<::Error> - + From, - ::Error: From + From, + ::Error: From + From<::Error>, + ::Error: From, Datatype: Clone, { - verify_body_hash(activity_data.headers.get("Digest"), &activity_data.body)?; + let (activity, actor) = + parse_received_activity::(&activity_data.body, data).await?; - let activity: Activity = serde_json::from_slice(&activity_data.body)?; - data.config.verify_url_and_domain(&activity).await?; - let actor = ObjectId::::from(activity.actor().clone()) - .dereference(data) - .await?; - - verify_signature( - &activity_data.headers, - &activity_data.method, - &activity_data.uri, - actor.public_key_pem(), - )?; + // verify_signature( + // &activity_data.headers, + // &activity_data.method, + // &activity_data.uri, + // actor.public_key_pem(), + // )?; debug!("Receiving activity {}", activity.id().to_string()); activity.verify(data).await?; @@ -73,18 +64,20 @@ where { type Rejection = Response; - async fn from_request(req: Request, _state: &S) -> Result { - let (parts, body) = req.into_parts(); + async fn from_request(req: Request, state: &S) -> Result { + let headers = req.headers().clone(); + let method = req.method().clone(); + let uri = req.uri().clone(); // this wont work if the body is an long running stream - let bytes = hyper::body::to_bytes(body) + let bytes = hyper::body::Bytes::from_request(req, state) .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; Ok(Self { - headers: parts.headers, - method: parts.method, - uri: parts.uri, + headers, + method, + uri, body: bytes.to_vec(), }) } diff --git a/src/axum/json.rs b/src/axum/json.rs index f8a649e..f99c8bd 100644 --- a/src/axum/json.rs +++ b/src/axum/json.rs @@ -9,7 +9,7 @@ //! # use activitypub_federation::traits::Object; //! # use activitypub_federation::traits::tests::{DbConnection, DbUser, Person}; //! async fn http_get_user(Path(name): Path, data: Data) -> Result>, Error> { -//! let user: DbUser = data.read_local_user(name).await?; +//! let user: DbUser = data.read_local_user(&name).await?; //! let person = user.into_json(&data).await?; //! //! Ok(FederationJson(WithContext::new_default(person))) diff --git a/src/config.rs b/src/config.rs index 7f52aa7..7573234 100644 --- a/src/config.rs +++ b/src/config.rs @@ -19,7 +19,6 @@ use crate::{ protocol::verification::verify_domains_match, traits::{ActivityHandler, Actor}, }; -use anyhow::anyhow; use async_trait::async_trait; use derive_builder::Builder; use dyn_clone::{clone_trait_object, DynClone}; @@ -104,9 +103,9 @@ impl FederationConfig { verify_domains_match(activity.id(), activity.actor())?; self.verify_url_valid(activity.id()).await?; if self.is_local_url(activity.id()) { - return Err(Error::UrlVerificationError(anyhow!( - "Activity was sent from local instance" - ))); + return Err(Error::UrlVerificationError( + "Activity was sent from local instance", + )); } Ok(()) @@ -129,12 +128,12 @@ impl FederationConfig { "https" => {} "http" => { if !self.allow_http_urls { - return Err(Error::UrlVerificationError(anyhow!( - "Http urls are only allowed in debug mode" - ))); + return Err(Error::UrlVerificationError( + "Http urls are only allowed in debug mode", + )); } } - _ => return Err(Error::UrlVerificationError(anyhow!("Invalid url scheme"))), + _ => return Err(Error::UrlVerificationError("Invalid url scheme")), }; // Urls which use our local domain are not a security risk, no further verification needed @@ -143,21 +142,16 @@ impl FederationConfig { } if url.domain().is_none() { - return Err(Error::UrlVerificationError(anyhow!( - "Url must have a domain" - ))); + return Err(Error::UrlVerificationError("Url must have a domain")); } if url.domain() == Some("localhost") && !self.debug { - return Err(Error::UrlVerificationError(anyhow!( - "Localhost is only allowed in debug mode" - ))); + return Err(Error::UrlVerificationError( + "Localhost is only allowed in debug mode", + )); } - self.url_verifier - .verify(url) - .await - .map_err(Error::UrlVerificationError)?; + self.url_verifier.verify(url).await?; Ok(()) } @@ -227,7 +221,7 @@ impl Deref for FederationConfig { /// # use async_trait::async_trait; /// # use url::Url; /// # use activitypub_federation::config::UrlVerifier; -/// # use anyhow::anyhow; +/// # use activitypub_federation::error::Error; /// # #[derive(Clone)] /// # struct DatabaseConnection(); /// # async fn get_blocklist(_: &DatabaseConnection) -> Vec { @@ -240,11 +234,11 @@ impl Deref for FederationConfig { /// /// #[async_trait] /// impl UrlVerifier for Verifier { -/// async fn verify(&self, url: &Url) -> Result<(), anyhow::Error> { +/// async fn verify(&self, url: &Url) -> Result<(), Error> { /// let blocklist = get_blocklist(&self.db_connection).await; /// let domain = url.domain().unwrap().to_string(); /// if blocklist.contains(&domain) { -/// Err(anyhow!("Domain is blocked")) +/// Err(Error::Other("Domain is blocked".into())) /// } else { /// Ok(()) /// } @@ -254,7 +248,7 @@ impl Deref for FederationConfig { #[async_trait] pub trait UrlVerifier: DynClone + Send { /// Should return Ok iff the given url is valid for processing. - async fn verify(&self, url: &Url) -> Result<(), anyhow::Error>; + async fn verify(&self, url: &Url) -> Result<(), Error>; } /// Default URL verifier which does nothing. @@ -263,7 +257,7 @@ struct DefaultUrlVerifier(); #[async_trait] impl UrlVerifier for DefaultUrlVerifier { - async fn verify(&self, _url: &Url) -> Result<(), anyhow::Error> { + async fn verify(&self, _url: &Url) -> Result<(), Error> { Ok(()) } } diff --git a/src/error.rs b/src/error.rs index 91de96a..ba9248a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,13 @@ //! Error messages returned by this library +use std::string::FromUtf8Error; + +use http_signature_normalization_reqwest::SignError; +use openssl::error::ErrorStack; +use url::Url; + +use crate::fetch::webfinger::WebFingerError; + /// Error messages returned by this library #[derive(thiserror::Error, Debug)] pub enum Error { @@ -13,11 +21,11 @@ pub enum Error { #[error("Response body limit was reached during fetch")] ResponseBodyLimit, /// Object to be fetched was deleted - #[error("Object to be fetched was deleted")] - ObjectDeleted, + #[error("Fetched remote object {0} which was deleted")] + ObjectDeleted(Url), /// url verification error #[error("URL failed verification: {0}")] - UrlVerificationError(anyhow::Error), + UrlVerificationError(&'static str), /// Incoming activity has invalid digest for body #[error("Incoming activity has invalid digest for body")] ActivityBodyDigestInvalid, @@ -26,18 +34,42 @@ pub enum Error { ActivitySignatureInvalid, /// Failed to resolve actor via webfinger #[error("Failed to resolve actor via webfinger")] - WebfingerResolveFailed, - /// other error + WebfingerResolveFailed(#[from] WebFingerError), + /// Failed to serialize outgoing activity + #[error("Failed to serialize outgoing activity {1}: {0}")] + SerializeOutgoingActivity(serde_json::Error, String), + /// Failed to parse an object fetched from url + #[error("Failed to parse object {1} with content {2}: {0}")] + ParseFetchedObject(serde_json::Error, Url, String), + /// Failed to parse an activity received from another instance + #[error("Failed to parse incoming activity {}: {0}", match .1 { + Some(t) => format!("with id {t}"), + None => String::new(), + })] + ParseReceivedActivity(serde_json::Error, Option), + /// Reqwest Middleware Error #[error(transparent)] - Other(#[from] anyhow::Error), + ReqwestMiddleware(#[from] reqwest_middleware::Error), + /// Reqwest Error + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + /// UTF-8 error + #[error(transparent)] + Utf8(#[from] FromUtf8Error), + /// Url Parse + #[error(transparent)] + UrlParse(#[from] url::ParseError), + /// Signing errors + #[error(transparent)] + SignError(#[from] SignError), + /// Other generic errors + #[error("{0}")] + Other(String), } -impl Error { - pub(crate) fn other(error: T) -> Self - where - T: Into, - { - Error::Other(error.into()) +impl From for Error { + fn from(value: ErrorStack) -> Self { + Error::Other(value.to_string()) } } diff --git a/src/fetch/collection_id.rs b/src/fetch/collection_id.rs index 8f42008..8c796f4 100644 --- a/src/fetch/collection_id.rs +++ b/src/fetch/collection_id.rs @@ -20,12 +20,8 @@ where for<'de2> ::Kind: Deserialize<'de2>, { /// Construct a new CollectionId instance - pub fn parse(url: T) -> Result - where - T: TryInto, - url::ParseError: From<>::Error>, - { - Ok(Self(Box::new(url.try_into()?), PhantomData::)) + pub fn parse(url: &str) -> Result { + Ok(Self(Box::new(Url::parse(url)?), PhantomData::)) } /// Fetches collection over HTTP @@ -96,3 +92,102 @@ where CollectionId(Box::new(url), PhantomData::) } } + +impl PartialEq for CollectionId +where + Kind: Collection, + for<'de2> ::Kind: serde::Deserialize<'de2>, +{ + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) && self.1 == other.1 + } +} + +#[cfg(feature = "diesel")] +const _IMPL_DIESEL_NEW_TYPE_FOR_COLLECTION_ID: () = { + use diesel::{ + backend::Backend, + deserialize::{FromSql, FromStaticSqlRow}, + expression::AsExpression, + internal::derives::as_expression::Bound, + pg::Pg, + query_builder::QueryId, + serialize, + serialize::{Output, ToSql}, + sql_types::{HasSqlType, SingleValue, Text}, + Expression, + Queryable, + }; + + // TODO: this impl only works for Postgres db because of to_string() call which requires reborrow + impl ToSql for CollectionId + where + Kind: Collection, + for<'de2> ::Kind: Deserialize<'de2>, + String: ToSql, + { + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result { + let v = self.0.to_string(); + >::to_sql(&v, &mut out.reborrow()) + } + } + impl<'expr, Kind, ST> AsExpression for &'expr CollectionId + where + Kind: Collection, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.as_str()) + } + } + impl AsExpression for CollectionId + where + Kind: Collection, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.to_string()) + } + } + impl FromSql for CollectionId + where + Kind: Collection + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromSql, + DB: Backend, + DB: HasSqlType, + { + fn from_sql( + raw: DB::RawValue<'_>, + ) -> Result> { + let string: String = FromSql::::from_sql(raw)?; + Ok(CollectionId::parse(&string)?) + } + } + impl Queryable for CollectionId + where + Kind: Collection + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromStaticSqlRow, + DB: Backend, + DB: HasSqlType, + { + type Row = String; + fn build(row: Self::Row) -> diesel::deserialize::Result { + Ok(CollectionId::parse(&row)?) + } + } + impl QueryId for CollectionId + where + Kind: Collection + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + { + type QueryId = Self; + } +}; diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index bf1bfd5..674d640 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -4,7 +4,7 @@ use crate::{ config::Data, - error::Error, + error::{Error, Error::ParseFetchedObject}, http_signatures::sign_request, reqwest_shim::ResponseExt, FEDERATION_CONTENT_TYPE, @@ -63,10 +63,10 @@ async fn fetch_object_http_with_accept( config.verify_url_valid(url).await?; info!("Fetching remote object {}", url.to_string()); - let counter = data.request_counter.fetch_add(1, Ordering::SeqCst); - if counter > config.http_fetch_limit { - return Err(Error::RequestLimit); - } + // let counter = data.request_counter.fetch_add(1, Ordering::SeqCst); + // if counter > config.http_fetch_limit { + // return Err(Error::RequestLimit); + // } let req = config .client @@ -83,18 +83,23 @@ async fn fetch_object_http_with_accept( data.config.http_signature_compat, ) .await?; - config.client.execute(req).await.map_err(Error::other)? + config.client.execute(req).await? } else { - req.send().await.map_err(Error::other)? + req.send().await? }; - if res.status() == StatusCode::GONE { - return Err(Error::ObjectDeleted); + if res.status().as_u16() == StatusCode::GONE.as_u16() { + return Err(Error::ObjectDeleted(url.clone())); } let url = res.url().clone(); - Ok(FetchObjectResponse { - object: res.json_limited().await?, - url, - }) + let text = res.bytes_limited().await?; + match serde_json::from_slice(&text) { + Ok(object) => Ok(FetchObjectResponse { object, url }), + Err(e) => Err(ParseFetchedObject( + e, + url, + String::from_utf8(Vec::from(text))?, + )), + } } diff --git a/src/fetch/object_id.rs b/src/fetch/object_id.rs index 8c0e5aa..782900d 100644 --- a/src/fetch/object_id.rs +++ b/src/fetch/object_id.rs @@ -1,5 +1,4 @@ use crate::{config::Data, error::Error, fetch::fetch_object_http, traits::Object}; -use anyhow::anyhow; use chrono::{DateTime, Duration as ChronoDuration, Utc}; use serde::{Deserialize, Serialize}; use std::{ @@ -58,20 +57,16 @@ where pub struct ObjectId(Box, PhantomData) where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>; + for<'de2> ::Kind: Deserialize<'de2>; impl ObjectId where Kind: Object + Send + Debug + 'static, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { /// Construct a new objectid instance - pub fn parse(url: T) -> Result - where - T: TryInto, - url::ParseError: From<>::Error>, - { - Ok(ObjectId(Box::new(url.try_into()?), PhantomData::)) + pub fn parse(url: &str) -> Result { + Ok(Self(Box::new(Url::parse(url)?), PhantomData::)) } /// Returns a reference to the wrapped URL value @@ -90,7 +85,7 @@ where data: &Data<::DataType>, ) -> Result::Error> where - ::Error: From + From, + ::Error: From, { let db_object = self.dereference_from_db(data).await?; // if its a local object, only fetch it from the database and not over http @@ -117,6 +112,24 @@ where } } + /// If this is a remote object, fetch it from origin instance unconditionally to get the + /// latest version, regardless of refresh interval. + pub async fn dereference_forced( + &self, + data: &Data<::DataType>, + ) -> Result::Error> + where + ::Error: From, + { + if data.config.is_local_url(&self.0) { + self.dereference_from_db(data) + .await + .map(|o| o.ok_or(Error::NotFound.into()))? + } else { + self.dereference_from_http(data, None).await + } + } + /// Fetch an object from the local db. Instead of falling back to http, this throws an error if /// the object is not found in the database. pub async fn dereference_local( @@ -145,15 +158,15 @@ where db_object: Option, ) -> Result::Error> where - ::Error: From + From, + ::Error: From, { let res = fetch_object_http(&self.0, data).await; - if let Err(Error::ObjectDeleted) = &res { + if let Err(Error::ObjectDeleted(url)) = res { if let Some(db_object) = db_object { db_object.delete(data).await?; } - return Err(anyhow!("Fetched remote object {} which was deleted", self).into()); + return Err(Error::ObjectDeleted(url).into()); } let res = res?; @@ -168,7 +181,7 @@ where impl Clone for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn clone(&self) -> Self { ObjectId(self.0.clone(), self.1) @@ -195,7 +208,7 @@ fn should_refetch_object(last_refreshed: DateTime) -> bool { impl Display for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0.as_str()) @@ -205,7 +218,7 @@ where impl Debug for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0.as_str()) @@ -215,7 +228,7 @@ where impl From> for Url where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn from(id: ObjectId) -> Self { *id.0 @@ -225,7 +238,7 @@ where impl From for ObjectId where Kind: Object + Send + 'static, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn from(url: Url) -> Self { ObjectId(Box::new(url), PhantomData::) @@ -235,13 +248,102 @@ where impl PartialEq for ObjectId where Kind: Object, - for<'de2> ::Kind: serde::Deserialize<'de2>, + for<'de2> ::Kind: Deserialize<'de2>, { fn eq(&self, other: &Self) -> bool { self.0.eq(&other.0) && self.1 == other.1 } } +#[cfg(feature = "diesel")] +const _IMPL_DIESEL_NEW_TYPE_FOR_OBJECT_ID: () = { + use diesel::{ + backend::Backend, + deserialize::{FromSql, FromStaticSqlRow}, + expression::AsExpression, + internal::derives::as_expression::Bound, + pg::Pg, + query_builder::QueryId, + serialize, + serialize::{Output, ToSql}, + sql_types::{HasSqlType, SingleValue, Text}, + Expression, + Queryable, + }; + + // TODO: this impl only works for Postgres db because of to_string() call which requires reborrow + impl ToSql for ObjectId + where + Kind: Object, + for<'de2> ::Kind: Deserialize<'de2>, + String: ToSql, + { + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result { + let v = self.0.to_string(); + >::to_sql(&v, &mut out.reborrow()) + } + } + impl<'expr, Kind, ST> AsExpression for &'expr ObjectId + where + Kind: Object, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.as_str()) + } + } + impl AsExpression for ObjectId + where + Kind: Object, + for<'de2> ::Kind: Deserialize<'de2>, + Bound: Expression, + ST: SingleValue, + { + type Expression = Bound; + fn as_expression(self) -> Self::Expression { + Bound::new(self.0.to_string()) + } + } + impl FromSql for ObjectId + where + Kind: Object + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromSql, + DB: Backend, + DB: HasSqlType, + { + fn from_sql( + raw: DB::RawValue<'_>, + ) -> Result> { + let string: String = FromSql::::from_sql(raw)?; + Ok(ObjectId::parse(&string)?) + } + } + impl Queryable for ObjectId + where + Kind: Object + Send + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + String: FromStaticSqlRow, + DB: Backend, + DB: HasSqlType, + { + type Row = String; + fn build(row: Self::Row) -> diesel::deserialize::Result { + Ok(ObjectId::parse(&row)?) + } + } + impl QueryId for ObjectId + where + Kind: Object + 'static, + for<'de2> ::Kind: Deserialize<'de2>, + { + type QueryId = Self; + } +}; + #[cfg(test)] pub mod tests { use super::*; diff --git a/src/fetch/webfinger.rs b/src/fetch/webfinger.rs index 91c31cd..68b110d 100644 --- a/src/fetch/webfinger.rs +++ b/src/fetch/webfinger.rs @@ -1,18 +1,38 @@ use crate::{ config::Data, - error::{Error, Error::WebfingerResolveFailed}, + error::Error, fetch::{fetch_object_http_with_accept, object_id::ObjectId}, traits::{Actor, Object}, FEDERATION_CONTENT_TYPE, }; -use anyhow::anyhow; use itertools::Itertools; +use once_cell::sync::Lazy; use regex::Regex; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Display}; use tracing::debug; use url::Url; +/// Errors relative to webfinger handling +#[derive(thiserror::Error, Debug)] +pub enum WebFingerError { + /// The webfinger identifier is invalid + #[error("The webfinger identifier is invalid")] + WrongFormat, + /// The webfinger identifier doesn't match the expected instance domain name + #[error("The webfinger identifier doesn't match the expected instance domain name")] + WrongDomain, + /// The wefinger object did not contain any link to an activitypub item + #[error("The webfinger object did not contain any link to an activitypub item")] + NoValidLink, +} + +impl WebFingerError { + fn into_crate_error(self) -> Error { + self.into() + } +} + /// Takes an identifier of the form `name@example.com`, and returns an object of `Kind`. /// /// For this the identifier is first resolved via webfinger protocol to an Activitypub ID. This ID @@ -24,22 +44,24 @@ pub async fn webfinger_resolve_actor( where Kind: Object + Actor + Send + 'static + Object, for<'de2> ::Kind: serde::Deserialize<'de2>, - ::Error: - From + From + From + Send + Sync, + ::Error: From + Send + Sync + Display, { let (_, domain) = identifier .splitn(2, '@') .collect_tuple() - .ok_or(WebfingerResolveFailed)?; + .ok_or(WebFingerError::WrongFormat.into_crate_error())?; let protocol = if data.config.debug { "http" } else { "https" }; let fetch_url = format!("{protocol}://{domain}/.well-known/webfinger?resource=acct:{identifier}"); debug!("Fetching webfinger url: {}", &fetch_url); - let res: Webfinger = - fetch_object_http_with_accept(&Url::parse(&fetch_url)?, data, "application/jrd+json") - .await? - .object; + let res: Webfinger = fetch_object_http_with_accept( + &Url::parse(&fetch_url).map_err(Error::UrlParse)?, + data, + "application/jrd+json", + ) + .await? + .object; debug_assert_eq!(res.subject, format!("acct:{identifier}")); let links: Vec = res @@ -54,13 +76,15 @@ where }) .filter_map(|l| l.href.clone()) .collect(); + for l in links { let object = ObjectId::::from(l).dereference(data).await; - if object.is_ok() { - return object; + match object { + Ok(obj) => return Ok(obj), + Err(error) => debug!(%error, "Failed to dereference link"), } } - Err(WebfingerResolveFailed.into()) + Err(WebFingerError::NoValidLink.into_crate_error().into()) } /// Extracts username from a webfinger resource parameter. @@ -88,20 +112,24 @@ where /// # Ok::<(), anyhow::Error>(()) /// }).unwrap(); ///``` -pub fn extract_webfinger_name(query: &str, data: &Data) -> Result +pub fn extract_webfinger_name<'i, T>(query: &'i str, data: &Data) -> Result<&'i str, Error> where T: Clone, { + static WEBFINGER_REGEX: Lazy = + Lazy::new(|| Regex::new(r"^acct:([\p{L}0-9_]+)@(.*)$").expect("compile regex")); // Regex to extract usernames from webfinger query. Supports different alphabets using `\p{L}`. - // TODO: would be nice if we could implement this without regex and remove the dependency - let regex = - Regex::new(&format!(r"^acct:@?([\p{{L}}0-9_]+)@{}$", data.domain())).map_err(Error::other)?; - Ok(regex + // TODO: This should use a URL parser + let captures = WEBFINGER_REGEX .captures(query) - .and_then(|c| c.get(1)) - .ok_or_else(|| Error::other(anyhow!("Webfinger regex failed to match")))? - .as_str() - .to_string()) + .ok_or(WebFingerError::WrongFormat)?; + + let account_name = captures.get(1).ok_or(WebFingerError::WrongFormat)?; + + if captures.get(2).map(|m| m.as_str()) != Some(data.domain()) { + return Err(WebFingerError::WrongDomain.into()); + } + Ok(account_name.as_str()) } /// Builds a basic webfinger response for the actor. @@ -249,15 +277,15 @@ mod tests { request_counter: Default::default(), }; assert_eq!( - Ok("test123".to_string()), + Ok("test123"), extract_webfinger_name("acct:test123@example.com", &data) ); assert_eq!( - Ok("Владимир".to_string()), + Ok("Владимир"), extract_webfinger_name("acct:Владимир@example.com", &data) ); assert_eq!( - Ok("تجريب".to_string()), + Ok("تجريب"), extract_webfinger_name("acct:تجريب@example.com", &data) ); Ok(()) diff --git a/src/http_signatures.rs b/src/http_signatures.rs index 96ce936..a165aaa 100644 --- a/src/http_signatures.rs +++ b/src/http_signatures.rs @@ -12,11 +12,13 @@ use crate::{ protocol::public_key::main_key_id, traits::{Actor, Object}, }; -use anyhow::Context; use base64::{engine::general_purpose::STANDARD as Base64, Engine}; use bytes::Bytes; -use http::{header::HeaderName, uri::PathAndQuery, HeaderValue, Method, Uri}; -use http_signature_normalization_reqwest::prelude::{Config, SignExt}; +use http::{uri::PathAndQuery, Uri}; +use http_signature_normalization_reqwest::{ + prelude::{Config, SignExt}, + DefaultSpawner, +}; use once_cell::sync::Lazy; use openssl::{ hash::MessageDigest, @@ -24,7 +26,11 @@ use openssl::{ rsa::Rsa, sign::{Signer, Verifier}, }; -use reqwest::Request; +use reqwest::{ + header::{HeaderName, HeaderValue}, + Method, + Request, +}; use reqwest_middleware::RequestBuilder; use serde::Deserialize; use sha2::{Digest, Sha256}; @@ -83,8 +89,9 @@ pub(crate) async fn sign_request( activity: Bytes, private_key: PKey, http_signature_compat: bool, -) -> Result { - static CONFIG: Lazy = Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER)); +) -> Result { + static CONFIG: Lazy> = + Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER)); static CONFIG_COMPAT: Lazy = Lazy::new(|| { Config::new() .mastodon_compat() @@ -103,14 +110,10 @@ pub(crate) async fn sign_request( Sha256::new(), activity, move |signing_string| { - let mut signer = Signer::new(MessageDigest::sha256(), &private_key) - .context("instantiating signer")?; - signer - .update(signing_string.as_bytes()) - .context("updating signer")?; + let mut signer = Signer::new(MessageDigest::sha256(), &private_key)?; + signer.update(signing_string.as_bytes())?; - Ok(Base64.encode(signer.sign_to_vec().context("sign to vec")?)) - as Result<_, anyhow::Error> + Ok(Base64.encode(signer.sign_to_vec()?)) as Result<_, Error> }, ) .await @@ -152,7 +155,7 @@ pub(crate) async fn signing_actor<'a, A, H>( ) -> Result::Error> where A: Object + Actor, - ::Error: From + From, + ::Error: From, for<'de2> ::Kind: Deserialize<'de2>, H: IntoIterator, { @@ -197,8 +200,8 @@ fn verify_signature_inner( let verified = CONFIG .begin_verify(method.as_str(), path_and_query, header_map) - .map_err(Error::other)? - .verify(|signature, signing_string| -> anyhow::Result { + .map_err(|val| Error::Other(val.to_string()))? + .verify(|signature, signing_string| -> Result { debug!( "Verifying with key {}, message {}", &public_key, &signing_string @@ -206,9 +209,13 @@ fn verify_signature_inner( let public_key = PKey::public_key_from_pem(public_key.as_bytes())?; let mut verifier = Verifier::new(MessageDigest::sha256(), &public_key)?; verifier.update(signing_string.as_bytes())?; - Ok(verifier.verify(&Base64.decode(signature)?)?) - }) - .map_err(Error::other)?; + + let base64_decoded = Base64 + .decode(signature) + .map_err(|err| Error::Other(err.to_string()))?; + + Ok(verifier.verify(&base64_decoded)?) + })?; if verified { debug!("verified signature for {}", uri); @@ -289,7 +296,7 @@ pub mod test { // use hardcoded date in order to test against hardcoded signature headers.insert( "date", - HeaderValue::from_str("Tue, 28 Mar 2023 21:03:44 GMT").unwrap(), + reqwest::header::HeaderValue::from_str("Tue, 28 Mar 2023 21:03:44 GMT").unwrap(), ); let request_builder = ClientWithMiddleware::from(Client::new()) diff --git a/src/lib.rs b/src/lib.rs index c660253..f482aa0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,46 @@ pub mod protocol; pub(crate) mod reqwest_shim; pub mod traits; +use crate::{ + config::Data, + error::Error, + fetch::object_id::ObjectId, + traits::{ActivityHandler, Actor, Object}, +}; pub use activitystreams_kinds as kinds; +use serde::{de::DeserializeOwned, Deserialize}; +use url::Url; + /// Mime type for Activitypub data, used for `Accept` and `Content-Type` HTTP headers pub static FEDERATION_CONTENT_TYPE: &str = "application/activity+json"; + +/// Deserialize incoming inbox activity to the given type, perform basic +/// validation and extract the actor. +async fn parse_received_activity( + body: &[u8], + data: &Data, +) -> Result<(Activity, ActorT), ::Error> +where + Activity: ActivityHandler + DeserializeOwned + Send + 'static, + ActorT: Object + Actor + Send + 'static, + for<'de2> ::Kind: serde::Deserialize<'de2>, + ::Error: From + From<::Error>, + ::Error: From, + Datatype: Clone, +{ + let activity: Activity = serde_json::from_slice(body).map_err(|e| { + // Attempt to include activity id in error message + #[derive(Deserialize)] + struct Id { + id: Url, + } + let id = serde_json::from_slice::(body).ok(); + Error::ParseReceivedActivity(e, id.map(|i| i.id)) + })?; + data.config.verify_url_and_domain(&activity).await?; + let actor = ObjectId::::from(activity.actor().clone()) + .dereference(data) + .await?; + Ok((activity, actor)) +} diff --git a/src/protocol/context.rs b/src/protocol/context.rs index 1d80bcb..027ff15 100644 --- a/src/protocol/context.rs +++ b/src/protocol/context.rs @@ -15,25 +15,23 @@ //! }; //! let note_with_context = WithContext::new_default(note); //! let serialized = serde_json::to_string(¬e_with_context)?; -//! assert_eq!(serialized, r#"{"@context":["https://www.w3.org/ns/activitystreams"],"content":"Hello world"}"#); +//! assert_eq!(serialized, r#"{"@context":"https://www.w3.org/ns/activitystreams","content":"Hello world"}"#); //! Ok::<(), serde_json::error::Error>(()) //! ``` -use crate::{config::Data, protocol::helpers::deserialize_one_or_many, traits::ActivityHandler}; +use crate::{config::Data, traits::ActivityHandler}; use serde::{Deserialize, Serialize}; use serde_json::Value; use url::Url; /// Default context used in Activitypub const DEFAULT_CONTEXT: &str = "https://www.w3.org/ns/activitystreams"; -const DEFAULT_SECURITY_CONTEXT: &str = "https://w3id.org/security/v1"; /// Wrapper for federated structs which handles `@context` field. #[derive(Serialize, Deserialize, Debug)] pub struct WithContext { #[serde(rename = "@context")] - #[serde(deserialize_with = "deserialize_one_or_many")] - context: Vec, + context: Value, #[serde(flatten)] inner: T, } @@ -41,15 +39,12 @@ pub struct WithContext { impl WithContext { /// Create a new wrapper with the default Activitypub context. pub fn new_default(inner: T) -> WithContext { - let context = vec![ - Value::String(DEFAULT_CONTEXT.to_string()), - Value::String(DEFAULT_SECURITY_CONTEXT.to_string()), - ]; + let context = Value::String(DEFAULT_CONTEXT.to_string()); WithContext::new(inner, context) } /// Create new wrapper with custom context. Use this in case you are implementing extensions. - pub fn new(inner: T, context: Vec) -> WithContext { + pub fn new(inner: T, context: Value) -> WithContext { WithContext { context, inner } } diff --git a/src/protocol/helpers.rs b/src/protocol/helpers.rs index 99ae7b2..8c69f65 100644 --- a/src/protocol/helpers.rs +++ b/src/protocol/helpers.rs @@ -56,12 +56,12 @@ where /// #[derive(serde::Deserialize)] /// struct Note { /// #[serde(deserialize_with = "deserialize_one")] -/// to: Url +/// to: [Url; 1] /// } /// /// let note = serde_json::from_str::(r#"{"to": ["https://example.com/u/alice"] }"#); /// assert!(note.is_ok()); -pub fn deserialize_one<'de, T, D>(deserializer: D) -> Result +pub fn deserialize_one<'de, T, D>(deserializer: D) -> Result<[T; 1], D::Error> where T: Deserialize<'de>, D: Deserializer<'de>, @@ -75,8 +75,8 @@ where let result: MaybeArray = Deserialize::deserialize(deserializer)?; Ok(match result { - MaybeArray::Simple(value) => value, - MaybeArray::Array([value]) => value, + MaybeArray::Simple(value) => [value], + MaybeArray::Array([value]) => [value], }) } @@ -125,7 +125,7 @@ mod tests { #[derive(serde::Deserialize)] struct Note { #[serde(deserialize_with = "deserialize_one")] - _to: Url, + _to: [Url; 1], } let note = serde_json::from_str::( diff --git a/src/protocol/public_key.rs b/src/protocol/public_key.rs index ecfcd3c..d36ee2b 100644 --- a/src/protocol/public_key.rs +++ b/src/protocol/public_key.rs @@ -6,7 +6,7 @@ use url::Url; /// Public key of actors which is used for HTTP signatures. /// /// This needs to be federated in the `public_key` field of all actors. -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] #[serde(rename_all = "camelCase")] pub struct PublicKey { /// Id of this private key. diff --git a/src/protocol/values.rs b/src/protocol/values.rs index 0c01097..4a87b3c 100644 --- a/src/protocol/values.rs +++ b/src/protocol/values.rs @@ -35,7 +35,7 @@ use serde::{Deserialize, Serialize}; /// Media type for markdown text. /// /// -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] pub enum MediaTypeMarkdown { /// `text/markdown` #[serde(rename = "text/markdown")] @@ -45,7 +45,7 @@ pub enum MediaTypeMarkdown { /// Media type for HTML text. /// /// -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] pub enum MediaTypeHtml { /// `text/html` #[serde(rename = "text/html")] diff --git a/src/protocol/verification.rs b/src/protocol/verification.rs index 3383bd9..18595b9 100644 --- a/src/protocol/verification.rs +++ b/src/protocol/verification.rs @@ -1,7 +1,6 @@ //! Verify that received data is valid use crate::error::Error; -use anyhow::anyhow; use url::Url; /// Check that both urls have the same domain. If not, return UrlVerificationError. @@ -16,7 +15,7 @@ use url::Url; /// ``` pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> { if a.domain() != b.domain() { - return Err(Error::UrlVerificationError(anyhow!("Domains do not match"))); + return Err(Error::UrlVerificationError("Domains do not match")); } Ok(()) } @@ -33,7 +32,7 @@ pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> { /// ``` pub fn verify_urls_match(a: &Url, b: &Url) -> Result<(), Error> { if a != b { - return Err(Error::UrlVerificationError(anyhow!("Urls do not match"))); + return Err(Error::UrlVerificationError("Urls do not match")); } Ok(()) } diff --git a/src/reqwest_shim.rs b/src/reqwest_shim.rs index 81c571b..9ebe108 100644 --- a/src/reqwest_shim.rs +++ b/src/reqwest_shim.rs @@ -3,10 +3,8 @@ use bytes::{BufMut, Bytes, BytesMut}; use futures_core::{ready, stream::BoxStream, Stream}; use pin_project_lite::pin_project; use reqwest::Response; -use serde::de::DeserializeOwned; use std::{ future::Future, - marker::PhantomData, mem, pin::Pin, task::{Context, Poll}, @@ -30,10 +28,7 @@ impl Future for BytesFuture { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { let this = self.as_mut().project(); - if let Some(chunk) = ready!(this.stream.poll_next(cx)) - .transpose() - .map_err(Error::other)? - { + if let Some(chunk) = ready!(this.stream.poll_next(cx)).transpose()? { this.aggregator.put(chunk); if this.aggregator.len() > *this.limit { return Poll::Ready(Err(Error::ResponseBodyLimit)); @@ -49,27 +44,6 @@ impl Future for BytesFuture { } } -pin_project! { - pub struct JsonFuture { - _t: PhantomData, - #[pin] - future: BytesFuture, - } -} - -impl Future for JsonFuture -where - T: DeserializeOwned, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let bytes = ready!(this.future.poll(cx))?; - Poll::Ready(serde_json::from_slice(&bytes).map_err(Error::other)) - } -} - pin_project! { pub struct TextFuture { #[pin] @@ -83,7 +57,7 @@ impl Future for TextFuture { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let bytes = ready!(this.future.poll(cx))?; - Poll::Ready(String::from_utf8(bytes.to_vec()).map_err(Error::other)) + Poll::Ready(String::from_utf8(bytes.to_vec()).map_err(Error::Utf8)) } } @@ -97,20 +71,16 @@ impl Future for TextFuture { /// TODO: Remove this shim as soon as reqwest gets support for size-limited bodies. pub trait ResponseExt { type BytesFuture; - type JsonFuture; type TextFuture; /// Size limited version of `bytes` to work around a reqwest issue. Check [`ResponseExt`] docs for details. fn bytes_limited(self) -> Self::BytesFuture; - /// Size limited version of `json` to work around a reqwest issue. Check [`ResponseExt`] docs for details. - fn json_limited(self) -> Self::JsonFuture; /// Size limited version of `text` to work around a reqwest issue. Check [`ResponseExt`] docs for details. fn text_limited(self) -> Self::TextFuture; } impl ResponseExt for Response { type BytesFuture = BytesFuture; - type JsonFuture = JsonFuture; type TextFuture = TextFuture; fn bytes_limited(self) -> Self::BytesFuture { @@ -121,13 +91,6 @@ impl ResponseExt for Response { } } - fn json_limited(self) -> Self::JsonFuture { - JsonFuture { - _t: PhantomData, - future: self.bytes_limited(), - } - } - fn text_limited(self) -> Self::TextFuture { TextFuture { future: self.bytes_limited(), diff --git a/src/traits.rs b/src/traits.rs index e4ed6d0..9fdec27 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -340,12 +340,12 @@ pub trait Collection: Sized { pub mod tests { use super::*; use crate::{ + error::Error, fetch::object_id::ObjectId, http_signatures::{generate_actor_keypair, Keypair}, protocol::{public_key::PublicKey, verification::verify_domains_match}, }; use activitystreams_kinds::{activity::FollowType, actor::PersonType}; - use anyhow::Error; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; @@ -356,7 +356,7 @@ pub mod tests { pub async fn read_post_from_json_id(&self, _: Url) -> Result, Error> { Ok(None) } - pub async fn read_local_user(&self, _: String) -> Result { + pub async fn read_local_user(&self, _: &str) -> Result { todo!() } pub async fn upsert(&self, _: &T) -> Result<(), Error> {