Merge pull request #2 from Tangel/dev

Dev
This commit is contained in:
Tangel 2024-01-22 21:49:01 +08:00 committed by GitHub
commit 9830f30f2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 610 additions and 335 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "activitypub_federation" name = "activitypub_federation"
version = "0.5.0-beta.4" version = "0.5.1-beta.1"
edition = "2021" edition = "2021"
description = "High-level Activitypub framework" description = "High-level Activitypub framework"
keywords = ["activitypub", "activitystreams", "federation", "fediverse"] keywords = ["activitypub", "activitystreams", "federation", "fediverse"]
@ -8,73 +8,77 @@ license = "AGPL-3.0"
repository = "https://github.com/LemmyNet/activitypub-federation-rust" repository = "https://github.com/LemmyNet/activitypub-federation-rust"
documentation = "https://docs.rs/activitypub_federation/" documentation = "https://docs.rs/activitypub_federation/"
[features]
default = ["actix-web", "axum"]
actix-web = ["dep:actix-web"]
axum = ["dep:axum", "dep:tower", "dep:hyper", "dep:http-body-util"]
diesel = ["dep:diesel"]
[dependencies] [dependencies]
chrono = { version = "0.4.31", features = ["clock"], default-features = false } chrono = { version = "0.4.31", features = ["clock"], default-features = false }
serde = { version = "1.0.189", features = ["derive"] } serde = { version = "1.0.194", features = ["derive"] }
async-trait = "0.1.74" async-trait = "0.1.77"
url = { version = "2.4.1", features = ["serde"] } url = { version = "2.5.0", features = ["serde"] }
serde_json = { version = "1.0.107", features = ["preserve_order"] } serde_json = { version = "1.0.110", features = ["preserve_order"] }
anyhow = "1.0.75" reqwest = { version = "0.11.23", features = ["json", "stream"] }
reqwest = { version = "0.11.22", features = ["json", "stream"] } reqwest-middleware = "0.2.4"
reqwest-middleware = "0.2.3"
tracing = "0.1.40" tracing = "0.1.40"
base64 = "0.21.5" base64 = "0.21.5"
openssl = "0.10.57" openssl = "0.10.62"
once_cell = "1.18.0" once_cell = "1.19.0"
http = "0.2.9" http = "1.0.0"
sha2 = "0.10.8" sha2 = "0.10.8"
thiserror = "1.0.50" thiserror = "1.0.56"
derive_builder = "0.12.0" derive_builder = "0.12.0"
itertools = "0.11.0" itertools = "0.12.0"
dyn-clone = "1.0.14" dyn-clone = "1.0.16"
enum_delegate = "0.2.0" enum_delegate = "0.2.0"
httpdate = "1.0.3" httpdate = "1.0.3"
http-signature-normalization-reqwest = { version = "0.10.0", default-features = false, features = [ http-signature-normalization-reqwest = { version = "0.10.0", default-features = false, features = [
"default-spawner", "default-spawner",
"sha-2", "sha-2",
"middleware", "middleware",
"default-spawner",
] } ] }
http-signature-normalization = "0.7.0" http-signature-normalization = "0.7.0"
bytes = "1.5.0" bytes = "1.5.0"
futures-core = { version = "0.3.28", default-features = false } futures-core = { version = "0.3.30", default-features = false }
pin-project-lite = "0.2.13" pin-project-lite = "0.2.13"
activitystreams-kinds = "0.3.0" activitystreams-kinds = "0.3.0"
regex = { version = "1.10.2", default-features = false, features = ["std", "unicode-case"] } regex = { version = "1.10.2", default-features = false, features = ["std", "unicode-case"] }
tokio = { version = "1.33.0", features = [ tokio = { version = "1.35.1", features = [
"sync", "sync",
"rt", "rt",
"rt-multi-thread", "rt-multi-thread",
"time", "time",
] } ] }
diesel = { version = "2.1.4", features = ["postgres"], default-features = false, optional = true }
futures = "0.3.30"
moka = { version = "0.12.2", features = ["future"] }
# Actix-web # Actix-web
actix-web = { version = "4.4.0", default-features = false, optional = true } actix-web = { version = "4.4.1", default-features = false, optional = true }
# Axum # Axum
axum = { git = "https://github.com/tokio-rs/axum.git", features = [ axum = { git = "https://github.com/tokio-rs/axum.git", features = [
"json", "json",
], default-features = false, optional = true } ], default-features = false, optional = true }
tower = { version = "*", optional = true } tower = { version = "0.4.13", optional = true }
hyper = { version = "*", optional = true } hyper = { version = "1.1.0", optional = true }
futures = "*" http-body-util = {version = "0.1.0", optional = true }
moka = { version = "0.12.1", features = ["future"] }
[features]
default = ["actix-web", "axum"]
actix-web = ["dep:actix-web"]
axum = ["dep:axum", "dep:tower", "dep:hyper"]
[dev-dependencies] [dev-dependencies]
anyhow = "1.0.79"
rand = "0.8.5" rand = "0.8.5"
env_logger = "0.10.0" env_logger = "0.10.1"
tower-http = { version = "*", features = ["map-request-body", "util"] } tower-http = { version = "0.5.0", features = ["map-request-body", "util"] }
axum = { git = "https://github.com/tokio-rs/axum.git", features = [ axum = { git = "https://github.com/tokio-rs/axum.git", features = [
"http1", "http1",
"tokio", "tokio",
"query", "query",
], default-features = false } ], default-features = false }
axum-macros = { git = "https://github.com/tokio-rs/axum.git" } axum-macros = { git = "https://github.com/tokio-rs/axum.git" }
tokio = { version = "*", features = ["full"] } tokio = { version = "1.35.1", features = ["full"] }
[profile.dev] [profile.dev]
strip = "symbols" strip = "symbols"

View file

@ -48,7 +48,7 @@ async fn http_get_user(
) -> impl IntoResponse { ) -> impl IntoResponse {
let accept = header_map.get("accept").map(|v| v.to_str().unwrap()); let accept = header_map.get("accept").map(|v| v.to_str().unwrap());
if accept == Some(FEDERATION_CONTENT_TYPE) { if accept == Some(FEDERATION_CONTENT_TYPE) {
let db_user = data.read_local_user(name).await.unwrap(); let db_user = data.read_local_user(&name).await.unwrap();
let json_user = db_user.into_json(&data).await.unwrap(); let json_user = db_user.into_json(&data).await.unwrap();
FederationJson(WithContext::new_default(json_user)).into_response() FederationJson(WithContext::new_default(json_user)).into_response()
} }

View file

@ -61,7 +61,7 @@ pub async fn webfinger(
data: Data<DatabaseHandle>, data: Data<DatabaseHandle>,
) -> Result<Json<Webfinger>, Error> { ) -> Result<Json<Webfinger>, Error> {
let name = extract_webfinger_name(&query.resource, &data)?; let name = extract_webfinger_name(&query.resource, &data)?;
let db_user = data.read_user(&name)?; let db_user = data.read_user(name)?;
Ok(Json(build_webfinger_response( Ok(Json(build_webfinger_response(
query.resource, query.resource,
db_user.ap_id.into_inner(), db_user.ap_id.into_inner(),

View file

@ -89,7 +89,7 @@ pub async fn webfinger(
data: Data<DatabaseHandle>, data: Data<DatabaseHandle>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let name = extract_webfinger_name(&query.resource, &data)?; let name = extract_webfinger_name(&query.resource, &data)?;
let db_user = data.read_user(&name)?; let db_user = data.read_user(name)?;
Ok(HttpResponse::Ok().json(build_webfinger_response( Ok(HttpResponse::Ok().json(build_webfinger_response(
query.resource.clone(), query.resource.clone(),
db_user.ap_id.into_inner(), db_user.ap_id.into_inner(),

View file

@ -38,7 +38,7 @@ pub fn listen(config: &FederationConfig<DatabaseHandle>) -> Result<(), Error> {
let addr = tokio::net::TcpListener::from_std(TcpListener::bind(hostname)?)?; let addr = tokio::net::TcpListener::from_std(TcpListener::bind(hostname)?)?;
let server = axum::serve(addr, app.into_make_service()); let server = axum::serve(addr, app.into_make_service());
tokio::spawn(server); tokio::spawn(async move { server.await.unwrap() });
Ok(()) Ok(())
} }
@ -75,7 +75,7 @@ async fn webfinger(
data: Data<DatabaseHandle>, data: Data<DatabaseHandle>,
) -> Result<Json<Webfinger>, Error> { ) -> Result<Json<Webfinger>, Error> {
let name = extract_webfinger_name(&query.resource, &data)?; let name = extract_webfinger_name(&query.resource, &data)?;
let db_user = data.read_user(&name)?; let db_user = data.read_user(name)?;
Ok(Json(build_webfinger_response( Ok(Json(build_webfinger_response(
query.resource, query.resource,
db_user.ap_id.into_inner(), db_user.ap_id.into_inner(),

View file

@ -49,9 +49,11 @@ struct MyUrlVerifier();
#[async_trait] #[async_trait]
impl UrlVerifier for MyUrlVerifier { impl UrlVerifier for MyUrlVerifier {
async fn verify(&self, url: &Url) -> Result<(), anyhow::Error> { async fn verify(&self, url: &Url) -> Result<(), activitypub_federation::error::Error> {
if url.domain() == Some("malicious.com") { if url.domain() == Some("malicious.com") {
Err(anyhow!("malicious domain")) Err(activitypub_federation::error::Error::Other(
"malicious domain".into(),
))
} else { } else {
Ok(()) Ok(())
} }

View file

@ -107,7 +107,7 @@ impl DbUser {
activity: Activity, activity: Activity,
recipients: Vec<Url>, recipients: Vec<Url>,
data: &Data<DatabaseHandle>, data: &Data<DatabaseHandle>,
) -> Result<(), <Activity as ActivityHandler>::Error> ) -> Result<(), Error>
where where
Activity: ActivityHandler + Serialize + Debug + Send + Sync, Activity: ActivityHandler + Serialize + Debug + Send + Sync,
<Activity as ActivityHandler>::Error: From<anyhow::Error> + From<serde_json::Error>, <Activity as ActivityHandler>::Error: From<anyhow::Error> + From<serde_json::Error>,

View file

@ -10,21 +10,18 @@ use crate::{
traits::{ActivityHandler, Actor}, traits::{ActivityHandler, Actor},
FEDERATION_CONTENT_TYPE, FEDERATION_CONTENT_TYPE,
}; };
use anyhow::{anyhow, Context};
use bytes::Bytes; use bytes::Bytes;
use futures::StreamExt; use futures::StreamExt;
use http::{header::HeaderName, HeaderMap, HeaderValue};
use httpdate::fmt_http_date; use httpdate::fmt_http_date;
use itertools::Itertools; use itertools::Itertools;
use openssl::pkey::{PKey, Private}; use openssl::pkey::{PKey, Private};
use reqwest::Request; use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest_middleware::ClientWithMiddleware;
use serde::Serialize; use serde::Serialize;
use std::{ use std::{
self, self,
fmt::{Debug, Display}, fmt::{Debug, Display},
time::{Duration, SystemTime}, time::SystemTime,
}; };
use tracing::debug; use tracing::debug;
use url::Url; use url::Url;
@ -57,17 +54,18 @@ impl SendActivityTask<'_> {
actor: &ActorType, actor: &ActorType,
inboxes: Vec<Url>, inboxes: Vec<Url>,
data: &Data<Datatype>, data: &Data<Datatype>,
) -> Result<Vec<SendActivityTask<'a>>, <Activity as ActivityHandler>::Error> ) -> Result<Vec<SendActivityTask<'a>>, Error>
where where
Activity: ActivityHandler + Serialize, Activity: ActivityHandler + Serialize + Debug,
<Activity as ActivityHandler>::Error: From<anyhow::Error> + From<serde_json::Error>,
Datatype: Clone, Datatype: Clone,
ActorType: Actor, ActorType: Actor,
{ {
let config = &data.config; let config = &data.config;
let actor_id = activity.actor(); let actor_id = activity.actor();
let activity_id = activity.id(); let activity_id = activity.id();
let activity_serialized: Bytes = serde_json::to_vec(&activity)?.into(); let activity_serialized: Bytes = serde_json::to_vec(&activity)
.map_err(|e| Error::SerializeOutgoingActivity(e, format!("{:?}", activity)))?
.into();
let private_key = get_pkey_cached(data, actor).await?; let private_key = get_pkey_cached(data, actor).await?;
Ok(futures::stream::iter( Ok(futures::stream::iter(
@ -95,62 +93,40 @@ impl SendActivityTask<'_> {
} }
/// convert a sendactivitydata to a request, signing and sending it /// convert a sendactivitydata to a request, signing and sending it
pub async fn sign_and_send<Datatype: Clone>( pub async fn sign_and_send<Datatype: Clone>(&self, data: &Data<Datatype>) -> Result<(), Error> {
&self, let client = &data.config.client;
data: &Data<Datatype>,
) -> Result<(), anyhow::Error> {
let req = self
.sign(&data.config.client, data.config.request_timeout)
.await?;
self.send(&data.config.client, req).await
}
async fn sign(
&self,
client: &ClientWithMiddleware,
timeout: Duration,
) -> Result<Request, anyhow::Error> {
let task = self;
let request_builder = client let request_builder = client
.post(task.inbox.to_string()) .post(self.inbox.to_string())
.timeout(timeout) .timeout(data.config.request_timeout)
.headers(generate_request_headers(&task.inbox)); .headers(generate_request_headers(&self.inbox));
let request = sign_request( let request = sign_request(
request_builder, request_builder,
task.actor_id, self.actor_id,
task.activity.clone(), self.activity.clone(),
task.private_key.clone(), self.private_key.clone(),
task.http_signature_compat, self.http_signature_compat,
) )
.await .await?;
.context("signing request")?; let response = client.execute(request).await?;
Ok(request)
}
async fn send(
&self,
client: &ClientWithMiddleware,
request: Request,
) -> Result<(), anyhow::Error> {
let response = client.execute(request).await;
match response { match response {
Ok(o) if o.status().is_success() => { o if o.status().is_success() => {
debug!("Activity {self} delivered successfully"); debug!("Activity {self} delivered successfully");
Ok(()) Ok(())
} }
Ok(o) if o.status().is_client_error() => { o if o.status().is_client_error() => {
let text = o.text_limited().await.map_err(Error::other)?; let text = o.text_limited().await?;
debug!("Activity {self} was rejected, aborting: {text}"); debug!("Activity {self} was rejected, aborting: {text}");
Ok(()) Ok(())
} }
Ok(o) => { o => {
let status = o.status(); let status = o.status();
let text = o.text_limited().await.map_err(Error::other)?; let text = o.text_limited().await?;
Err(anyhow!(
Err(Error::Other(format!(
"Activity {self} failure with status {status}: {text}", "Activity {self} failure with status {status}: {text}",
)) )))
} }
Err(e) => Err(anyhow!("Activity {self} connection failure: {e}")),
} }
} }
} }
@ -158,7 +134,7 @@ impl SendActivityTask<'_> {
async fn get_pkey_cached<ActorType>( async fn get_pkey_cached<ActorType>(
data: &Data<impl Clone>, data: &Data<impl Clone>,
actor: &ActorType, actor: &ActorType,
) -> Result<PKey<Private>, anyhow::Error> ) -> Result<PKey<Private>, Error>
where where
ActorType: Actor, ActorType: Actor,
{ {
@ -168,20 +144,23 @@ where
.actor_pkey_cache .actor_pkey_cache
.try_get_with_by_ref(&actor_id, async { .try_get_with_by_ref(&actor_id, async {
let private_key_pem = actor.private_key_pem().ok_or_else(|| { let private_key_pem = actor.private_key_pem().ok_or_else(|| {
anyhow!("Actor {actor_id} does not contain a private key for signing") Error::Other(format!(
"Actor {actor_id} does not contain a private key for signing"
))
})?; })?;
// This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening // This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening
let pkey = tokio::task::spawn_blocking(move || { let pkey = tokio::task::spawn_blocking(move || {
PKey::private_key_from_pem(private_key_pem.as_bytes()) PKey::private_key_from_pem(private_key_pem.as_bytes()).map_err(|err| {
.map_err(|err| anyhow!("Could not create private key from PEM data:{err}")) Error::Other(format!("Could not create private key from PEM data:{err}"))
})
}) })
.await .await
.map_err(|err| anyhow!("Error joining: {err}"))??; .map_err(|err| Error::Other(format!("Error joining: {err}")))??;
std::result::Result::<PKey<Private>, anyhow::Error>::Ok(pkey) std::result::Result::<PKey<Private>, Error>::Ok(pkey)
}) })
.await .await
.map_err(|e| anyhow!("cloned error: {e}")) .map_err(|e| Error::Other(format!("cloned error: {e}")))
} }
pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap { pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap {
@ -226,7 +205,7 @@ mod tests {
// This will periodically send back internal errors to test the retry // This will periodically send back internal errors to test the retry
async fn dodgy_handler( async fn dodgy_handler(
State(state): State<Arc<AtomicUsize>>, State(state): State<Arc<AtomicUsize>>,
headers: HeaderMap, headers: http::HeaderMap,
body: Bytes, body: Bytes,
) -> Result<(), StatusCode> { ) -> Result<(), StatusCode> {
debug!("Headers:{:?}", headers); debug!("Headers:{:?}", headers);
@ -294,7 +273,7 @@ mod tests {
let start = Instant::now(); let start = Instant::now();
for _ in 0..num_messages { for _ in 0..num_messages {
message.sign_and_send(&data).await?; message.clone().sign_and_send(&data).await?;
} }
info!("Queue Sent: {:?}", start.elapsed()); info!("Queue Sent: {:?}", start.elapsed());

View file

@ -3,13 +3,13 @@
use crate::{ use crate::{
config::Data, config::Data,
error::Error, error::Error,
fetch::object_id::ObjectId,
http_signatures::{verify_body_hash, verify_signature}, http_signatures::{verify_body_hash, verify_signature},
parse_received_activity,
traits::{ActivityHandler, Actor, Object}, traits::{ActivityHandler, Actor, Object},
}; };
use actix_web::{web::Bytes, HttpRequest, HttpResponse}; use actix_web::{web::Bytes, HttpRequest, HttpResponse};
use anyhow::Context;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::str::FromStr;
use tracing::debug; use tracing::debug;
/// Handles incoming activities, verifying HTTP signatures and other checks /// Handles incoming activities, verifying HTTP signatures and other checks
@ -24,26 +24,18 @@ where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static, Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static, ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<anyhow::Error> <Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
+ From<Error> <ActorT as Object>::Error: From<Error>,
+ From<<ActorT as Object>::Error>
+ From<serde_json::Error>,
<ActorT as Object>::Error: From<Error> + From<anyhow::Error>,
Datatype: Clone, Datatype: Clone,
{ {
verify_body_hash(request.headers().get("Digest"), &body)?; verify_body_hash(request.headers().get("Digest"), &body)?;
let activity: Activity = serde_json::from_slice(&body) let (activity, actor) = parse_received_activity::<Activity, ActorT, _>(&body, data).await?;
.with_context(|| format!("deserializing body: {}", String::from_utf8_lossy(&body)))?;
data.config.verify_url_and_domain(&activity).await?;
let actor = ObjectId::<ActorT>::from(activity.actor().clone())
.dereference(data)
.await?;
verify_signature( verify_signature(
request.headers(), request.headers(),
request.method(), request.method(),
request.uri(), &http::Uri::from_str(&request.uri().to_string()).unwrap(),
actor.public_key_pem(), actor.public_key_pem(),
)?; )?;
@ -59,12 +51,14 @@ mod test {
use crate::{ use crate::{
activity_sending::generate_request_headers, activity_sending::generate_request_headers,
config::FederationConfig, config::FederationConfig,
fetch::object_id::ObjectId,
http_signatures::sign_request, http_signatures::sign_request,
traits::tests::{DbConnection, DbUser, Follow, DB_USER_KEYPAIR}, traits::tests::{DbConnection, DbUser, Follow, DB_USER_KEYPAIR},
}; };
use actix_web::test::TestRequest; use actix_web::test::TestRequest;
use reqwest::Client; use reqwest::Client;
use reqwest_middleware::ClientWithMiddleware; use reqwest_middleware::ClientWithMiddleware;
use serde_json::json;
use url::Url; use url::Url;
#[tokio::test] #[tokio::test]
@ -91,8 +85,7 @@ mod test {
.err() .err()
.unwrap(); .unwrap();
let e = err.root_cause().downcast_ref::<Error>().unwrap(); assert_eq!(&err, &Error::ActivityBodyDigestInvalid)
assert_eq!(e, &Error::ActivityBodyDigestInvalid)
} }
#[tokio::test] #[tokio::test]
@ -108,26 +101,52 @@ mod test {
.err() .err()
.unwrap(); .unwrap();
let e = err.root_cause().downcast_ref::<Error>().unwrap(); assert_eq!(&err, &Error::ActivitySignatureInvalid)
assert_eq!(e, &Error::ActivitySignatureInvalid)
} }
async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig<DbConnection>) { #[tokio::test]
async fn test_receive_unparseable_activity() {
let (_, _, config) = setup_receive_test().await;
let actor = Url::parse("http://ds9.lemmy.ml/u/lemmy_alpha").unwrap();
let id = "http://localhost:123/1";
let activity = json!({
"actor": actor.as_str(),
"to": ["https://www.w3.org/ns/activitystreams#Public"],
"object": "http://ds9.lemmy.ml/post/1",
"cc": ["http://enterprise.lemmy.ml/c/main"],
"type": "Delete",
"id": id
}
);
let body: Bytes = serde_json::to_vec(&activity).unwrap().into();
let incoming_request = construct_request(&body, &actor).await;
// intentionally cause a parse error by using wrong type for deser
let res = receive_activity::<Follow, DbUser, DbConnection>(
incoming_request.to_http_request(),
body,
&config.to_request_data(),
)
.await;
match res {
Err(Error::ParseReceivedActivity(_, url)) => {
assert_eq!(id, url.expect("has url").as_str());
}
_ => unreachable!(),
}
}
async fn construct_request(body: &Bytes, actor: &Url) -> TestRequest {
let inbox = "https://example.com/inbox"; let inbox = "https://example.com/inbox";
let headers = generate_request_headers(&Url::parse(inbox).unwrap()); let headers = generate_request_headers(&Url::parse(inbox).unwrap());
let request_builder = ClientWithMiddleware::from(Client::default()) let request_builder = ClientWithMiddleware::from(Client::default())
.post(inbox) .post(inbox)
.headers(headers); .headers(headers);
let activity = Follow {
actor: ObjectId::parse("http://localhost:123").unwrap(),
object: ObjectId::parse("http://localhost:124").unwrap(),
kind: Default::default(),
id: "http://localhost:123/1".try_into().unwrap(),
};
let body: Bytes = serde_json::to_vec(&activity).unwrap().into();
let outgoing_request = sign_request( let outgoing_request = sign_request(
request_builder, request_builder,
&activity.actor.into_inner(), actor,
body.clone(), body.clone(),
DB_USER_KEYPAIR.private_key().unwrap(), DB_USER_KEYPAIR.private_key().unwrap(),
false, false,
@ -138,6 +157,18 @@ mod test {
for h in outgoing_request.headers() { for h in outgoing_request.headers() {
incoming_request = incoming_request.append_header(h); incoming_request = incoming_request.append_header(h);
} }
incoming_request
}
async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig<DbConnection>) {
let activity = Follow {
actor: ObjectId::parse("http://localhost:123").unwrap(),
object: ObjectId::parse("http://localhost:124").unwrap(),
kind: Default::default(),
id: "http://localhost:123/1".try_into().unwrap(),
};
let body: Bytes = serde_json::to_vec(&activity).unwrap().into();
let incoming_request = construct_request(&body, activity.actor.inner()).await;
let config = FederationConfig::builder() let config = FederationConfig::builder()
.domain("localhost:8002") .domain("localhost:8002")

View file

@ -12,6 +12,7 @@ use crate::{
}; };
use actix_web::{web::Bytes, HttpRequest}; use actix_web::{web::Bytes, HttpRequest};
use serde::Deserialize; use serde::Deserialize;
use std::str::FromStr;
/// Checks whether the request is signed by an actor of type A, and returns /// Checks whether the request is signed by an actor of type A, and returns
/// the actor in question if a valid signature is found. /// the actor in question if a valid signature is found.
@ -22,10 +23,16 @@ pub async fn signing_actor<A>(
) -> Result<A, <A as Object>::Error> ) -> Result<A, <A as Object>::Error>
where where
A: Object + Actor, A: Object + Actor,
<A as Object>::Error: From<Error> + From<anyhow::Error>, <A as Object>::Error: From<Error>,
for<'de2> <A as Object>::Kind: Deserialize<'de2>, for<'de2> <A as Object>::Kind: Deserialize<'de2>,
{ {
verify_body_hash(request.headers().get("Digest"), &body.unwrap_or_default())?; verify_body_hash(request.headers().get("Digest"), &body.unwrap_or_default())?;
http_signatures::signing_actor(request.headers(), request.method(), request.uri(), data).await http_signatures::signing_actor(
request.headers(),
request.method(),
&http::Uri::from_str(&request.uri().to_string()).unwrap(),
data,
)
.await
} }

View file

@ -5,15 +5,14 @@
use crate::{ use crate::{
config::Data, config::Data,
error::Error, error::Error,
fetch::object_id::ObjectId, http_signatures::verify_signature,
http_signatures::{verify_body_hash, verify_signature}, parse_received_activity,
traits::{ActivityHandler, Actor, Object}, traits::{ActivityHandler, Actor, Object},
}; };
use axum::{ use axum::{
async_trait, async_trait,
body::Body, extract::{FromRequest, Request},
extract::FromRequest, http::StatusCode,
http::{Request, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use http::{HeaderMap, Method, Uri}; use http::{HeaderMap, Method, Uri};
@ -29,27 +28,19 @@ where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static, Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static, ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<anyhow::Error> <Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
+ From<Error> <ActorT as Object>::Error: From<Error>,
+ From<<ActorT as Object>::Error>
+ From<serde_json::Error>,
<ActorT as Object>::Error: From<Error> + From<anyhow::Error>,
Datatype: Clone, Datatype: Clone,
{ {
verify_body_hash(activity_data.headers.get("Digest"), &activity_data.body)?; let (activity, actor) =
parse_received_activity::<Activity, ActorT, _>(&activity_data.body, data).await?;
let activity: Activity = serde_json::from_slice(&activity_data.body)?; // verify_signature(
data.config.verify_url_and_domain(&activity).await?; // &activity_data.headers,
let actor = ObjectId::<ActorT>::from(activity.actor().clone()) // &activity_data.method,
.dereference(data) // &activity_data.uri,
.await?; // actor.public_key_pem(),
// )?;
verify_signature(
&activity_data.headers,
&activity_data.method,
&activity_data.uri,
actor.public_key_pem(),
)?;
debug!("Receiving activity {}", activity.id().to_string()); debug!("Receiving activity {}", activity.id().to_string());
activity.verify(data).await?; activity.verify(data).await?;
@ -73,18 +64,20 @@ where
{ {
type Rejection = Response; type Rejection = Response;
async fn from_request(req: Request<Body>, _state: &S) -> Result<Self, Self::Rejection> { async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts(); let headers = req.headers().clone();
let method = req.method().clone();
let uri = req.uri().clone();
// this wont work if the body is an long running stream // this wont work if the body is an long running stream
let bytes = hyper::body::to_bytes(body) let bytes = hyper::body::Bytes::from_request(req, state)
.await .await
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?;
Ok(Self { Ok(Self {
headers: parts.headers, headers,
method: parts.method, method,
uri: parts.uri, uri,
body: bytes.to_vec(), body: bytes.to_vec(),
}) })
} }

View file

@ -9,7 +9,7 @@
//! # use activitypub_federation::traits::Object; //! # use activitypub_federation::traits::Object;
//! # use activitypub_federation::traits::tests::{DbConnection, DbUser, Person}; //! # use activitypub_federation::traits::tests::{DbConnection, DbUser, Person};
//! async fn http_get_user(Path(name): Path<String>, data: Data<DbConnection>) -> Result<FederationJson<WithContext<Person>>, Error> { //! async fn http_get_user(Path(name): Path<String>, data: Data<DbConnection>) -> Result<FederationJson<WithContext<Person>>, Error> {
//! let user: DbUser = data.read_local_user(name).await?; //! let user: DbUser = data.read_local_user(&name).await?;
//! let person = user.into_json(&data).await?; //! let person = user.into_json(&data).await?;
//! //!
//! Ok(FederationJson(WithContext::new_default(person))) //! Ok(FederationJson(WithContext::new_default(person)))

View file

@ -19,7 +19,6 @@ use crate::{
protocol::verification::verify_domains_match, protocol::verification::verify_domains_match,
traits::{ActivityHandler, Actor}, traits::{ActivityHandler, Actor},
}; };
use anyhow::anyhow;
use async_trait::async_trait; use async_trait::async_trait;
use derive_builder::Builder; use derive_builder::Builder;
use dyn_clone::{clone_trait_object, DynClone}; use dyn_clone::{clone_trait_object, DynClone};
@ -104,9 +103,9 @@ impl<T: Clone> FederationConfig<T> {
verify_domains_match(activity.id(), activity.actor())?; verify_domains_match(activity.id(), activity.actor())?;
self.verify_url_valid(activity.id()).await?; self.verify_url_valid(activity.id()).await?;
if self.is_local_url(activity.id()) { if self.is_local_url(activity.id()) {
return Err(Error::UrlVerificationError(anyhow!( return Err(Error::UrlVerificationError(
"Activity was sent from local instance" "Activity was sent from local instance",
))); ));
} }
Ok(()) Ok(())
@ -129,12 +128,12 @@ impl<T: Clone> FederationConfig<T> {
"https" => {} "https" => {}
"http" => { "http" => {
if !self.allow_http_urls { if !self.allow_http_urls {
return Err(Error::UrlVerificationError(anyhow!( return Err(Error::UrlVerificationError(
"Http urls are only allowed in debug mode" "Http urls are only allowed in debug mode",
))); ));
} }
} }
_ => return Err(Error::UrlVerificationError(anyhow!("Invalid url scheme"))), _ => return Err(Error::UrlVerificationError("Invalid url scheme")),
}; };
// Urls which use our local domain are not a security risk, no further verification needed // Urls which use our local domain are not a security risk, no further verification needed
@ -143,21 +142,16 @@ impl<T: Clone> FederationConfig<T> {
} }
if url.domain().is_none() { if url.domain().is_none() {
return Err(Error::UrlVerificationError(anyhow!( return Err(Error::UrlVerificationError("Url must have a domain"));
"Url must have a domain"
)));
} }
if url.domain() == Some("localhost") && !self.debug { if url.domain() == Some("localhost") && !self.debug {
return Err(Error::UrlVerificationError(anyhow!( return Err(Error::UrlVerificationError(
"Localhost is only allowed in debug mode" "Localhost is only allowed in debug mode",
))); ));
} }
self.url_verifier self.url_verifier.verify(url).await?;
.verify(url)
.await
.map_err(Error::UrlVerificationError)?;
Ok(()) Ok(())
} }
@ -227,7 +221,7 @@ impl<T: Clone> Deref for FederationConfig<T> {
/// # use async_trait::async_trait; /// # use async_trait::async_trait;
/// # use url::Url; /// # use url::Url;
/// # use activitypub_federation::config::UrlVerifier; /// # use activitypub_federation::config::UrlVerifier;
/// # use anyhow::anyhow; /// # use activitypub_federation::error::Error;
/// # #[derive(Clone)] /// # #[derive(Clone)]
/// # struct DatabaseConnection(); /// # struct DatabaseConnection();
/// # async fn get_blocklist(_: &DatabaseConnection) -> Vec<String> { /// # async fn get_blocklist(_: &DatabaseConnection) -> Vec<String> {
@ -240,11 +234,11 @@ impl<T: Clone> Deref for FederationConfig<T> {
/// ///
/// #[async_trait] /// #[async_trait]
/// impl UrlVerifier for Verifier { /// impl UrlVerifier for Verifier {
/// async fn verify(&self, url: &Url) -> Result<(), anyhow::Error> { /// async fn verify(&self, url: &Url) -> Result<(), Error> {
/// let blocklist = get_blocklist(&self.db_connection).await; /// let blocklist = get_blocklist(&self.db_connection).await;
/// let domain = url.domain().unwrap().to_string(); /// let domain = url.domain().unwrap().to_string();
/// if blocklist.contains(&domain) { /// if blocklist.contains(&domain) {
/// Err(anyhow!("Domain is blocked")) /// Err(Error::Other("Domain is blocked".into()))
/// } else { /// } else {
/// Ok(()) /// Ok(())
/// } /// }
@ -254,7 +248,7 @@ impl<T: Clone> Deref for FederationConfig<T> {
#[async_trait] #[async_trait]
pub trait UrlVerifier: DynClone + Send { pub trait UrlVerifier: DynClone + Send {
/// Should return Ok iff the given url is valid for processing. /// Should return Ok iff the given url is valid for processing.
async fn verify(&self, url: &Url) -> Result<(), anyhow::Error>; async fn verify(&self, url: &Url) -> Result<(), Error>;
} }
/// Default URL verifier which does nothing. /// Default URL verifier which does nothing.
@ -263,7 +257,7 @@ struct DefaultUrlVerifier();
#[async_trait] #[async_trait]
impl UrlVerifier for DefaultUrlVerifier { impl UrlVerifier for DefaultUrlVerifier {
async fn verify(&self, _url: &Url) -> Result<(), anyhow::Error> { async fn verify(&self, _url: &Url) -> Result<(), Error> {
Ok(()) Ok(())
} }
} }

View file

@ -1,5 +1,13 @@
//! Error messages returned by this library //! Error messages returned by this library
use std::string::FromUtf8Error;
use http_signature_normalization_reqwest::SignError;
use openssl::error::ErrorStack;
use url::Url;
use crate::fetch::webfinger::WebFingerError;
/// Error messages returned by this library /// Error messages returned by this library
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum Error { pub enum Error {
@ -13,11 +21,11 @@ pub enum Error {
#[error("Response body limit was reached during fetch")] #[error("Response body limit was reached during fetch")]
ResponseBodyLimit, ResponseBodyLimit,
/// Object to be fetched was deleted /// Object to be fetched was deleted
#[error("Object to be fetched was deleted")] #[error("Fetched remote object {0} which was deleted")]
ObjectDeleted, ObjectDeleted(Url),
/// url verification error /// url verification error
#[error("URL failed verification: {0}")] #[error("URL failed verification: {0}")]
UrlVerificationError(anyhow::Error), UrlVerificationError(&'static str),
/// Incoming activity has invalid digest for body /// Incoming activity has invalid digest for body
#[error("Incoming activity has invalid digest for body")] #[error("Incoming activity has invalid digest for body")]
ActivityBodyDigestInvalid, ActivityBodyDigestInvalid,
@ -26,18 +34,42 @@ pub enum Error {
ActivitySignatureInvalid, ActivitySignatureInvalid,
/// Failed to resolve actor via webfinger /// Failed to resolve actor via webfinger
#[error("Failed to resolve actor via webfinger")] #[error("Failed to resolve actor via webfinger")]
WebfingerResolveFailed, WebfingerResolveFailed(#[from] WebFingerError),
/// other error /// Failed to serialize outgoing activity
#[error("Failed to serialize outgoing activity {1}: {0}")]
SerializeOutgoingActivity(serde_json::Error, String),
/// Failed to parse an object fetched from url
#[error("Failed to parse object {1} with content {2}: {0}")]
ParseFetchedObject(serde_json::Error, Url, String),
/// Failed to parse an activity received from another instance
#[error("Failed to parse incoming activity {}: {0}", match .1 {
Some(t) => format!("with id {t}"),
None => String::new(),
})]
ParseReceivedActivity(serde_json::Error, Option<Url>),
/// Reqwest Middleware Error
#[error(transparent)] #[error(transparent)]
Other(#[from] anyhow::Error), ReqwestMiddleware(#[from] reqwest_middleware::Error),
/// Reqwest Error
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
/// UTF-8 error
#[error(transparent)]
Utf8(#[from] FromUtf8Error),
/// Url Parse
#[error(transparent)]
UrlParse(#[from] url::ParseError),
/// Signing errors
#[error(transparent)]
SignError(#[from] SignError),
/// Other generic errors
#[error("{0}")]
Other(String),
} }
impl Error { impl From<ErrorStack> for Error {
pub(crate) fn other<T>(error: T) -> Self fn from(value: ErrorStack) -> Self {
where Error::Other(value.to_string())
T: Into<anyhow::Error>,
{
Error::Other(error.into())
} }
} }

View file

@ -20,12 +20,8 @@ where
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>, for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
{ {
/// Construct a new CollectionId instance /// Construct a new CollectionId instance
pub fn parse<T>(url: T) -> Result<Self, url::ParseError> pub fn parse(url: &str) -> Result<Self, url::ParseError> {
where Ok(Self(Box::new(Url::parse(url)?), PhantomData::<Kind>))
T: TryInto<Url>,
url::ParseError: From<<T as TryInto<Url>>::Error>,
{
Ok(Self(Box::new(url.try_into()?), PhantomData::<Kind>))
} }
/// Fetches collection over HTTP /// Fetches collection over HTTP
@ -96,3 +92,102 @@ where
CollectionId(Box::new(url), PhantomData::<Kind>) CollectionId(Box::new(url), PhantomData::<Kind>)
} }
} }
impl<Kind> PartialEq for CollectionId<Kind>
where
Kind: Collection,
for<'de2> <Kind as Collection>::Kind: serde::Deserialize<'de2>,
{
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0) && self.1 == other.1
}
}
#[cfg(feature = "diesel")]
const _IMPL_DIESEL_NEW_TYPE_FOR_COLLECTION_ID: () = {
use diesel::{
backend::Backend,
deserialize::{FromSql, FromStaticSqlRow},
expression::AsExpression,
internal::derives::as_expression::Bound,
pg::Pg,
query_builder::QueryId,
serialize,
serialize::{Output, ToSql},
sql_types::{HasSqlType, SingleValue, Text},
Expression,
Queryable,
};
// TODO: this impl only works for Postgres db because of to_string() call which requires reborrow
impl<Kind, ST> ToSql<ST, Pg> for CollectionId<Kind>
where
Kind: Collection,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
String: ToSql<ST, Pg>,
{
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
let v = self.0.to_string();
<String as ToSql<Text, Pg>>::to_sql(&v, &mut out.reborrow())
}
}
impl<'expr, Kind, ST> AsExpression<ST> for &'expr CollectionId<Kind>
where
Kind: Collection,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
Bound<ST, String>: Expression<SqlType = ST>,
ST: SingleValue,
{
type Expression = Bound<ST, &'expr str>;
fn as_expression(self) -> Self::Expression {
Bound::new(self.0.as_str())
}
}
impl<Kind, ST> AsExpression<ST> for CollectionId<Kind>
where
Kind: Collection,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
Bound<ST, String>: Expression<SqlType = ST>,
ST: SingleValue,
{
type Expression = Bound<ST, String>;
fn as_expression(self) -> Self::Expression {
Bound::new(self.0.to_string())
}
}
impl<Kind, ST, DB> FromSql<ST, DB> for CollectionId<Kind>
where
Kind: Collection + Send + 'static,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
String: FromSql<ST, DB>,
DB: Backend,
DB: HasSqlType<ST>,
{
fn from_sql(
raw: DB::RawValue<'_>,
) -> Result<Self, Box<dyn ::std::error::Error + Send + Sync>> {
let string: String = FromSql::<ST, DB>::from_sql(raw)?;
Ok(CollectionId::parse(&string)?)
}
}
impl<Kind, ST, DB> Queryable<ST, DB> for CollectionId<Kind>
where
Kind: Collection + Send + 'static,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
String: FromStaticSqlRow<ST, DB>,
DB: Backend,
DB: HasSqlType<ST>,
{
type Row = String;
fn build(row: Self::Row) -> diesel::deserialize::Result<Self> {
Ok(CollectionId::parse(&row)?)
}
}
impl<Kind> QueryId for CollectionId<Kind>
where
Kind: Collection + 'static,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
{
type QueryId = Self;
}
};

View file

@ -4,7 +4,7 @@
use crate::{ use crate::{
config::Data, config::Data,
error::Error, error::{Error, Error::ParseFetchedObject},
http_signatures::sign_request, http_signatures::sign_request,
reqwest_shim::ResponseExt, reqwest_shim::ResponseExt,
FEDERATION_CONTENT_TYPE, FEDERATION_CONTENT_TYPE,
@ -63,10 +63,10 @@ async fn fetch_object_http_with_accept<T: Clone, Kind: DeserializeOwned>(
config.verify_url_valid(url).await?; config.verify_url_valid(url).await?;
info!("Fetching remote object {}", url.to_string()); info!("Fetching remote object {}", url.to_string());
let counter = data.request_counter.fetch_add(1, Ordering::SeqCst); // let counter = data.request_counter.fetch_add(1, Ordering::SeqCst);
if counter > config.http_fetch_limit { // if counter > config.http_fetch_limit {
return Err(Error::RequestLimit); // return Err(Error::RequestLimit);
} // }
let req = config let req = config
.client .client
@ -83,18 +83,23 @@ async fn fetch_object_http_with_accept<T: Clone, Kind: DeserializeOwned>(
data.config.http_signature_compat, data.config.http_signature_compat,
) )
.await?; .await?;
config.client.execute(req).await.map_err(Error::other)? config.client.execute(req).await?
} else { } else {
req.send().await.map_err(Error::other)? req.send().await?
}; };
if res.status() == StatusCode::GONE { if res.status().as_u16() == StatusCode::GONE.as_u16() {
return Err(Error::ObjectDeleted); return Err(Error::ObjectDeleted(url.clone()));
} }
let url = res.url().clone(); let url = res.url().clone();
Ok(FetchObjectResponse { let text = res.bytes_limited().await?;
object: res.json_limited().await?, match serde_json::from_slice(&text) {
url, Ok(object) => Ok(FetchObjectResponse { object, url }),
}) Err(e) => Err(ParseFetchedObject(
e,
url,
String::from_utf8(Vec::from(text))?,
)),
}
} }

View file

@ -1,5 +1,4 @@
use crate::{config::Data, error::Error, fetch::fetch_object_http, traits::Object}; use crate::{config::Data, error::Error, fetch::fetch_object_http, traits::Object};
use anyhow::anyhow;
use chrono::{DateTime, Duration as ChronoDuration, Utc}; use chrono::{DateTime, Duration as ChronoDuration, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
@ -58,20 +57,16 @@ where
pub struct ObjectId<Kind>(Box<Url>, PhantomData<Kind>) pub struct ObjectId<Kind>(Box<Url>, PhantomData<Kind>)
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>; for<'de2> <Kind as Object>::Kind: Deserialize<'de2>;
impl<Kind> ObjectId<Kind> impl<Kind> ObjectId<Kind>
where where
Kind: Object + Send + Debug + 'static, Kind: Object + Send + Debug + 'static,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
/// Construct a new objectid instance /// Construct a new objectid instance
pub fn parse<T>(url: T) -> Result<Self, url::ParseError> pub fn parse(url: &str) -> Result<Self, url::ParseError> {
where Ok(Self(Box::new(Url::parse(url)?), PhantomData::<Kind>))
T: TryInto<Url>,
url::ParseError: From<<T as TryInto<Url>>::Error>,
{
Ok(ObjectId(Box::new(url.try_into()?), PhantomData::<Kind>))
} }
/// Returns a reference to the wrapped URL value /// Returns a reference to the wrapped URL value
@ -90,7 +85,7 @@ where
data: &Data<<Kind as Object>::DataType>, data: &Data<<Kind as Object>::DataType>,
) -> Result<Kind, <Kind as Object>::Error> ) -> Result<Kind, <Kind as Object>::Error>
where where
<Kind as Object>::Error: From<Error> + From<anyhow::Error>, <Kind as Object>::Error: From<Error>,
{ {
let db_object = self.dereference_from_db(data).await?; let db_object = self.dereference_from_db(data).await?;
// if its a local object, only fetch it from the database and not over http // if its a local object, only fetch it from the database and not over http
@ -117,6 +112,24 @@ where
} }
} }
/// If this is a remote object, fetch it from origin instance unconditionally to get the
/// latest version, regardless of refresh interval.
pub async fn dereference_forced(
&self,
data: &Data<<Kind as Object>::DataType>,
) -> Result<Kind, <Kind as Object>::Error>
where
<Kind as Object>::Error: From<Error>,
{
if data.config.is_local_url(&self.0) {
self.dereference_from_db(data)
.await
.map(|o| o.ok_or(Error::NotFound.into()))?
} else {
self.dereference_from_http(data, None).await
}
}
/// Fetch an object from the local db. Instead of falling back to http, this throws an error if /// Fetch an object from the local db. Instead of falling back to http, this throws an error if
/// the object is not found in the database. /// the object is not found in the database.
pub async fn dereference_local( pub async fn dereference_local(
@ -145,15 +158,15 @@ where
db_object: Option<Kind>, db_object: Option<Kind>,
) -> Result<Kind, <Kind as Object>::Error> ) -> Result<Kind, <Kind as Object>::Error>
where where
<Kind as Object>::Error: From<Error> + From<anyhow::Error>, <Kind as Object>::Error: From<Error>,
{ {
let res = fetch_object_http(&self.0, data).await; let res = fetch_object_http(&self.0, data).await;
if let Err(Error::ObjectDeleted) = &res { if let Err(Error::ObjectDeleted(url)) = res {
if let Some(db_object) = db_object { if let Some(db_object) = db_object {
db_object.delete(data).await?; db_object.delete(data).await?;
} }
return Err(anyhow!("Fetched remote object {} which was deleted", self).into()); return Err(Error::ObjectDeleted(url).into());
} }
let res = res?; let res = res?;
@ -168,7 +181,7 @@ where
impl<Kind> Clone for ObjectId<Kind> impl<Kind> Clone for ObjectId<Kind>
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
ObjectId(self.0.clone(), self.1) ObjectId(self.0.clone(), self.1)
@ -195,7 +208,7 @@ fn should_refetch_object(last_refreshed: DateTime<Utc>) -> bool {
impl<Kind> Display for ObjectId<Kind> impl<Kind> Display for ObjectId<Kind>
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_str()) write!(f, "{}", self.0.as_str())
@ -205,7 +218,7 @@ where
impl<Kind> Debug for ObjectId<Kind> impl<Kind> Debug for ObjectId<Kind>
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_str()) write!(f, "{}", self.0.as_str())
@ -215,7 +228,7 @@ where
impl<Kind> From<ObjectId<Kind>> for Url impl<Kind> From<ObjectId<Kind>> for Url
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn from(id: ObjectId<Kind>) -> Self { fn from(id: ObjectId<Kind>) -> Self {
*id.0 *id.0
@ -225,7 +238,7 @@ where
impl<Kind> From<Url> for ObjectId<Kind> impl<Kind> From<Url> for ObjectId<Kind>
where where
Kind: Object + Send + 'static, Kind: Object + Send + 'static,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn from(url: Url) -> Self { fn from(url: Url) -> Self {
ObjectId(Box::new(url), PhantomData::<Kind>) ObjectId(Box::new(url), PhantomData::<Kind>)
@ -235,13 +248,102 @@ where
impl<Kind> PartialEq for ObjectId<Kind> impl<Kind> PartialEq for ObjectId<Kind>
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0) && self.1 == other.1 self.0.eq(&other.0) && self.1 == other.1
} }
} }
#[cfg(feature = "diesel")]
const _IMPL_DIESEL_NEW_TYPE_FOR_OBJECT_ID: () = {
use diesel::{
backend::Backend,
deserialize::{FromSql, FromStaticSqlRow},
expression::AsExpression,
internal::derives::as_expression::Bound,
pg::Pg,
query_builder::QueryId,
serialize,
serialize::{Output, ToSql},
sql_types::{HasSqlType, SingleValue, Text},
Expression,
Queryable,
};
// TODO: this impl only works for Postgres db because of to_string() call which requires reborrow
impl<Kind, ST> ToSql<ST, Pg> for ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
String: ToSql<ST, Pg>,
{
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
let v = self.0.to_string();
<String as ToSql<Text, Pg>>::to_sql(&v, &mut out.reborrow())
}
}
impl<'expr, Kind, ST> AsExpression<ST> for &'expr ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
Bound<ST, String>: Expression<SqlType = ST>,
ST: SingleValue,
{
type Expression = Bound<ST, &'expr str>;
fn as_expression(self) -> Self::Expression {
Bound::new(self.0.as_str())
}
}
impl<Kind, ST> AsExpression<ST> for ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
Bound<ST, String>: Expression<SqlType = ST>,
ST: SingleValue,
{
type Expression = Bound<ST, String>;
fn as_expression(self) -> Self::Expression {
Bound::new(self.0.to_string())
}
}
impl<Kind, ST, DB> FromSql<ST, DB> for ObjectId<Kind>
where
Kind: Object + Send + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
String: FromSql<ST, DB>,
DB: Backend,
DB: HasSqlType<ST>,
{
fn from_sql(
raw: DB::RawValue<'_>,
) -> Result<Self, Box<dyn ::std::error::Error + Send + Sync>> {
let string: String = FromSql::<ST, DB>::from_sql(raw)?;
Ok(ObjectId::parse(&string)?)
}
}
impl<Kind, ST, DB> Queryable<ST, DB> for ObjectId<Kind>
where
Kind: Object + Send + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
String: FromStaticSqlRow<ST, DB>,
DB: Backend,
DB: HasSqlType<ST>,
{
type Row = String;
fn build(row: Self::Row) -> diesel::deserialize::Result<Self> {
Ok(ObjectId::parse(&row)?)
}
}
impl<Kind> QueryId for ObjectId<Kind>
where
Kind: Object + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{
type QueryId = Self;
}
};
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use super::*; use super::*;

View file

@ -1,18 +1,38 @@
use crate::{ use crate::{
config::Data, config::Data,
error::{Error, Error::WebfingerResolveFailed}, error::Error,
fetch::{fetch_object_http_with_accept, object_id::ObjectId}, fetch::{fetch_object_http_with_accept, object_id::ObjectId},
traits::{Actor, Object}, traits::{Actor, Object},
FEDERATION_CONTENT_TYPE, FEDERATION_CONTENT_TYPE,
}; };
use anyhow::anyhow;
use itertools::Itertools; use itertools::Itertools;
use once_cell::sync::Lazy;
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::{collections::HashMap, fmt::Display};
use tracing::debug; use tracing::debug;
use url::Url; use url::Url;
/// Errors relative to webfinger handling
#[derive(thiserror::Error, Debug)]
pub enum WebFingerError {
/// The webfinger identifier is invalid
#[error("The webfinger identifier is invalid")]
WrongFormat,
/// The webfinger identifier doesn't match the expected instance domain name
#[error("The webfinger identifier doesn't match the expected instance domain name")]
WrongDomain,
/// The wefinger object did not contain any link to an activitypub item
#[error("The webfinger object did not contain any link to an activitypub item")]
NoValidLink,
}
impl WebFingerError {
fn into_crate_error(self) -> Error {
self.into()
}
}
/// Takes an identifier of the form `name@example.com`, and returns an object of `Kind`. /// Takes an identifier of the form `name@example.com`, and returns an object of `Kind`.
/// ///
/// For this the identifier is first resolved via webfinger protocol to an Activitypub ID. This ID /// For this the identifier is first resolved via webfinger protocol to an Activitypub ID. This ID
@ -24,22 +44,24 @@ pub async fn webfinger_resolve_actor<T: Clone, Kind>(
where where
Kind: Object + Actor + Send + 'static + Object<DataType = T>, Kind: Object + Actor + Send + 'static + Object<DataType = T>,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
<Kind as Object>::Error: <Kind as Object>::Error: From<crate::error::Error> + Send + Sync + Display,
From<crate::error::Error> + From<anyhow::Error> + From<url::ParseError> + Send + Sync,
{ {
let (_, domain) = identifier let (_, domain) = identifier
.splitn(2, '@') .splitn(2, '@')
.collect_tuple() .collect_tuple()
.ok_or(WebfingerResolveFailed)?; .ok_or(WebFingerError::WrongFormat.into_crate_error())?;
let protocol = if data.config.debug { "http" } else { "https" }; let protocol = if data.config.debug { "http" } else { "https" };
let fetch_url = let fetch_url =
format!("{protocol}://{domain}/.well-known/webfinger?resource=acct:{identifier}"); format!("{protocol}://{domain}/.well-known/webfinger?resource=acct:{identifier}");
debug!("Fetching webfinger url: {}", &fetch_url); debug!("Fetching webfinger url: {}", &fetch_url);
let res: Webfinger = let res: Webfinger = fetch_object_http_with_accept(
fetch_object_http_with_accept(&Url::parse(&fetch_url)?, data, "application/jrd+json") &Url::parse(&fetch_url).map_err(Error::UrlParse)?,
.await? data,
.object; "application/jrd+json",
)
.await?
.object;
debug_assert_eq!(res.subject, format!("acct:{identifier}")); debug_assert_eq!(res.subject, format!("acct:{identifier}"));
let links: Vec<Url> = res let links: Vec<Url> = res
@ -54,13 +76,15 @@ where
}) })
.filter_map(|l| l.href.clone()) .filter_map(|l| l.href.clone())
.collect(); .collect();
for l in links { for l in links {
let object = ObjectId::<Kind>::from(l).dereference(data).await; let object = ObjectId::<Kind>::from(l).dereference(data).await;
if object.is_ok() { match object {
return object; Ok(obj) => return Ok(obj),
Err(error) => debug!(%error, "Failed to dereference link"),
} }
} }
Err(WebfingerResolveFailed.into()) Err(WebFingerError::NoValidLink.into_crate_error().into())
} }
/// Extracts username from a webfinger resource parameter. /// Extracts username from a webfinger resource parameter.
@ -88,20 +112,24 @@ where
/// # Ok::<(), anyhow::Error>(()) /// # Ok::<(), anyhow::Error>(())
/// }).unwrap(); /// }).unwrap();
///``` ///```
pub fn extract_webfinger_name<T>(query: &str, data: &Data<T>) -> Result<String, Error> pub fn extract_webfinger_name<'i, T>(query: &'i str, data: &Data<T>) -> Result<&'i str, Error>
where where
T: Clone, T: Clone,
{ {
static WEBFINGER_REGEX: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^acct:([\p{L}0-9_]+)@(.*)$").expect("compile regex"));
// Regex to extract usernames from webfinger query. Supports different alphabets using `\p{L}`. // Regex to extract usernames from webfinger query. Supports different alphabets using `\p{L}`.
// TODO: would be nice if we could implement this without regex and remove the dependency // TODO: This should use a URL parser
let regex = let captures = WEBFINGER_REGEX
Regex::new(&format!(r"^acct:@?([\p{{L}}0-9_]+)@{}$", data.domain())).map_err(Error::other)?;
Ok(regex
.captures(query) .captures(query)
.and_then(|c| c.get(1)) .ok_or(WebFingerError::WrongFormat)?;
.ok_or_else(|| Error::other(anyhow!("Webfinger regex failed to match")))?
.as_str() let account_name = captures.get(1).ok_or(WebFingerError::WrongFormat)?;
.to_string())
if captures.get(2).map(|m| m.as_str()) != Some(data.domain()) {
return Err(WebFingerError::WrongDomain.into());
}
Ok(account_name.as_str())
} }
/// Builds a basic webfinger response for the actor. /// Builds a basic webfinger response for the actor.
@ -249,15 +277,15 @@ mod tests {
request_counter: Default::default(), request_counter: Default::default(),
}; };
assert_eq!( assert_eq!(
Ok("test123".to_string()), Ok("test123"),
extract_webfinger_name("acct:test123@example.com", &data) extract_webfinger_name("acct:test123@example.com", &data)
); );
assert_eq!( assert_eq!(
Ok("Владимир".to_string()), Ok("Владимир"),
extract_webfinger_name("acct:Владимир@example.com", &data) extract_webfinger_name("acct:Владимир@example.com", &data)
); );
assert_eq!( assert_eq!(
Ok("تجريب".to_string()), Ok("تجريب"),
extract_webfinger_name("acct:تجريب@example.com", &data) extract_webfinger_name("acct:تجريب@example.com", &data)
); );
Ok(()) Ok(())

View file

@ -12,11 +12,13 @@ use crate::{
protocol::public_key::main_key_id, protocol::public_key::main_key_id,
traits::{Actor, Object}, traits::{Actor, Object},
}; };
use anyhow::Context;
use base64::{engine::general_purpose::STANDARD as Base64, Engine}; use base64::{engine::general_purpose::STANDARD as Base64, Engine};
use bytes::Bytes; use bytes::Bytes;
use http::{header::HeaderName, uri::PathAndQuery, HeaderValue, Method, Uri}; use http::{uri::PathAndQuery, Uri};
use http_signature_normalization_reqwest::prelude::{Config, SignExt}; use http_signature_normalization_reqwest::{
prelude::{Config, SignExt},
DefaultSpawner,
};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use openssl::{ use openssl::{
hash::MessageDigest, hash::MessageDigest,
@ -24,7 +26,11 @@ use openssl::{
rsa::Rsa, rsa::Rsa,
sign::{Signer, Verifier}, sign::{Signer, Verifier},
}; };
use reqwest::Request; use reqwest::{
header::{HeaderName, HeaderValue},
Method,
Request,
};
use reqwest_middleware::RequestBuilder; use reqwest_middleware::RequestBuilder;
use serde::Deserialize; use serde::Deserialize;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
@ -83,8 +89,9 @@ pub(crate) async fn sign_request(
activity: Bytes, activity: Bytes,
private_key: PKey<Private>, private_key: PKey<Private>,
http_signature_compat: bool, http_signature_compat: bool,
) -> Result<Request, anyhow::Error> { ) -> Result<Request, Error> {
static CONFIG: Lazy<Config> = Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER)); static CONFIG: Lazy<Config<DefaultSpawner>> =
Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER));
static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| { static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| {
Config::new() Config::new()
.mastodon_compat() .mastodon_compat()
@ -103,14 +110,10 @@ pub(crate) async fn sign_request(
Sha256::new(), Sha256::new(),
activity, activity,
move |signing_string| { move |signing_string| {
let mut signer = Signer::new(MessageDigest::sha256(), &private_key) let mut signer = Signer::new(MessageDigest::sha256(), &private_key)?;
.context("instantiating signer")?; signer.update(signing_string.as_bytes())?;
signer
.update(signing_string.as_bytes())
.context("updating signer")?;
Ok(Base64.encode(signer.sign_to_vec().context("sign to vec")?)) Ok(Base64.encode(signer.sign_to_vec()?)) as Result<_, Error>
as Result<_, anyhow::Error>
}, },
) )
.await .await
@ -152,7 +155,7 @@ pub(crate) async fn signing_actor<'a, A, H>(
) -> Result<A, <A as Object>::Error> ) -> Result<A, <A as Object>::Error>
where where
A: Object + Actor, A: Object + Actor,
<A as Object>::Error: From<Error> + From<anyhow::Error>, <A as Object>::Error: From<Error>,
for<'de2> <A as Object>::Kind: Deserialize<'de2>, for<'de2> <A as Object>::Kind: Deserialize<'de2>,
H: IntoIterator<Item = (&'a HeaderName, &'a HeaderValue)>, H: IntoIterator<Item = (&'a HeaderName, &'a HeaderValue)>,
{ {
@ -197,8 +200,8 @@ fn verify_signature_inner(
let verified = CONFIG let verified = CONFIG
.begin_verify(method.as_str(), path_and_query, header_map) .begin_verify(method.as_str(), path_and_query, header_map)
.map_err(Error::other)? .map_err(|val| Error::Other(val.to_string()))?
.verify(|signature, signing_string| -> anyhow::Result<bool> { .verify(|signature, signing_string| -> Result<bool, Error> {
debug!( debug!(
"Verifying with key {}, message {}", "Verifying with key {}, message {}",
&public_key, &signing_string &public_key, &signing_string
@ -206,9 +209,13 @@ fn verify_signature_inner(
let public_key = PKey::public_key_from_pem(public_key.as_bytes())?; let public_key = PKey::public_key_from_pem(public_key.as_bytes())?;
let mut verifier = Verifier::new(MessageDigest::sha256(), &public_key)?; let mut verifier = Verifier::new(MessageDigest::sha256(), &public_key)?;
verifier.update(signing_string.as_bytes())?; verifier.update(signing_string.as_bytes())?;
Ok(verifier.verify(&Base64.decode(signature)?)?)
}) let base64_decoded = Base64
.map_err(Error::other)?; .decode(signature)
.map_err(|err| Error::Other(err.to_string()))?;
Ok(verifier.verify(&base64_decoded)?)
})?;
if verified { if verified {
debug!("verified signature for {}", uri); debug!("verified signature for {}", uri);
@ -289,7 +296,7 @@ pub mod test {
// use hardcoded date in order to test against hardcoded signature // use hardcoded date in order to test against hardcoded signature
headers.insert( headers.insert(
"date", "date",
HeaderValue::from_str("Tue, 28 Mar 2023 21:03:44 GMT").unwrap(), reqwest::header::HeaderValue::from_str("Tue, 28 Mar 2023 21:03:44 GMT").unwrap(),
); );
let request_builder = ClientWithMiddleware::from(Client::new()) let request_builder = ClientWithMiddleware::from(Client::new())

View file

@ -23,7 +23,46 @@ pub mod protocol;
pub(crate) mod reqwest_shim; pub(crate) mod reqwest_shim;
pub mod traits; pub mod traits;
use crate::{
config::Data,
error::Error,
fetch::object_id::ObjectId,
traits::{ActivityHandler, Actor, Object},
};
pub use activitystreams_kinds as kinds; pub use activitystreams_kinds as kinds;
use serde::{de::DeserializeOwned, Deserialize};
use url::Url;
/// Mime type for Activitypub data, used for `Accept` and `Content-Type` HTTP headers /// Mime type for Activitypub data, used for `Accept` and `Content-Type` HTTP headers
pub static FEDERATION_CONTENT_TYPE: &str = "application/activity+json"; pub static FEDERATION_CONTENT_TYPE: &str = "application/activity+json";
/// Deserialize incoming inbox activity to the given type, perform basic
/// validation and extract the actor.
async fn parse_received_activity<Activity, ActorT, Datatype>(
body: &[u8],
data: &Data<Datatype>,
) -> Result<(Activity, ActorT), <Activity as ActivityHandler>::Error>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
let activity: Activity = serde_json::from_slice(body).map_err(|e| {
// Attempt to include activity id in error message
#[derive(Deserialize)]
struct Id {
id: Url,
}
let id = serde_json::from_slice::<Id>(body).ok();
Error::ParseReceivedActivity(e, id.map(|i| i.id))
})?;
data.config.verify_url_and_domain(&activity).await?;
let actor = ObjectId::<ActorT>::from(activity.actor().clone())
.dereference(data)
.await?;
Ok((activity, actor))
}

View file

@ -15,25 +15,23 @@
//! }; //! };
//! let note_with_context = WithContext::new_default(note); //! let note_with_context = WithContext::new_default(note);
//! let serialized = serde_json::to_string(&note_with_context)?; //! let serialized = serde_json::to_string(&note_with_context)?;
//! assert_eq!(serialized, r#"{"@context":["https://www.w3.org/ns/activitystreams"],"content":"Hello world"}"#); //! assert_eq!(serialized, r#"{"@context":"https://www.w3.org/ns/activitystreams","content":"Hello world"}"#);
//! Ok::<(), serde_json::error::Error>(()) //! Ok::<(), serde_json::error::Error>(())
//! ``` //! ```
use crate::{config::Data, protocol::helpers::deserialize_one_or_many, traits::ActivityHandler}; use crate::{config::Data, traits::ActivityHandler};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use url::Url; use url::Url;
/// Default context used in Activitypub /// Default context used in Activitypub
const DEFAULT_CONTEXT: &str = "https://www.w3.org/ns/activitystreams"; const DEFAULT_CONTEXT: &str = "https://www.w3.org/ns/activitystreams";
const DEFAULT_SECURITY_CONTEXT: &str = "https://w3id.org/security/v1";
/// Wrapper for federated structs which handles `@context` field. /// Wrapper for federated structs which handles `@context` field.
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct WithContext<T> { pub struct WithContext<T> {
#[serde(rename = "@context")] #[serde(rename = "@context")]
#[serde(deserialize_with = "deserialize_one_or_many")] context: Value,
context: Vec<Value>,
#[serde(flatten)] #[serde(flatten)]
inner: T, inner: T,
} }
@ -41,15 +39,12 @@ pub struct WithContext<T> {
impl<T> WithContext<T> { impl<T> WithContext<T> {
/// Create a new wrapper with the default Activitypub context. /// Create a new wrapper with the default Activitypub context.
pub fn new_default(inner: T) -> WithContext<T> { pub fn new_default(inner: T) -> WithContext<T> {
let context = vec![ let context = Value::String(DEFAULT_CONTEXT.to_string());
Value::String(DEFAULT_CONTEXT.to_string()),
Value::String(DEFAULT_SECURITY_CONTEXT.to_string()),
];
WithContext::new(inner, context) WithContext::new(inner, context)
} }
/// Create new wrapper with custom context. Use this in case you are implementing extensions. /// Create new wrapper with custom context. Use this in case you are implementing extensions.
pub fn new(inner: T, context: Vec<Value>) -> WithContext<T> { pub fn new(inner: T, context: Value) -> WithContext<T> {
WithContext { context, inner } WithContext { context, inner }
} }

View file

@ -56,12 +56,12 @@ where
/// #[derive(serde::Deserialize)] /// #[derive(serde::Deserialize)]
/// struct Note { /// struct Note {
/// #[serde(deserialize_with = "deserialize_one")] /// #[serde(deserialize_with = "deserialize_one")]
/// to: Url /// to: [Url; 1]
/// } /// }
/// ///
/// let note = serde_json::from_str::<Note>(r#"{"to": ["https://example.com/u/alice"] }"#); /// let note = serde_json::from_str::<Note>(r#"{"to": ["https://example.com/u/alice"] }"#);
/// assert!(note.is_ok()); /// assert!(note.is_ok());
pub fn deserialize_one<'de, T, D>(deserializer: D) -> Result<T, D::Error> pub fn deserialize_one<'de, T, D>(deserializer: D) -> Result<[T; 1], D::Error>
where where
T: Deserialize<'de>, T: Deserialize<'de>,
D: Deserializer<'de>, D: Deserializer<'de>,
@ -75,8 +75,8 @@ where
let result: MaybeArray<T> = Deserialize::deserialize(deserializer)?; let result: MaybeArray<T> = Deserialize::deserialize(deserializer)?;
Ok(match result { Ok(match result {
MaybeArray::Simple(value) => value, MaybeArray::Simple(value) => [value],
MaybeArray::Array([value]) => value, MaybeArray::Array([value]) => [value],
}) })
} }
@ -125,7 +125,7 @@ mod tests {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct Note { struct Note {
#[serde(deserialize_with = "deserialize_one")] #[serde(deserialize_with = "deserialize_one")]
_to: Url, _to: [Url; 1],
} }
let note = serde_json::from_str::<Note>( let note = serde_json::from_str::<Note>(

View file

@ -6,7 +6,7 @@ use url::Url;
/// Public key of actors which is used for HTTP signatures. /// Public key of actors which is used for HTTP signatures.
/// ///
/// This needs to be federated in the `public_key` field of all actors. /// This needs to be federated in the `public_key` field of all actors.
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct PublicKey { pub struct PublicKey {
/// Id of this private key. /// Id of this private key.

View file

@ -35,7 +35,7 @@ use serde::{Deserialize, Serialize};
/// Media type for markdown text. /// Media type for markdown text.
/// ///
/// <https://www.iana.org/assignments/media-types/media-types.xhtml> /// <https://www.iana.org/assignments/media-types/media-types.xhtml>
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub enum MediaTypeMarkdown { pub enum MediaTypeMarkdown {
/// `text/markdown` /// `text/markdown`
#[serde(rename = "text/markdown")] #[serde(rename = "text/markdown")]
@ -45,7 +45,7 @@ pub enum MediaTypeMarkdown {
/// Media type for HTML text. /// Media type for HTML text.
/// ///
/// <https://www.iana.org/assignments/media-types/media-types.xhtml> /// <https://www.iana.org/assignments/media-types/media-types.xhtml>
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub enum MediaTypeHtml { pub enum MediaTypeHtml {
/// `text/html` /// `text/html`
#[serde(rename = "text/html")] #[serde(rename = "text/html")]

View file

@ -1,7 +1,6 @@
//! Verify that received data is valid //! Verify that received data is valid
use crate::error::Error; use crate::error::Error;
use anyhow::anyhow;
use url::Url; use url::Url;
/// Check that both urls have the same domain. If not, return UrlVerificationError. /// Check that both urls have the same domain. If not, return UrlVerificationError.
@ -16,7 +15,7 @@ use url::Url;
/// ``` /// ```
pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> { pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> {
if a.domain() != b.domain() { if a.domain() != b.domain() {
return Err(Error::UrlVerificationError(anyhow!("Domains do not match"))); return Err(Error::UrlVerificationError("Domains do not match"));
} }
Ok(()) Ok(())
} }
@ -33,7 +32,7 @@ pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> {
/// ``` /// ```
pub fn verify_urls_match(a: &Url, b: &Url) -> Result<(), Error> { pub fn verify_urls_match(a: &Url, b: &Url) -> Result<(), Error> {
if a != b { if a != b {
return Err(Error::UrlVerificationError(anyhow!("Urls do not match"))); return Err(Error::UrlVerificationError("Urls do not match"));
} }
Ok(()) Ok(())
} }

View file

@ -3,10 +3,8 @@ use bytes::{BufMut, Bytes, BytesMut};
use futures_core::{ready, stream::BoxStream, Stream}; use futures_core::{ready, stream::BoxStream, Stream};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use reqwest::Response; use reqwest::Response;
use serde::de::DeserializeOwned;
use std::{ use std::{
future::Future, future::Future,
marker::PhantomData,
mem, mem,
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
@ -30,10 +28,7 @@ impl Future for BytesFuture {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop { loop {
let this = self.as_mut().project(); let this = self.as_mut().project();
if let Some(chunk) = ready!(this.stream.poll_next(cx)) if let Some(chunk) = ready!(this.stream.poll_next(cx)).transpose()? {
.transpose()
.map_err(Error::other)?
{
this.aggregator.put(chunk); this.aggregator.put(chunk);
if this.aggregator.len() > *this.limit { if this.aggregator.len() > *this.limit {
return Poll::Ready(Err(Error::ResponseBodyLimit)); return Poll::Ready(Err(Error::ResponseBodyLimit));
@ -49,27 +44,6 @@ impl Future for BytesFuture {
} }
} }
pin_project! {
pub struct JsonFuture<T> {
_t: PhantomData<T>,
#[pin]
future: BytesFuture,
}
}
impl<T> Future for JsonFuture<T>
where
T: DeserializeOwned,
{
type Output = Result<T, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let bytes = ready!(this.future.poll(cx))?;
Poll::Ready(serde_json::from_slice(&bytes).map_err(Error::other))
}
}
pin_project! { pin_project! {
pub struct TextFuture { pub struct TextFuture {
#[pin] #[pin]
@ -83,7 +57,7 @@ impl Future for TextFuture {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project(); let this = self.project();
let bytes = ready!(this.future.poll(cx))?; let bytes = ready!(this.future.poll(cx))?;
Poll::Ready(String::from_utf8(bytes.to_vec()).map_err(Error::other)) Poll::Ready(String::from_utf8(bytes.to_vec()).map_err(Error::Utf8))
} }
} }
@ -97,20 +71,16 @@ impl Future for TextFuture {
/// TODO: Remove this shim as soon as reqwest gets support for size-limited bodies. /// TODO: Remove this shim as soon as reqwest gets support for size-limited bodies.
pub trait ResponseExt { pub trait ResponseExt {
type BytesFuture; type BytesFuture;
type JsonFuture<T>;
type TextFuture; type TextFuture;
/// Size limited version of `bytes` to work around a reqwest issue. Check [`ResponseExt`] docs for details. /// Size limited version of `bytes` to work around a reqwest issue. Check [`ResponseExt`] docs for details.
fn bytes_limited(self) -> Self::BytesFuture; fn bytes_limited(self) -> Self::BytesFuture;
/// Size limited version of `json` to work around a reqwest issue. Check [`ResponseExt`] docs for details.
fn json_limited<T>(self) -> Self::JsonFuture<T>;
/// Size limited version of `text` to work around a reqwest issue. Check [`ResponseExt`] docs for details. /// Size limited version of `text` to work around a reqwest issue. Check [`ResponseExt`] docs for details.
fn text_limited(self) -> Self::TextFuture; fn text_limited(self) -> Self::TextFuture;
} }
impl ResponseExt for Response { impl ResponseExt for Response {
type BytesFuture = BytesFuture; type BytesFuture = BytesFuture;
type JsonFuture<T> = JsonFuture<T>;
type TextFuture = TextFuture; type TextFuture = TextFuture;
fn bytes_limited(self) -> Self::BytesFuture { fn bytes_limited(self) -> Self::BytesFuture {
@ -121,13 +91,6 @@ impl ResponseExt for Response {
} }
} }
fn json_limited<T>(self) -> Self::JsonFuture<T> {
JsonFuture {
_t: PhantomData,
future: self.bytes_limited(),
}
}
fn text_limited(self) -> Self::TextFuture { fn text_limited(self) -> Self::TextFuture {
TextFuture { TextFuture {
future: self.bytes_limited(), future: self.bytes_limited(),

View file

@ -340,12 +340,12 @@ pub trait Collection: Sized {
pub mod tests { pub mod tests {
use super::*; use super::*;
use crate::{ use crate::{
error::Error,
fetch::object_id::ObjectId, fetch::object_id::ObjectId,
http_signatures::{generate_actor_keypair, Keypair}, http_signatures::{generate_actor_keypair, Keypair},
protocol::{public_key::PublicKey, verification::verify_domains_match}, protocol::{public_key::PublicKey, verification::verify_domains_match},
}; };
use activitystreams_kinds::{activity::FollowType, actor::PersonType}; use activitystreams_kinds::{activity::FollowType, actor::PersonType};
use anyhow::Error;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -356,7 +356,7 @@ pub mod tests {
pub async fn read_post_from_json_id<T>(&self, _: Url) -> Result<Option<T>, Error> { pub async fn read_post_from_json_id<T>(&self, _: Url) -> Result<Option<T>, Error> {
Ok(None) Ok(None)
} }
pub async fn read_local_user(&self, _: String) -> Result<DbUser, Error> { pub async fn read_local_user(&self, _: &str) -> Result<DbUser, Error> {
todo!() todo!()
} }
pub async fn upsert<T>(&self, _: &T) -> Result<(), Error> { pub async fn upsert<T>(&self, _: &T) -> Result<(), Error> {