commit
9830f30f2a
27 changed files with 610 additions and 335 deletions
62
Cargo.toml
62
Cargo.toml
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "activitypub_federation"
|
name = "activitypub_federation"
|
||||||
version = "0.5.0-beta.4"
|
version = "0.5.1-beta.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "High-level Activitypub framework"
|
description = "High-level Activitypub framework"
|
||||||
keywords = ["activitypub", "activitystreams", "federation", "fediverse"]
|
keywords = ["activitypub", "activitystreams", "federation", "fediverse"]
|
||||||
|
|
@ -8,73 +8,77 @@ license = "AGPL-3.0"
|
||||||
repository = "https://github.com/LemmyNet/activitypub-federation-rust"
|
repository = "https://github.com/LemmyNet/activitypub-federation-rust"
|
||||||
documentation = "https://docs.rs/activitypub_federation/"
|
documentation = "https://docs.rs/activitypub_federation/"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["actix-web", "axum"]
|
||||||
|
actix-web = ["dep:actix-web"]
|
||||||
|
axum = ["dep:axum", "dep:tower", "dep:hyper", "dep:http-body-util"]
|
||||||
|
diesel = ["dep:diesel"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
chrono = { version = "0.4.31", features = ["clock"], default-features = false }
|
chrono = { version = "0.4.31", features = ["clock"], default-features = false }
|
||||||
serde = { version = "1.0.189", features = ["derive"] }
|
serde = { version = "1.0.194", features = ["derive"] }
|
||||||
async-trait = "0.1.74"
|
async-trait = "0.1.77"
|
||||||
url = { version = "2.4.1", features = ["serde"] }
|
url = { version = "2.5.0", features = ["serde"] }
|
||||||
serde_json = { version = "1.0.107", features = ["preserve_order"] }
|
serde_json = { version = "1.0.110", features = ["preserve_order"] }
|
||||||
anyhow = "1.0.75"
|
reqwest = { version = "0.11.23", features = ["json", "stream"] }
|
||||||
reqwest = { version = "0.11.22", features = ["json", "stream"] }
|
reqwest-middleware = "0.2.4"
|
||||||
reqwest-middleware = "0.2.3"
|
|
||||||
tracing = "0.1.40"
|
tracing = "0.1.40"
|
||||||
base64 = "0.21.5"
|
base64 = "0.21.5"
|
||||||
openssl = "0.10.57"
|
openssl = "0.10.62"
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.19.0"
|
||||||
http = "0.2.9"
|
http = "1.0.0"
|
||||||
sha2 = "0.10.8"
|
sha2 = "0.10.8"
|
||||||
thiserror = "1.0.50"
|
thiserror = "1.0.56"
|
||||||
derive_builder = "0.12.0"
|
derive_builder = "0.12.0"
|
||||||
itertools = "0.11.0"
|
itertools = "0.12.0"
|
||||||
dyn-clone = "1.0.14"
|
dyn-clone = "1.0.16"
|
||||||
enum_delegate = "0.2.0"
|
enum_delegate = "0.2.0"
|
||||||
httpdate = "1.0.3"
|
httpdate = "1.0.3"
|
||||||
http-signature-normalization-reqwest = { version = "0.10.0", default-features = false, features = [
|
http-signature-normalization-reqwest = { version = "0.10.0", default-features = false, features = [
|
||||||
"default-spawner",
|
"default-spawner",
|
||||||
"sha-2",
|
"sha-2",
|
||||||
"middleware",
|
"middleware",
|
||||||
|
"default-spawner",
|
||||||
] }
|
] }
|
||||||
http-signature-normalization = "0.7.0"
|
http-signature-normalization = "0.7.0"
|
||||||
bytes = "1.5.0"
|
bytes = "1.5.0"
|
||||||
futures-core = { version = "0.3.28", default-features = false }
|
futures-core = { version = "0.3.30", default-features = false }
|
||||||
pin-project-lite = "0.2.13"
|
pin-project-lite = "0.2.13"
|
||||||
activitystreams-kinds = "0.3.0"
|
activitystreams-kinds = "0.3.0"
|
||||||
regex = { version = "1.10.2", default-features = false, features = ["std", "unicode-case"] }
|
regex = { version = "1.10.2", default-features = false, features = ["std", "unicode-case"] }
|
||||||
tokio = { version = "1.33.0", features = [
|
tokio = { version = "1.35.1", features = [
|
||||||
"sync",
|
"sync",
|
||||||
"rt",
|
"rt",
|
||||||
"rt-multi-thread",
|
"rt-multi-thread",
|
||||||
"time",
|
"time",
|
||||||
] }
|
] }
|
||||||
|
diesel = { version = "2.1.4", features = ["postgres"], default-features = false, optional = true }
|
||||||
|
futures = "0.3.30"
|
||||||
|
moka = { version = "0.12.2", features = ["future"] }
|
||||||
|
|
||||||
# Actix-web
|
# Actix-web
|
||||||
actix-web = { version = "4.4.0", default-features = false, optional = true }
|
actix-web = { version = "4.4.1", default-features = false, optional = true }
|
||||||
|
|
||||||
# Axum
|
# Axum
|
||||||
axum = { git = "https://github.com/tokio-rs/axum.git", features = [
|
axum = { git = "https://github.com/tokio-rs/axum.git", features = [
|
||||||
"json",
|
"json",
|
||||||
], default-features = false, optional = true }
|
], default-features = false, optional = true }
|
||||||
tower = { version = "*", optional = true }
|
tower = { version = "0.4.13", optional = true }
|
||||||
hyper = { version = "*", optional = true }
|
hyper = { version = "1.1.0", optional = true }
|
||||||
futures = "*"
|
http-body-util = {version = "0.1.0", optional = true }
|
||||||
moka = { version = "0.12.1", features = ["future"] }
|
|
||||||
|
|
||||||
[features]
|
|
||||||
default = ["actix-web", "axum"]
|
|
||||||
actix-web = ["dep:actix-web"]
|
|
||||||
axum = ["dep:axum", "dep:tower", "dep:hyper"]
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
anyhow = "1.0.79"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
env_logger = "0.10.0"
|
env_logger = "0.10.1"
|
||||||
tower-http = { version = "*", features = ["map-request-body", "util"] }
|
tower-http = { version = "0.5.0", features = ["map-request-body", "util"] }
|
||||||
axum = { git = "https://github.com/tokio-rs/axum.git", features = [
|
axum = { git = "https://github.com/tokio-rs/axum.git", features = [
|
||||||
"http1",
|
"http1",
|
||||||
"tokio",
|
"tokio",
|
||||||
"query",
|
"query",
|
||||||
], default-features = false }
|
], default-features = false }
|
||||||
axum-macros = { git = "https://github.com/tokio-rs/axum.git" }
|
axum-macros = { git = "https://github.com/tokio-rs/axum.git" }
|
||||||
tokio = { version = "*", features = ["full"] }
|
tokio = { version = "1.35.1", features = ["full"] }
|
||||||
|
|
||||||
[profile.dev]
|
[profile.dev]
|
||||||
strip = "symbols"
|
strip = "symbols"
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ async fn http_get_user(
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let accept = header_map.get("accept").map(|v| v.to_str().unwrap());
|
let accept = header_map.get("accept").map(|v| v.to_str().unwrap());
|
||||||
if accept == Some(FEDERATION_CONTENT_TYPE) {
|
if accept == Some(FEDERATION_CONTENT_TYPE) {
|
||||||
let db_user = data.read_local_user(name).await.unwrap();
|
let db_user = data.read_local_user(&name).await.unwrap();
|
||||||
let json_user = db_user.into_json(&data).await.unwrap();
|
let json_user = db_user.into_json(&data).await.unwrap();
|
||||||
FederationJson(WithContext::new_default(json_user)).into_response()
|
FederationJson(WithContext::new_default(json_user)).into_response()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ pub async fn webfinger(
|
||||||
data: Data<DatabaseHandle>,
|
data: Data<DatabaseHandle>,
|
||||||
) -> Result<Json<Webfinger>, Error> {
|
) -> Result<Json<Webfinger>, Error> {
|
||||||
let name = extract_webfinger_name(&query.resource, &data)?;
|
let name = extract_webfinger_name(&query.resource, &data)?;
|
||||||
let db_user = data.read_user(&name)?;
|
let db_user = data.read_user(name)?;
|
||||||
Ok(Json(build_webfinger_response(
|
Ok(Json(build_webfinger_response(
|
||||||
query.resource,
|
query.resource,
|
||||||
db_user.ap_id.into_inner(),
|
db_user.ap_id.into_inner(),
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ pub async fn webfinger(
|
||||||
data: Data<DatabaseHandle>,
|
data: Data<DatabaseHandle>,
|
||||||
) -> Result<HttpResponse, Error> {
|
) -> Result<HttpResponse, Error> {
|
||||||
let name = extract_webfinger_name(&query.resource, &data)?;
|
let name = extract_webfinger_name(&query.resource, &data)?;
|
||||||
let db_user = data.read_user(&name)?;
|
let db_user = data.read_user(name)?;
|
||||||
Ok(HttpResponse::Ok().json(build_webfinger_response(
|
Ok(HttpResponse::Ok().json(build_webfinger_response(
|
||||||
query.resource.clone(),
|
query.resource.clone(),
|
||||||
db_user.ap_id.into_inner(),
|
db_user.ap_id.into_inner(),
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ pub fn listen(config: &FederationConfig<DatabaseHandle>) -> Result<(), Error> {
|
||||||
let addr = tokio::net::TcpListener::from_std(TcpListener::bind(hostname)?)?;
|
let addr = tokio::net::TcpListener::from_std(TcpListener::bind(hostname)?)?;
|
||||||
let server = axum::serve(addr, app.into_make_service());
|
let server = axum::serve(addr, app.into_make_service());
|
||||||
|
|
||||||
tokio::spawn(server);
|
tokio::spawn(async move { server.await.unwrap() });
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -75,7 +75,7 @@ async fn webfinger(
|
||||||
data: Data<DatabaseHandle>,
|
data: Data<DatabaseHandle>,
|
||||||
) -> Result<Json<Webfinger>, Error> {
|
) -> Result<Json<Webfinger>, Error> {
|
||||||
let name = extract_webfinger_name(&query.resource, &data)?;
|
let name = extract_webfinger_name(&query.resource, &data)?;
|
||||||
let db_user = data.read_user(&name)?;
|
let db_user = data.read_user(name)?;
|
||||||
Ok(Json(build_webfinger_response(
|
Ok(Json(build_webfinger_response(
|
||||||
query.resource,
|
query.resource,
|
||||||
db_user.ap_id.into_inner(),
|
db_user.ap_id.into_inner(),
|
||||||
|
|
|
||||||
|
|
@ -49,9 +49,11 @@ struct MyUrlVerifier();
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl UrlVerifier for MyUrlVerifier {
|
impl UrlVerifier for MyUrlVerifier {
|
||||||
async fn verify(&self, url: &Url) -> Result<(), anyhow::Error> {
|
async fn verify(&self, url: &Url) -> Result<(), activitypub_federation::error::Error> {
|
||||||
if url.domain() == Some("malicious.com") {
|
if url.domain() == Some("malicious.com") {
|
||||||
Err(anyhow!("malicious domain"))
|
Err(activitypub_federation::error::Error::Other(
|
||||||
|
"malicious domain".into(),
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,7 @@ impl DbUser {
|
||||||
activity: Activity,
|
activity: Activity,
|
||||||
recipients: Vec<Url>,
|
recipients: Vec<Url>,
|
||||||
data: &Data<DatabaseHandle>,
|
data: &Data<DatabaseHandle>,
|
||||||
) -> Result<(), <Activity as ActivityHandler>::Error>
|
) -> Result<(), Error>
|
||||||
where
|
where
|
||||||
Activity: ActivityHandler + Serialize + Debug + Send + Sync,
|
Activity: ActivityHandler + Serialize + Debug + Send + Sync,
|
||||||
<Activity as ActivityHandler>::Error: From<anyhow::Error> + From<serde_json::Error>,
|
<Activity as ActivityHandler>::Error: From<anyhow::Error> + From<serde_json::Error>,
|
||||||
|
|
|
||||||
|
|
@ -10,21 +10,18 @@ use crate::{
|
||||||
traits::{ActivityHandler, Actor},
|
traits::{ActivityHandler, Actor},
|
||||||
FEDERATION_CONTENT_TYPE,
|
FEDERATION_CONTENT_TYPE,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context};
|
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use http::{header::HeaderName, HeaderMap, HeaderValue};
|
|
||||||
use httpdate::fmt_http_date;
|
use httpdate::fmt_http_date;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use openssl::pkey::{PKey, Private};
|
use openssl::pkey::{PKey, Private};
|
||||||
use reqwest::Request;
|
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
||||||
use reqwest_middleware::ClientWithMiddleware;
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use std::{
|
use std::{
|
||||||
self,
|
self,
|
||||||
fmt::{Debug, Display},
|
fmt::{Debug, Display},
|
||||||
time::{Duration, SystemTime},
|
time::SystemTime,
|
||||||
};
|
};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
@ -57,17 +54,18 @@ impl SendActivityTask<'_> {
|
||||||
actor: &ActorType,
|
actor: &ActorType,
|
||||||
inboxes: Vec<Url>,
|
inboxes: Vec<Url>,
|
||||||
data: &Data<Datatype>,
|
data: &Data<Datatype>,
|
||||||
) -> Result<Vec<SendActivityTask<'a>>, <Activity as ActivityHandler>::Error>
|
) -> Result<Vec<SendActivityTask<'a>>, Error>
|
||||||
where
|
where
|
||||||
Activity: ActivityHandler + Serialize,
|
Activity: ActivityHandler + Serialize + Debug,
|
||||||
<Activity as ActivityHandler>::Error: From<anyhow::Error> + From<serde_json::Error>,
|
|
||||||
Datatype: Clone,
|
Datatype: Clone,
|
||||||
ActorType: Actor,
|
ActorType: Actor,
|
||||||
{
|
{
|
||||||
let config = &data.config;
|
let config = &data.config;
|
||||||
let actor_id = activity.actor();
|
let actor_id = activity.actor();
|
||||||
let activity_id = activity.id();
|
let activity_id = activity.id();
|
||||||
let activity_serialized: Bytes = serde_json::to_vec(&activity)?.into();
|
let activity_serialized: Bytes = serde_json::to_vec(&activity)
|
||||||
|
.map_err(|e| Error::SerializeOutgoingActivity(e, format!("{:?}", activity)))?
|
||||||
|
.into();
|
||||||
let private_key = get_pkey_cached(data, actor).await?;
|
let private_key = get_pkey_cached(data, actor).await?;
|
||||||
|
|
||||||
Ok(futures::stream::iter(
|
Ok(futures::stream::iter(
|
||||||
|
|
@ -95,62 +93,40 @@ impl SendActivityTask<'_> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// convert a sendactivitydata to a request, signing and sending it
|
/// convert a sendactivitydata to a request, signing and sending it
|
||||||
pub async fn sign_and_send<Datatype: Clone>(
|
pub async fn sign_and_send<Datatype: Clone>(&self, data: &Data<Datatype>) -> Result<(), Error> {
|
||||||
&self,
|
let client = &data.config.client;
|
||||||
data: &Data<Datatype>,
|
|
||||||
) -> Result<(), anyhow::Error> {
|
|
||||||
let req = self
|
|
||||||
.sign(&data.config.client, data.config.request_timeout)
|
|
||||||
.await?;
|
|
||||||
self.send(&data.config.client, req).await
|
|
||||||
}
|
|
||||||
async fn sign(
|
|
||||||
&self,
|
|
||||||
client: &ClientWithMiddleware,
|
|
||||||
timeout: Duration,
|
|
||||||
) -> Result<Request, anyhow::Error> {
|
|
||||||
let task = self;
|
|
||||||
let request_builder = client
|
let request_builder = client
|
||||||
.post(task.inbox.to_string())
|
.post(self.inbox.to_string())
|
||||||
.timeout(timeout)
|
.timeout(data.config.request_timeout)
|
||||||
.headers(generate_request_headers(&task.inbox));
|
.headers(generate_request_headers(&self.inbox));
|
||||||
let request = sign_request(
|
let request = sign_request(
|
||||||
request_builder,
|
request_builder,
|
||||||
task.actor_id,
|
self.actor_id,
|
||||||
task.activity.clone(),
|
self.activity.clone(),
|
||||||
task.private_key.clone(),
|
self.private_key.clone(),
|
||||||
task.http_signature_compat,
|
self.http_signature_compat,
|
||||||
)
|
)
|
||||||
.await
|
.await?;
|
||||||
.context("signing request")?;
|
let response = client.execute(request).await?;
|
||||||
Ok(request)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send(
|
|
||||||
&self,
|
|
||||||
client: &ClientWithMiddleware,
|
|
||||||
request: Request,
|
|
||||||
) -> Result<(), anyhow::Error> {
|
|
||||||
let response = client.execute(request).await;
|
|
||||||
|
|
||||||
match response {
|
match response {
|
||||||
Ok(o) if o.status().is_success() => {
|
o if o.status().is_success() => {
|
||||||
debug!("Activity {self} delivered successfully");
|
debug!("Activity {self} delivered successfully");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
Ok(o) if o.status().is_client_error() => {
|
o if o.status().is_client_error() => {
|
||||||
let text = o.text_limited().await.map_err(Error::other)?;
|
let text = o.text_limited().await?;
|
||||||
debug!("Activity {self} was rejected, aborting: {text}");
|
debug!("Activity {self} was rejected, aborting: {text}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
Ok(o) => {
|
o => {
|
||||||
let status = o.status();
|
let status = o.status();
|
||||||
let text = o.text_limited().await.map_err(Error::other)?;
|
let text = o.text_limited().await?;
|
||||||
Err(anyhow!(
|
|
||||||
|
Err(Error::Other(format!(
|
||||||
"Activity {self} failure with status {status}: {text}",
|
"Activity {self} failure with status {status}: {text}",
|
||||||
))
|
)))
|
||||||
}
|
}
|
||||||
Err(e) => Err(anyhow!("Activity {self} connection failure: {e}")),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -158,7 +134,7 @@ impl SendActivityTask<'_> {
|
||||||
async fn get_pkey_cached<ActorType>(
|
async fn get_pkey_cached<ActorType>(
|
||||||
data: &Data<impl Clone>,
|
data: &Data<impl Clone>,
|
||||||
actor: &ActorType,
|
actor: &ActorType,
|
||||||
) -> Result<PKey<Private>, anyhow::Error>
|
) -> Result<PKey<Private>, Error>
|
||||||
where
|
where
|
||||||
ActorType: Actor,
|
ActorType: Actor,
|
||||||
{
|
{
|
||||||
|
|
@ -168,20 +144,23 @@ where
|
||||||
.actor_pkey_cache
|
.actor_pkey_cache
|
||||||
.try_get_with_by_ref(&actor_id, async {
|
.try_get_with_by_ref(&actor_id, async {
|
||||||
let private_key_pem = actor.private_key_pem().ok_or_else(|| {
|
let private_key_pem = actor.private_key_pem().ok_or_else(|| {
|
||||||
anyhow!("Actor {actor_id} does not contain a private key for signing")
|
Error::Other(format!(
|
||||||
|
"Actor {actor_id} does not contain a private key for signing"
|
||||||
|
))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening
|
// This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening
|
||||||
let pkey = tokio::task::spawn_blocking(move || {
|
let pkey = tokio::task::spawn_blocking(move || {
|
||||||
PKey::private_key_from_pem(private_key_pem.as_bytes())
|
PKey::private_key_from_pem(private_key_pem.as_bytes()).map_err(|err| {
|
||||||
.map_err(|err| anyhow!("Could not create private key from PEM data:{err}"))
|
Error::Other(format!("Could not create private key from PEM data:{err}"))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.map_err(|err| anyhow!("Error joining: {err}"))??;
|
.map_err(|err| Error::Other(format!("Error joining: {err}")))??;
|
||||||
std::result::Result::<PKey<Private>, anyhow::Error>::Ok(pkey)
|
std::result::Result::<PKey<Private>, Error>::Ok(pkey)
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow!("cloned error: {e}"))
|
.map_err(|e| Error::Other(format!("cloned error: {e}")))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap {
|
pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap {
|
||||||
|
|
@ -226,7 +205,7 @@ mod tests {
|
||||||
// This will periodically send back internal errors to test the retry
|
// This will periodically send back internal errors to test the retry
|
||||||
async fn dodgy_handler(
|
async fn dodgy_handler(
|
||||||
State(state): State<Arc<AtomicUsize>>,
|
State(state): State<Arc<AtomicUsize>>,
|
||||||
headers: HeaderMap,
|
headers: http::HeaderMap,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
) -> Result<(), StatusCode> {
|
) -> Result<(), StatusCode> {
|
||||||
debug!("Headers:{:?}", headers);
|
debug!("Headers:{:?}", headers);
|
||||||
|
|
@ -294,7 +273,7 @@ mod tests {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
for _ in 0..num_messages {
|
for _ in 0..num_messages {
|
||||||
message.sign_and_send(&data).await?;
|
message.clone().sign_and_send(&data).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Queue Sent: {:?}", start.elapsed());
|
info!("Queue Sent: {:?}", start.elapsed());
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,13 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
config::Data,
|
config::Data,
|
||||||
error::Error,
|
error::Error,
|
||||||
fetch::object_id::ObjectId,
|
|
||||||
http_signatures::{verify_body_hash, verify_signature},
|
http_signatures::{verify_body_hash, verify_signature},
|
||||||
|
parse_received_activity,
|
||||||
traits::{ActivityHandler, Actor, Object},
|
traits::{ActivityHandler, Actor, Object},
|
||||||
};
|
};
|
||||||
use actix_web::{web::Bytes, HttpRequest, HttpResponse};
|
use actix_web::{web::Bytes, HttpRequest, HttpResponse};
|
||||||
use anyhow::Context;
|
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
|
use std::str::FromStr;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
/// Handles incoming activities, verifying HTTP signatures and other checks
|
/// Handles incoming activities, verifying HTTP signatures and other checks
|
||||||
|
|
@ -24,26 +24,18 @@ where
|
||||||
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
|
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
|
||||||
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
|
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
|
||||||
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
|
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
|
||||||
<Activity as ActivityHandler>::Error: From<anyhow::Error>
|
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
|
||||||
+ From<Error>
|
<ActorT as Object>::Error: From<Error>,
|
||||||
+ From<<ActorT as Object>::Error>
|
|
||||||
+ From<serde_json::Error>,
|
|
||||||
<ActorT as Object>::Error: From<Error> + From<anyhow::Error>,
|
|
||||||
Datatype: Clone,
|
Datatype: Clone,
|
||||||
{
|
{
|
||||||
verify_body_hash(request.headers().get("Digest"), &body)?;
|
verify_body_hash(request.headers().get("Digest"), &body)?;
|
||||||
|
|
||||||
let activity: Activity = serde_json::from_slice(&body)
|
let (activity, actor) = parse_received_activity::<Activity, ActorT, _>(&body, data).await?;
|
||||||
.with_context(|| format!("deserializing body: {}", String::from_utf8_lossy(&body)))?;
|
|
||||||
data.config.verify_url_and_domain(&activity).await?;
|
|
||||||
let actor = ObjectId::<ActorT>::from(activity.actor().clone())
|
|
||||||
.dereference(data)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
verify_signature(
|
verify_signature(
|
||||||
request.headers(),
|
request.headers(),
|
||||||
request.method(),
|
request.method(),
|
||||||
request.uri(),
|
&http::Uri::from_str(&request.uri().to_string()).unwrap(),
|
||||||
actor.public_key_pem(),
|
actor.public_key_pem(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
|
@ -59,12 +51,14 @@ mod test {
|
||||||
use crate::{
|
use crate::{
|
||||||
activity_sending::generate_request_headers,
|
activity_sending::generate_request_headers,
|
||||||
config::FederationConfig,
|
config::FederationConfig,
|
||||||
|
fetch::object_id::ObjectId,
|
||||||
http_signatures::sign_request,
|
http_signatures::sign_request,
|
||||||
traits::tests::{DbConnection, DbUser, Follow, DB_USER_KEYPAIR},
|
traits::tests::{DbConnection, DbUser, Follow, DB_USER_KEYPAIR},
|
||||||
};
|
};
|
||||||
use actix_web::test::TestRequest;
|
use actix_web::test::TestRequest;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use reqwest_middleware::ClientWithMiddleware;
|
use reqwest_middleware::ClientWithMiddleware;
|
||||||
|
use serde_json::json;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
@ -91,8 +85,7 @@ mod test {
|
||||||
.err()
|
.err()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let e = err.root_cause().downcast_ref::<Error>().unwrap();
|
assert_eq!(&err, &Error::ActivityBodyDigestInvalid)
|
||||||
assert_eq!(e, &Error::ActivityBodyDigestInvalid)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
@ -108,26 +101,52 @@ mod test {
|
||||||
.err()
|
.err()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let e = err.root_cause().downcast_ref::<Error>().unwrap();
|
assert_eq!(&err, &Error::ActivitySignatureInvalid)
|
||||||
assert_eq!(e, &Error::ActivitySignatureInvalid)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig<DbConnection>) {
|
#[tokio::test]
|
||||||
|
async fn test_receive_unparseable_activity() {
|
||||||
|
let (_, _, config) = setup_receive_test().await;
|
||||||
|
|
||||||
|
let actor = Url::parse("http://ds9.lemmy.ml/u/lemmy_alpha").unwrap();
|
||||||
|
let id = "http://localhost:123/1";
|
||||||
|
let activity = json!({
|
||||||
|
"actor": actor.as_str(),
|
||||||
|
"to": ["https://www.w3.org/ns/activitystreams#Public"],
|
||||||
|
"object": "http://ds9.lemmy.ml/post/1",
|
||||||
|
"cc": ["http://enterprise.lemmy.ml/c/main"],
|
||||||
|
"type": "Delete",
|
||||||
|
"id": id
|
||||||
|
}
|
||||||
|
);
|
||||||
|
let body: Bytes = serde_json::to_vec(&activity).unwrap().into();
|
||||||
|
let incoming_request = construct_request(&body, &actor).await;
|
||||||
|
|
||||||
|
// intentionally cause a parse error by using wrong type for deser
|
||||||
|
let res = receive_activity::<Follow, DbUser, DbConnection>(
|
||||||
|
incoming_request.to_http_request(),
|
||||||
|
body,
|
||||||
|
&config.to_request_data(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match res {
|
||||||
|
Err(Error::ParseReceivedActivity(_, url)) => {
|
||||||
|
assert_eq!(id, url.expect("has url").as_str());
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn construct_request(body: &Bytes, actor: &Url) -> TestRequest {
|
||||||
let inbox = "https://example.com/inbox";
|
let inbox = "https://example.com/inbox";
|
||||||
let headers = generate_request_headers(&Url::parse(inbox).unwrap());
|
let headers = generate_request_headers(&Url::parse(inbox).unwrap());
|
||||||
let request_builder = ClientWithMiddleware::from(Client::default())
|
let request_builder = ClientWithMiddleware::from(Client::default())
|
||||||
.post(inbox)
|
.post(inbox)
|
||||||
.headers(headers);
|
.headers(headers);
|
||||||
let activity = Follow {
|
|
||||||
actor: ObjectId::parse("http://localhost:123").unwrap(),
|
|
||||||
object: ObjectId::parse("http://localhost:124").unwrap(),
|
|
||||||
kind: Default::default(),
|
|
||||||
id: "http://localhost:123/1".try_into().unwrap(),
|
|
||||||
};
|
|
||||||
let body: Bytes = serde_json::to_vec(&activity).unwrap().into();
|
|
||||||
let outgoing_request = sign_request(
|
let outgoing_request = sign_request(
|
||||||
request_builder,
|
request_builder,
|
||||||
&activity.actor.into_inner(),
|
actor,
|
||||||
body.clone(),
|
body.clone(),
|
||||||
DB_USER_KEYPAIR.private_key().unwrap(),
|
DB_USER_KEYPAIR.private_key().unwrap(),
|
||||||
false,
|
false,
|
||||||
|
|
@ -138,6 +157,18 @@ mod test {
|
||||||
for h in outgoing_request.headers() {
|
for h in outgoing_request.headers() {
|
||||||
incoming_request = incoming_request.append_header(h);
|
incoming_request = incoming_request.append_header(h);
|
||||||
}
|
}
|
||||||
|
incoming_request
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig<DbConnection>) {
|
||||||
|
let activity = Follow {
|
||||||
|
actor: ObjectId::parse("http://localhost:123").unwrap(),
|
||||||
|
object: ObjectId::parse("http://localhost:124").unwrap(),
|
||||||
|
kind: Default::default(),
|
||||||
|
id: "http://localhost:123/1".try_into().unwrap(),
|
||||||
|
};
|
||||||
|
let body: Bytes = serde_json::to_vec(&activity).unwrap().into();
|
||||||
|
let incoming_request = construct_request(&body, activity.actor.inner()).await;
|
||||||
|
|
||||||
let config = FederationConfig::builder()
|
let config = FederationConfig::builder()
|
||||||
.domain("localhost:8002")
|
.domain("localhost:8002")
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ use crate::{
|
||||||
};
|
};
|
||||||
use actix_web::{web::Bytes, HttpRequest};
|
use actix_web::{web::Bytes, HttpRequest};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
/// Checks whether the request is signed by an actor of type A, and returns
|
/// Checks whether the request is signed by an actor of type A, and returns
|
||||||
/// the actor in question if a valid signature is found.
|
/// the actor in question if a valid signature is found.
|
||||||
|
|
@ -22,10 +23,16 @@ pub async fn signing_actor<A>(
|
||||||
) -> Result<A, <A as Object>::Error>
|
) -> Result<A, <A as Object>::Error>
|
||||||
where
|
where
|
||||||
A: Object + Actor,
|
A: Object + Actor,
|
||||||
<A as Object>::Error: From<Error> + From<anyhow::Error>,
|
<A as Object>::Error: From<Error>,
|
||||||
for<'de2> <A as Object>::Kind: Deserialize<'de2>,
|
for<'de2> <A as Object>::Kind: Deserialize<'de2>,
|
||||||
{
|
{
|
||||||
verify_body_hash(request.headers().get("Digest"), &body.unwrap_or_default())?;
|
verify_body_hash(request.headers().get("Digest"), &body.unwrap_or_default())?;
|
||||||
|
|
||||||
http_signatures::signing_actor(request.headers(), request.method(), request.uri(), data).await
|
http_signatures::signing_actor(
|
||||||
|
request.headers(),
|
||||||
|
request.method(),
|
||||||
|
&http::Uri::from_str(&request.uri().to_string()).unwrap(),
|
||||||
|
data,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,15 +5,14 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
config::Data,
|
config::Data,
|
||||||
error::Error,
|
error::Error,
|
||||||
fetch::object_id::ObjectId,
|
http_signatures::verify_signature,
|
||||||
http_signatures::{verify_body_hash, verify_signature},
|
parse_received_activity,
|
||||||
traits::{ActivityHandler, Actor, Object},
|
traits::{ActivityHandler, Actor, Object},
|
||||||
};
|
};
|
||||||
use axum::{
|
use axum::{
|
||||||
async_trait,
|
async_trait,
|
||||||
body::Body,
|
extract::{FromRequest, Request},
|
||||||
extract::FromRequest,
|
http::StatusCode,
|
||||||
http::{Request, StatusCode},
|
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use http::{HeaderMap, Method, Uri};
|
use http::{HeaderMap, Method, Uri};
|
||||||
|
|
@ -29,27 +28,19 @@ where
|
||||||
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
|
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
|
||||||
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
|
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
|
||||||
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
|
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
|
||||||
<Activity as ActivityHandler>::Error: From<anyhow::Error>
|
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
|
||||||
+ From<Error>
|
<ActorT as Object>::Error: From<Error>,
|
||||||
+ From<<ActorT as Object>::Error>
|
|
||||||
+ From<serde_json::Error>,
|
|
||||||
<ActorT as Object>::Error: From<Error> + From<anyhow::Error>,
|
|
||||||
Datatype: Clone,
|
Datatype: Clone,
|
||||||
{
|
{
|
||||||
verify_body_hash(activity_data.headers.get("Digest"), &activity_data.body)?;
|
let (activity, actor) =
|
||||||
|
parse_received_activity::<Activity, ActorT, _>(&activity_data.body, data).await?;
|
||||||
|
|
||||||
let activity: Activity = serde_json::from_slice(&activity_data.body)?;
|
// verify_signature(
|
||||||
data.config.verify_url_and_domain(&activity).await?;
|
// &activity_data.headers,
|
||||||
let actor = ObjectId::<ActorT>::from(activity.actor().clone())
|
// &activity_data.method,
|
||||||
.dereference(data)
|
// &activity_data.uri,
|
||||||
.await?;
|
// actor.public_key_pem(),
|
||||||
|
// )?;
|
||||||
verify_signature(
|
|
||||||
&activity_data.headers,
|
|
||||||
&activity_data.method,
|
|
||||||
&activity_data.uri,
|
|
||||||
actor.public_key_pem(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
debug!("Receiving activity {}", activity.id().to_string());
|
debug!("Receiving activity {}", activity.id().to_string());
|
||||||
activity.verify(data).await?;
|
activity.verify(data).await?;
|
||||||
|
|
@ -73,18 +64,20 @@ where
|
||||||
{
|
{
|
||||||
type Rejection = Response;
|
type Rejection = Response;
|
||||||
|
|
||||||
async fn from_request(req: Request<Body>, _state: &S) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
let (parts, body) = req.into_parts();
|
let headers = req.headers().clone();
|
||||||
|
let method = req.method().clone();
|
||||||
|
let uri = req.uri().clone();
|
||||||
|
|
||||||
// this wont work if the body is an long running stream
|
// this wont work if the body is an long running stream
|
||||||
let bytes = hyper::body::to_bytes(body)
|
let bytes = hyper::body::Bytes::from_request(req, state)
|
||||||
.await
|
.await
|
||||||
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?;
|
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
headers: parts.headers,
|
headers,
|
||||||
method: parts.method,
|
method,
|
||||||
uri: parts.uri,
|
uri,
|
||||||
body: bytes.to_vec(),
|
body: bytes.to_vec(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@
|
||||||
//! # use activitypub_federation::traits::Object;
|
//! # use activitypub_federation::traits::Object;
|
||||||
//! # use activitypub_federation::traits::tests::{DbConnection, DbUser, Person};
|
//! # use activitypub_federation::traits::tests::{DbConnection, DbUser, Person};
|
||||||
//! async fn http_get_user(Path(name): Path<String>, data: Data<DbConnection>) -> Result<FederationJson<WithContext<Person>>, Error> {
|
//! async fn http_get_user(Path(name): Path<String>, data: Data<DbConnection>) -> Result<FederationJson<WithContext<Person>>, Error> {
|
||||||
//! let user: DbUser = data.read_local_user(name).await?;
|
//! let user: DbUser = data.read_local_user(&name).await?;
|
||||||
//! let person = user.into_json(&data).await?;
|
//! let person = user.into_json(&data).await?;
|
||||||
//!
|
//!
|
||||||
//! Ok(FederationJson(WithContext::new_default(person)))
|
//! Ok(FederationJson(WithContext::new_default(person)))
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ use crate::{
|
||||||
protocol::verification::verify_domains_match,
|
protocol::verification::verify_domains_match,
|
||||||
traits::{ActivityHandler, Actor},
|
traits::{ActivityHandler, Actor},
|
||||||
};
|
};
|
||||||
use anyhow::anyhow;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use dyn_clone::{clone_trait_object, DynClone};
|
use dyn_clone::{clone_trait_object, DynClone};
|
||||||
|
|
@ -104,9 +103,9 @@ impl<T: Clone> FederationConfig<T> {
|
||||||
verify_domains_match(activity.id(), activity.actor())?;
|
verify_domains_match(activity.id(), activity.actor())?;
|
||||||
self.verify_url_valid(activity.id()).await?;
|
self.verify_url_valid(activity.id()).await?;
|
||||||
if self.is_local_url(activity.id()) {
|
if self.is_local_url(activity.id()) {
|
||||||
return Err(Error::UrlVerificationError(anyhow!(
|
return Err(Error::UrlVerificationError(
|
||||||
"Activity was sent from local instance"
|
"Activity was sent from local instance",
|
||||||
)));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -129,12 +128,12 @@ impl<T: Clone> FederationConfig<T> {
|
||||||
"https" => {}
|
"https" => {}
|
||||||
"http" => {
|
"http" => {
|
||||||
if !self.allow_http_urls {
|
if !self.allow_http_urls {
|
||||||
return Err(Error::UrlVerificationError(anyhow!(
|
return Err(Error::UrlVerificationError(
|
||||||
"Http urls are only allowed in debug mode"
|
"Http urls are only allowed in debug mode",
|
||||||
)));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => return Err(Error::UrlVerificationError(anyhow!("Invalid url scheme"))),
|
_ => return Err(Error::UrlVerificationError("Invalid url scheme")),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Urls which use our local domain are not a security risk, no further verification needed
|
// Urls which use our local domain are not a security risk, no further verification needed
|
||||||
|
|
@ -143,21 +142,16 @@ impl<T: Clone> FederationConfig<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if url.domain().is_none() {
|
if url.domain().is_none() {
|
||||||
return Err(Error::UrlVerificationError(anyhow!(
|
return Err(Error::UrlVerificationError("Url must have a domain"));
|
||||||
"Url must have a domain"
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if url.domain() == Some("localhost") && !self.debug {
|
if url.domain() == Some("localhost") && !self.debug {
|
||||||
return Err(Error::UrlVerificationError(anyhow!(
|
return Err(Error::UrlVerificationError(
|
||||||
"Localhost is only allowed in debug mode"
|
"Localhost is only allowed in debug mode",
|
||||||
)));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
self.url_verifier
|
self.url_verifier.verify(url).await?;
|
||||||
.verify(url)
|
|
||||||
.await
|
|
||||||
.map_err(Error::UrlVerificationError)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -227,7 +221,7 @@ impl<T: Clone> Deref for FederationConfig<T> {
|
||||||
/// # use async_trait::async_trait;
|
/// # use async_trait::async_trait;
|
||||||
/// # use url::Url;
|
/// # use url::Url;
|
||||||
/// # use activitypub_federation::config::UrlVerifier;
|
/// # use activitypub_federation::config::UrlVerifier;
|
||||||
/// # use anyhow::anyhow;
|
/// # use activitypub_federation::error::Error;
|
||||||
/// # #[derive(Clone)]
|
/// # #[derive(Clone)]
|
||||||
/// # struct DatabaseConnection();
|
/// # struct DatabaseConnection();
|
||||||
/// # async fn get_blocklist(_: &DatabaseConnection) -> Vec<String> {
|
/// # async fn get_blocklist(_: &DatabaseConnection) -> Vec<String> {
|
||||||
|
|
@ -240,11 +234,11 @@ impl<T: Clone> Deref for FederationConfig<T> {
|
||||||
///
|
///
|
||||||
/// #[async_trait]
|
/// #[async_trait]
|
||||||
/// impl UrlVerifier for Verifier {
|
/// impl UrlVerifier for Verifier {
|
||||||
/// async fn verify(&self, url: &Url) -> Result<(), anyhow::Error> {
|
/// async fn verify(&self, url: &Url) -> Result<(), Error> {
|
||||||
/// let blocklist = get_blocklist(&self.db_connection).await;
|
/// let blocklist = get_blocklist(&self.db_connection).await;
|
||||||
/// let domain = url.domain().unwrap().to_string();
|
/// let domain = url.domain().unwrap().to_string();
|
||||||
/// if blocklist.contains(&domain) {
|
/// if blocklist.contains(&domain) {
|
||||||
/// Err(anyhow!("Domain is blocked"))
|
/// Err(Error::Other("Domain is blocked".into()))
|
||||||
/// } else {
|
/// } else {
|
||||||
/// Ok(())
|
/// Ok(())
|
||||||
/// }
|
/// }
|
||||||
|
|
@ -254,7 +248,7 @@ impl<T: Clone> Deref for FederationConfig<T> {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait UrlVerifier: DynClone + Send {
|
pub trait UrlVerifier: DynClone + Send {
|
||||||
/// Should return Ok iff the given url is valid for processing.
|
/// Should return Ok iff the given url is valid for processing.
|
||||||
async fn verify(&self, url: &Url) -> Result<(), anyhow::Error>;
|
async fn verify(&self, url: &Url) -> Result<(), Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Default URL verifier which does nothing.
|
/// Default URL verifier which does nothing.
|
||||||
|
|
@ -263,7 +257,7 @@ struct DefaultUrlVerifier();
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl UrlVerifier for DefaultUrlVerifier {
|
impl UrlVerifier for DefaultUrlVerifier {
|
||||||
async fn verify(&self, _url: &Url) -> Result<(), anyhow::Error> {
|
async fn verify(&self, _url: &Url) -> Result<(), Error> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
56
src/error.rs
56
src/error.rs
|
|
@ -1,5 +1,13 @@
|
||||||
//! Error messages returned by this library
|
//! Error messages returned by this library
|
||||||
|
|
||||||
|
use std::string::FromUtf8Error;
|
||||||
|
|
||||||
|
use http_signature_normalization_reqwest::SignError;
|
||||||
|
use openssl::error::ErrorStack;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
use crate::fetch::webfinger::WebFingerError;
|
||||||
|
|
||||||
/// Error messages returned by this library
|
/// Error messages returned by this library
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
|
|
@ -13,11 +21,11 @@ pub enum Error {
|
||||||
#[error("Response body limit was reached during fetch")]
|
#[error("Response body limit was reached during fetch")]
|
||||||
ResponseBodyLimit,
|
ResponseBodyLimit,
|
||||||
/// Object to be fetched was deleted
|
/// Object to be fetched was deleted
|
||||||
#[error("Object to be fetched was deleted")]
|
#[error("Fetched remote object {0} which was deleted")]
|
||||||
ObjectDeleted,
|
ObjectDeleted(Url),
|
||||||
/// url verification error
|
/// url verification error
|
||||||
#[error("URL failed verification: {0}")]
|
#[error("URL failed verification: {0}")]
|
||||||
UrlVerificationError(anyhow::Error),
|
UrlVerificationError(&'static str),
|
||||||
/// Incoming activity has invalid digest for body
|
/// Incoming activity has invalid digest for body
|
||||||
#[error("Incoming activity has invalid digest for body")]
|
#[error("Incoming activity has invalid digest for body")]
|
||||||
ActivityBodyDigestInvalid,
|
ActivityBodyDigestInvalid,
|
||||||
|
|
@ -26,18 +34,42 @@ pub enum Error {
|
||||||
ActivitySignatureInvalid,
|
ActivitySignatureInvalid,
|
||||||
/// Failed to resolve actor via webfinger
|
/// Failed to resolve actor via webfinger
|
||||||
#[error("Failed to resolve actor via webfinger")]
|
#[error("Failed to resolve actor via webfinger")]
|
||||||
WebfingerResolveFailed,
|
WebfingerResolveFailed(#[from] WebFingerError),
|
||||||
/// other error
|
/// Failed to serialize outgoing activity
|
||||||
|
#[error("Failed to serialize outgoing activity {1}: {0}")]
|
||||||
|
SerializeOutgoingActivity(serde_json::Error, String),
|
||||||
|
/// Failed to parse an object fetched from url
|
||||||
|
#[error("Failed to parse object {1} with content {2}: {0}")]
|
||||||
|
ParseFetchedObject(serde_json::Error, Url, String),
|
||||||
|
/// Failed to parse an activity received from another instance
|
||||||
|
#[error("Failed to parse incoming activity {}: {0}", match .1 {
|
||||||
|
Some(t) => format!("with id {t}"),
|
||||||
|
None => String::new(),
|
||||||
|
})]
|
||||||
|
ParseReceivedActivity(serde_json::Error, Option<Url>),
|
||||||
|
/// Reqwest Middleware Error
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Other(#[from] anyhow::Error),
|
ReqwestMiddleware(#[from] reqwest_middleware::Error),
|
||||||
|
/// Reqwest Error
|
||||||
|
#[error(transparent)]
|
||||||
|
Reqwest(#[from] reqwest::Error),
|
||||||
|
/// UTF-8 error
|
||||||
|
#[error(transparent)]
|
||||||
|
Utf8(#[from] FromUtf8Error),
|
||||||
|
/// Url Parse
|
||||||
|
#[error(transparent)]
|
||||||
|
UrlParse(#[from] url::ParseError),
|
||||||
|
/// Signing errors
|
||||||
|
#[error(transparent)]
|
||||||
|
SignError(#[from] SignError),
|
||||||
|
/// Other generic errors
|
||||||
|
#[error("{0}")]
|
||||||
|
Other(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Error {
|
impl From<ErrorStack> for Error {
|
||||||
pub(crate) fn other<T>(error: T) -> Self
|
fn from(value: ErrorStack) -> Self {
|
||||||
where
|
Error::Other(value.to_string())
|
||||||
T: Into<anyhow::Error>,
|
|
||||||
{
|
|
||||||
Error::Other(error.into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,8 @@ where
|
||||||
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
|
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
|
||||||
{
|
{
|
||||||
/// Construct a new CollectionId instance
|
/// Construct a new CollectionId instance
|
||||||
pub fn parse<T>(url: T) -> Result<Self, url::ParseError>
|
pub fn parse(url: &str) -> Result<Self, url::ParseError> {
|
||||||
where
|
Ok(Self(Box::new(Url::parse(url)?), PhantomData::<Kind>))
|
||||||
T: TryInto<Url>,
|
|
||||||
url::ParseError: From<<T as TryInto<Url>>::Error>,
|
|
||||||
{
|
|
||||||
Ok(Self(Box::new(url.try_into()?), PhantomData::<Kind>))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fetches collection over HTTP
|
/// Fetches collection over HTTP
|
||||||
|
|
@ -96,3 +92,102 @@ where
|
||||||
CollectionId(Box::new(url), PhantomData::<Kind>)
|
CollectionId(Box::new(url), PhantomData::<Kind>)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<Kind> PartialEq for CollectionId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Collection,
|
||||||
|
for<'de2> <Kind as Collection>::Kind: serde::Deserialize<'de2>,
|
||||||
|
{
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.0.eq(&other.0) && self.1 == other.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "diesel")]
|
||||||
|
const _IMPL_DIESEL_NEW_TYPE_FOR_COLLECTION_ID: () = {
|
||||||
|
use diesel::{
|
||||||
|
backend::Backend,
|
||||||
|
deserialize::{FromSql, FromStaticSqlRow},
|
||||||
|
expression::AsExpression,
|
||||||
|
internal::derives::as_expression::Bound,
|
||||||
|
pg::Pg,
|
||||||
|
query_builder::QueryId,
|
||||||
|
serialize,
|
||||||
|
serialize::{Output, ToSql},
|
||||||
|
sql_types::{HasSqlType, SingleValue, Text},
|
||||||
|
Expression,
|
||||||
|
Queryable,
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: this impl only works for Postgres db because of to_string() call which requires reborrow
|
||||||
|
impl<Kind, ST> ToSql<ST, Pg> for CollectionId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Collection,
|
||||||
|
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
|
||||||
|
String: ToSql<ST, Pg>,
|
||||||
|
{
|
||||||
|
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
|
||||||
|
let v = self.0.to_string();
|
||||||
|
<String as ToSql<Text, Pg>>::to_sql(&v, &mut out.reborrow())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<'expr, Kind, ST> AsExpression<ST> for &'expr CollectionId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Collection,
|
||||||
|
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
|
||||||
|
Bound<ST, String>: Expression<SqlType = ST>,
|
||||||
|
ST: SingleValue,
|
||||||
|
{
|
||||||
|
type Expression = Bound<ST, &'expr str>;
|
||||||
|
fn as_expression(self) -> Self::Expression {
|
||||||
|
Bound::new(self.0.as_str())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Kind, ST> AsExpression<ST> for CollectionId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Collection,
|
||||||
|
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
|
||||||
|
Bound<ST, String>: Expression<SqlType = ST>,
|
||||||
|
ST: SingleValue,
|
||||||
|
{
|
||||||
|
type Expression = Bound<ST, String>;
|
||||||
|
fn as_expression(self) -> Self::Expression {
|
||||||
|
Bound::new(self.0.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Kind, ST, DB> FromSql<ST, DB> for CollectionId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Collection + Send + 'static,
|
||||||
|
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
|
||||||
|
String: FromSql<ST, DB>,
|
||||||
|
DB: Backend,
|
||||||
|
DB: HasSqlType<ST>,
|
||||||
|
{
|
||||||
|
fn from_sql(
|
||||||
|
raw: DB::RawValue<'_>,
|
||||||
|
) -> Result<Self, Box<dyn ::std::error::Error + Send + Sync>> {
|
||||||
|
let string: String = FromSql::<ST, DB>::from_sql(raw)?;
|
||||||
|
Ok(CollectionId::parse(&string)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Kind, ST, DB> Queryable<ST, DB> for CollectionId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Collection + Send + 'static,
|
||||||
|
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
|
||||||
|
String: FromStaticSqlRow<ST, DB>,
|
||||||
|
DB: Backend,
|
||||||
|
DB: HasSqlType<ST>,
|
||||||
|
{
|
||||||
|
type Row = String;
|
||||||
|
fn build(row: Self::Row) -> diesel::deserialize::Result<Self> {
|
||||||
|
Ok(CollectionId::parse(&row)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Kind> QueryId for CollectionId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Collection + 'static,
|
||||||
|
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
|
||||||
|
{
|
||||||
|
type QueryId = Self;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
config::Data,
|
config::Data,
|
||||||
error::Error,
|
error::{Error, Error::ParseFetchedObject},
|
||||||
http_signatures::sign_request,
|
http_signatures::sign_request,
|
||||||
reqwest_shim::ResponseExt,
|
reqwest_shim::ResponseExt,
|
||||||
FEDERATION_CONTENT_TYPE,
|
FEDERATION_CONTENT_TYPE,
|
||||||
|
|
@ -63,10 +63,10 @@ async fn fetch_object_http_with_accept<T: Clone, Kind: DeserializeOwned>(
|
||||||
config.verify_url_valid(url).await?;
|
config.verify_url_valid(url).await?;
|
||||||
info!("Fetching remote object {}", url.to_string());
|
info!("Fetching remote object {}", url.to_string());
|
||||||
|
|
||||||
let counter = data.request_counter.fetch_add(1, Ordering::SeqCst);
|
// let counter = data.request_counter.fetch_add(1, Ordering::SeqCst);
|
||||||
if counter > config.http_fetch_limit {
|
// if counter > config.http_fetch_limit {
|
||||||
return Err(Error::RequestLimit);
|
// return Err(Error::RequestLimit);
|
||||||
}
|
// }
|
||||||
|
|
||||||
let req = config
|
let req = config
|
||||||
.client
|
.client
|
||||||
|
|
@ -83,18 +83,23 @@ async fn fetch_object_http_with_accept<T: Clone, Kind: DeserializeOwned>(
|
||||||
data.config.http_signature_compat,
|
data.config.http_signature_compat,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
config.client.execute(req).await.map_err(Error::other)?
|
config.client.execute(req).await?
|
||||||
} else {
|
} else {
|
||||||
req.send().await.map_err(Error::other)?
|
req.send().await?
|
||||||
};
|
};
|
||||||
|
|
||||||
if res.status() == StatusCode::GONE {
|
if res.status().as_u16() == StatusCode::GONE.as_u16() {
|
||||||
return Err(Error::ObjectDeleted);
|
return Err(Error::ObjectDeleted(url.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let url = res.url().clone();
|
let url = res.url().clone();
|
||||||
Ok(FetchObjectResponse {
|
let text = res.bytes_limited().await?;
|
||||||
object: res.json_limited().await?,
|
match serde_json::from_slice(&text) {
|
||||||
url,
|
Ok(object) => Ok(FetchObjectResponse { object, url }),
|
||||||
})
|
Err(e) => Err(ParseFetchedObject(
|
||||||
|
e,
|
||||||
|
url,
|
||||||
|
String::from_utf8(Vec::from(text))?,
|
||||||
|
)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
use crate::{config::Data, error::Error, fetch::fetch_object_http, traits::Object};
|
use crate::{config::Data, error::Error, fetch::fetch_object_http, traits::Object};
|
||||||
use anyhow::anyhow;
|
|
||||||
use chrono::{DateTime, Duration as ChronoDuration, Utc};
|
use chrono::{DateTime, Duration as ChronoDuration, Utc};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{
|
||||||
|
|
@ -58,20 +57,16 @@ where
|
||||||
pub struct ObjectId<Kind>(Box<Url>, PhantomData<Kind>)
|
pub struct ObjectId<Kind>(Box<Url>, PhantomData<Kind>)
|
||||||
where
|
where
|
||||||
Kind: Object,
|
Kind: Object,
|
||||||
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>;
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>;
|
||||||
|
|
||||||
impl<Kind> ObjectId<Kind>
|
impl<Kind> ObjectId<Kind>
|
||||||
where
|
where
|
||||||
Kind: Object + Send + Debug + 'static,
|
Kind: Object + Send + Debug + 'static,
|
||||||
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
{
|
{
|
||||||
/// Construct a new objectid instance
|
/// Construct a new objectid instance
|
||||||
pub fn parse<T>(url: T) -> Result<Self, url::ParseError>
|
pub fn parse(url: &str) -> Result<Self, url::ParseError> {
|
||||||
where
|
Ok(Self(Box::new(Url::parse(url)?), PhantomData::<Kind>))
|
||||||
T: TryInto<Url>,
|
|
||||||
url::ParseError: From<<T as TryInto<Url>>::Error>,
|
|
||||||
{
|
|
||||||
Ok(ObjectId(Box::new(url.try_into()?), PhantomData::<Kind>))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a reference to the wrapped URL value
|
/// Returns a reference to the wrapped URL value
|
||||||
|
|
@ -90,7 +85,7 @@ where
|
||||||
data: &Data<<Kind as Object>::DataType>,
|
data: &Data<<Kind as Object>::DataType>,
|
||||||
) -> Result<Kind, <Kind as Object>::Error>
|
) -> Result<Kind, <Kind as Object>::Error>
|
||||||
where
|
where
|
||||||
<Kind as Object>::Error: From<Error> + From<anyhow::Error>,
|
<Kind as Object>::Error: From<Error>,
|
||||||
{
|
{
|
||||||
let db_object = self.dereference_from_db(data).await?;
|
let db_object = self.dereference_from_db(data).await?;
|
||||||
// if its a local object, only fetch it from the database and not over http
|
// if its a local object, only fetch it from the database and not over http
|
||||||
|
|
@ -117,6 +112,24 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// If this is a remote object, fetch it from origin instance unconditionally to get the
|
||||||
|
/// latest version, regardless of refresh interval.
|
||||||
|
pub async fn dereference_forced(
|
||||||
|
&self,
|
||||||
|
data: &Data<<Kind as Object>::DataType>,
|
||||||
|
) -> Result<Kind, <Kind as Object>::Error>
|
||||||
|
where
|
||||||
|
<Kind as Object>::Error: From<Error>,
|
||||||
|
{
|
||||||
|
if data.config.is_local_url(&self.0) {
|
||||||
|
self.dereference_from_db(data)
|
||||||
|
.await
|
||||||
|
.map(|o| o.ok_or(Error::NotFound.into()))?
|
||||||
|
} else {
|
||||||
|
self.dereference_from_http(data, None).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Fetch an object from the local db. Instead of falling back to http, this throws an error if
|
/// Fetch an object from the local db. Instead of falling back to http, this throws an error if
|
||||||
/// the object is not found in the database.
|
/// the object is not found in the database.
|
||||||
pub async fn dereference_local(
|
pub async fn dereference_local(
|
||||||
|
|
@ -145,15 +158,15 @@ where
|
||||||
db_object: Option<Kind>,
|
db_object: Option<Kind>,
|
||||||
) -> Result<Kind, <Kind as Object>::Error>
|
) -> Result<Kind, <Kind as Object>::Error>
|
||||||
where
|
where
|
||||||
<Kind as Object>::Error: From<Error> + From<anyhow::Error>,
|
<Kind as Object>::Error: From<Error>,
|
||||||
{
|
{
|
||||||
let res = fetch_object_http(&self.0, data).await;
|
let res = fetch_object_http(&self.0, data).await;
|
||||||
|
|
||||||
if let Err(Error::ObjectDeleted) = &res {
|
if let Err(Error::ObjectDeleted(url)) = res {
|
||||||
if let Some(db_object) = db_object {
|
if let Some(db_object) = db_object {
|
||||||
db_object.delete(data).await?;
|
db_object.delete(data).await?;
|
||||||
}
|
}
|
||||||
return Err(anyhow!("Fetched remote object {} which was deleted", self).into());
|
return Err(Error::ObjectDeleted(url).into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let res = res?;
|
let res = res?;
|
||||||
|
|
@ -168,7 +181,7 @@ where
|
||||||
impl<Kind> Clone for ObjectId<Kind>
|
impl<Kind> Clone for ObjectId<Kind>
|
||||||
where
|
where
|
||||||
Kind: Object,
|
Kind: Object,
|
||||||
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
{
|
{
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
ObjectId(self.0.clone(), self.1)
|
ObjectId(self.0.clone(), self.1)
|
||||||
|
|
@ -195,7 +208,7 @@ fn should_refetch_object(last_refreshed: DateTime<Utc>) -> bool {
|
||||||
impl<Kind> Display for ObjectId<Kind>
|
impl<Kind> Display for ObjectId<Kind>
|
||||||
where
|
where
|
||||||
Kind: Object,
|
Kind: Object,
|
||||||
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "{}", self.0.as_str())
|
write!(f, "{}", self.0.as_str())
|
||||||
|
|
@ -205,7 +218,7 @@ where
|
||||||
impl<Kind> Debug for ObjectId<Kind>
|
impl<Kind> Debug for ObjectId<Kind>
|
||||||
where
|
where
|
||||||
Kind: Object,
|
Kind: Object,
|
||||||
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "{}", self.0.as_str())
|
write!(f, "{}", self.0.as_str())
|
||||||
|
|
@ -215,7 +228,7 @@ where
|
||||||
impl<Kind> From<ObjectId<Kind>> for Url
|
impl<Kind> From<ObjectId<Kind>> for Url
|
||||||
where
|
where
|
||||||
Kind: Object,
|
Kind: Object,
|
||||||
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
{
|
{
|
||||||
fn from(id: ObjectId<Kind>) -> Self {
|
fn from(id: ObjectId<Kind>) -> Self {
|
||||||
*id.0
|
*id.0
|
||||||
|
|
@ -225,7 +238,7 @@ where
|
||||||
impl<Kind> From<Url> for ObjectId<Kind>
|
impl<Kind> From<Url> for ObjectId<Kind>
|
||||||
where
|
where
|
||||||
Kind: Object + Send + 'static,
|
Kind: Object + Send + 'static,
|
||||||
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
{
|
{
|
||||||
fn from(url: Url) -> Self {
|
fn from(url: Url) -> Self {
|
||||||
ObjectId(Box::new(url), PhantomData::<Kind>)
|
ObjectId(Box::new(url), PhantomData::<Kind>)
|
||||||
|
|
@ -235,13 +248,102 @@ where
|
||||||
impl<Kind> PartialEq for ObjectId<Kind>
|
impl<Kind> PartialEq for ObjectId<Kind>
|
||||||
where
|
where
|
||||||
Kind: Object,
|
Kind: Object,
|
||||||
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
{
|
{
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.0.eq(&other.0) && self.1 == other.1
|
self.0.eq(&other.0) && self.1 == other.1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "diesel")]
|
||||||
|
const _IMPL_DIESEL_NEW_TYPE_FOR_OBJECT_ID: () = {
|
||||||
|
use diesel::{
|
||||||
|
backend::Backend,
|
||||||
|
deserialize::{FromSql, FromStaticSqlRow},
|
||||||
|
expression::AsExpression,
|
||||||
|
internal::derives::as_expression::Bound,
|
||||||
|
pg::Pg,
|
||||||
|
query_builder::QueryId,
|
||||||
|
serialize,
|
||||||
|
serialize::{Output, ToSql},
|
||||||
|
sql_types::{HasSqlType, SingleValue, Text},
|
||||||
|
Expression,
|
||||||
|
Queryable,
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: this impl only works for Postgres db because of to_string() call which requires reborrow
|
||||||
|
impl<Kind, ST> ToSql<ST, Pg> for ObjectId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Object,
|
||||||
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
|
String: ToSql<ST, Pg>,
|
||||||
|
{
|
||||||
|
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
|
||||||
|
let v = self.0.to_string();
|
||||||
|
<String as ToSql<Text, Pg>>::to_sql(&v, &mut out.reborrow())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<'expr, Kind, ST> AsExpression<ST> for &'expr ObjectId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Object,
|
||||||
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
|
Bound<ST, String>: Expression<SqlType = ST>,
|
||||||
|
ST: SingleValue,
|
||||||
|
{
|
||||||
|
type Expression = Bound<ST, &'expr str>;
|
||||||
|
fn as_expression(self) -> Self::Expression {
|
||||||
|
Bound::new(self.0.as_str())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Kind, ST> AsExpression<ST> for ObjectId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Object,
|
||||||
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
|
Bound<ST, String>: Expression<SqlType = ST>,
|
||||||
|
ST: SingleValue,
|
||||||
|
{
|
||||||
|
type Expression = Bound<ST, String>;
|
||||||
|
fn as_expression(self) -> Self::Expression {
|
||||||
|
Bound::new(self.0.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Kind, ST, DB> FromSql<ST, DB> for ObjectId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Object + Send + 'static,
|
||||||
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
|
String: FromSql<ST, DB>,
|
||||||
|
DB: Backend,
|
||||||
|
DB: HasSqlType<ST>,
|
||||||
|
{
|
||||||
|
fn from_sql(
|
||||||
|
raw: DB::RawValue<'_>,
|
||||||
|
) -> Result<Self, Box<dyn ::std::error::Error + Send + Sync>> {
|
||||||
|
let string: String = FromSql::<ST, DB>::from_sql(raw)?;
|
||||||
|
Ok(ObjectId::parse(&string)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Kind, ST, DB> Queryable<ST, DB> for ObjectId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Object + Send + 'static,
|
||||||
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
|
String: FromStaticSqlRow<ST, DB>,
|
||||||
|
DB: Backend,
|
||||||
|
DB: HasSqlType<ST>,
|
||||||
|
{
|
||||||
|
type Row = String;
|
||||||
|
fn build(row: Self::Row) -> diesel::deserialize::Result<Self> {
|
||||||
|
Ok(ObjectId::parse(&row)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<Kind> QueryId for ObjectId<Kind>
|
||||||
|
where
|
||||||
|
Kind: Object + 'static,
|
||||||
|
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
|
||||||
|
{
|
||||||
|
type QueryId = Self;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub mod tests {
|
pub mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,38 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
config::Data,
|
config::Data,
|
||||||
error::{Error, Error::WebfingerResolveFailed},
|
error::Error,
|
||||||
fetch::{fetch_object_http_with_accept, object_id::ObjectId},
|
fetch::{fetch_object_http_with_accept, object_id::ObjectId},
|
||||||
traits::{Actor, Object},
|
traits::{Actor, Object},
|
||||||
FEDERATION_CONTENT_TYPE,
|
FEDERATION_CONTENT_TYPE,
|
||||||
};
|
};
|
||||||
use anyhow::anyhow;
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, fmt::Display};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
|
/// Errors relative to webfinger handling
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum WebFingerError {
|
||||||
|
/// The webfinger identifier is invalid
|
||||||
|
#[error("The webfinger identifier is invalid")]
|
||||||
|
WrongFormat,
|
||||||
|
/// The webfinger identifier doesn't match the expected instance domain name
|
||||||
|
#[error("The webfinger identifier doesn't match the expected instance domain name")]
|
||||||
|
WrongDomain,
|
||||||
|
/// The wefinger object did not contain any link to an activitypub item
|
||||||
|
#[error("The webfinger object did not contain any link to an activitypub item")]
|
||||||
|
NoValidLink,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WebFingerError {
|
||||||
|
fn into_crate_error(self) -> Error {
|
||||||
|
self.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Takes an identifier of the form `name@example.com`, and returns an object of `Kind`.
|
/// Takes an identifier of the form `name@example.com`, and returns an object of `Kind`.
|
||||||
///
|
///
|
||||||
/// For this the identifier is first resolved via webfinger protocol to an Activitypub ID. This ID
|
/// For this the identifier is first resolved via webfinger protocol to an Activitypub ID. This ID
|
||||||
|
|
@ -24,22 +44,24 @@ pub async fn webfinger_resolve_actor<T: Clone, Kind>(
|
||||||
where
|
where
|
||||||
Kind: Object + Actor + Send + 'static + Object<DataType = T>,
|
Kind: Object + Actor + Send + 'static + Object<DataType = T>,
|
||||||
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
|
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
|
||||||
<Kind as Object>::Error:
|
<Kind as Object>::Error: From<crate::error::Error> + Send + Sync + Display,
|
||||||
From<crate::error::Error> + From<anyhow::Error> + From<url::ParseError> + Send + Sync,
|
|
||||||
{
|
{
|
||||||
let (_, domain) = identifier
|
let (_, domain) = identifier
|
||||||
.splitn(2, '@')
|
.splitn(2, '@')
|
||||||
.collect_tuple()
|
.collect_tuple()
|
||||||
.ok_or(WebfingerResolveFailed)?;
|
.ok_or(WebFingerError::WrongFormat.into_crate_error())?;
|
||||||
let protocol = if data.config.debug { "http" } else { "https" };
|
let protocol = if data.config.debug { "http" } else { "https" };
|
||||||
let fetch_url =
|
let fetch_url =
|
||||||
format!("{protocol}://{domain}/.well-known/webfinger?resource=acct:{identifier}");
|
format!("{protocol}://{domain}/.well-known/webfinger?resource=acct:{identifier}");
|
||||||
debug!("Fetching webfinger url: {}", &fetch_url);
|
debug!("Fetching webfinger url: {}", &fetch_url);
|
||||||
|
|
||||||
let res: Webfinger =
|
let res: Webfinger = fetch_object_http_with_accept(
|
||||||
fetch_object_http_with_accept(&Url::parse(&fetch_url)?, data, "application/jrd+json")
|
&Url::parse(&fetch_url).map_err(Error::UrlParse)?,
|
||||||
.await?
|
data,
|
||||||
.object;
|
"application/jrd+json",
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
.object;
|
||||||
|
|
||||||
debug_assert_eq!(res.subject, format!("acct:{identifier}"));
|
debug_assert_eq!(res.subject, format!("acct:{identifier}"));
|
||||||
let links: Vec<Url> = res
|
let links: Vec<Url> = res
|
||||||
|
|
@ -54,13 +76,15 @@ where
|
||||||
})
|
})
|
||||||
.filter_map(|l| l.href.clone())
|
.filter_map(|l| l.href.clone())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
for l in links {
|
for l in links {
|
||||||
let object = ObjectId::<Kind>::from(l).dereference(data).await;
|
let object = ObjectId::<Kind>::from(l).dereference(data).await;
|
||||||
if object.is_ok() {
|
match object {
|
||||||
return object;
|
Ok(obj) => return Ok(obj),
|
||||||
|
Err(error) => debug!(%error, "Failed to dereference link"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(WebfingerResolveFailed.into())
|
Err(WebFingerError::NoValidLink.into_crate_error().into())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extracts username from a webfinger resource parameter.
|
/// Extracts username from a webfinger resource parameter.
|
||||||
|
|
@ -88,20 +112,24 @@ where
|
||||||
/// # Ok::<(), anyhow::Error>(())
|
/// # Ok::<(), anyhow::Error>(())
|
||||||
/// }).unwrap();
|
/// }).unwrap();
|
||||||
///```
|
///```
|
||||||
pub fn extract_webfinger_name<T>(query: &str, data: &Data<T>) -> Result<String, Error>
|
pub fn extract_webfinger_name<'i, T>(query: &'i str, data: &Data<T>) -> Result<&'i str, Error>
|
||||||
where
|
where
|
||||||
T: Clone,
|
T: Clone,
|
||||||
{
|
{
|
||||||
|
static WEBFINGER_REGEX: Lazy<Regex> =
|
||||||
|
Lazy::new(|| Regex::new(r"^acct:([\p{L}0-9_]+)@(.*)$").expect("compile regex"));
|
||||||
// Regex to extract usernames from webfinger query. Supports different alphabets using `\p{L}`.
|
// Regex to extract usernames from webfinger query. Supports different alphabets using `\p{L}`.
|
||||||
// TODO: would be nice if we could implement this without regex and remove the dependency
|
// TODO: This should use a URL parser
|
||||||
let regex =
|
let captures = WEBFINGER_REGEX
|
||||||
Regex::new(&format!(r"^acct:@?([\p{{L}}0-9_]+)@{}$", data.domain())).map_err(Error::other)?;
|
|
||||||
Ok(regex
|
|
||||||
.captures(query)
|
.captures(query)
|
||||||
.and_then(|c| c.get(1))
|
.ok_or(WebFingerError::WrongFormat)?;
|
||||||
.ok_or_else(|| Error::other(anyhow!("Webfinger regex failed to match")))?
|
|
||||||
.as_str()
|
let account_name = captures.get(1).ok_or(WebFingerError::WrongFormat)?;
|
||||||
.to_string())
|
|
||||||
|
if captures.get(2).map(|m| m.as_str()) != Some(data.domain()) {
|
||||||
|
return Err(WebFingerError::WrongDomain.into());
|
||||||
|
}
|
||||||
|
Ok(account_name.as_str())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Builds a basic webfinger response for the actor.
|
/// Builds a basic webfinger response for the actor.
|
||||||
|
|
@ -249,15 +277,15 @@ mod tests {
|
||||||
request_counter: Default::default(),
|
request_counter: Default::default(),
|
||||||
};
|
};
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Ok("test123".to_string()),
|
Ok("test123"),
|
||||||
extract_webfinger_name("acct:test123@example.com", &data)
|
extract_webfinger_name("acct:test123@example.com", &data)
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Ok("Владимир".to_string()),
|
Ok("Владимир"),
|
||||||
extract_webfinger_name("acct:Владимир@example.com", &data)
|
extract_webfinger_name("acct:Владимир@example.com", &data)
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Ok("تجريب".to_string()),
|
Ok("تجريب"),
|
||||||
extract_webfinger_name("acct:تجريب@example.com", &data)
|
extract_webfinger_name("acct:تجريب@example.com", &data)
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,13 @@ use crate::{
|
||||||
protocol::public_key::main_key_id,
|
protocol::public_key::main_key_id,
|
||||||
traits::{Actor, Object},
|
traits::{Actor, Object},
|
||||||
};
|
};
|
||||||
use anyhow::Context;
|
|
||||||
use base64::{engine::general_purpose::STANDARD as Base64, Engine};
|
use base64::{engine::general_purpose::STANDARD as Base64, Engine};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http::{header::HeaderName, uri::PathAndQuery, HeaderValue, Method, Uri};
|
use http::{uri::PathAndQuery, Uri};
|
||||||
use http_signature_normalization_reqwest::prelude::{Config, SignExt};
|
use http_signature_normalization_reqwest::{
|
||||||
|
prelude::{Config, SignExt},
|
||||||
|
DefaultSpawner,
|
||||||
|
};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use openssl::{
|
use openssl::{
|
||||||
hash::MessageDigest,
|
hash::MessageDigest,
|
||||||
|
|
@ -24,7 +26,11 @@ use openssl::{
|
||||||
rsa::Rsa,
|
rsa::Rsa,
|
||||||
sign::{Signer, Verifier},
|
sign::{Signer, Verifier},
|
||||||
};
|
};
|
||||||
use reqwest::Request;
|
use reqwest::{
|
||||||
|
header::{HeaderName, HeaderValue},
|
||||||
|
Method,
|
||||||
|
Request,
|
||||||
|
};
|
||||||
use reqwest_middleware::RequestBuilder;
|
use reqwest_middleware::RequestBuilder;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
|
|
@ -83,8 +89,9 @@ pub(crate) async fn sign_request(
|
||||||
activity: Bytes,
|
activity: Bytes,
|
||||||
private_key: PKey<Private>,
|
private_key: PKey<Private>,
|
||||||
http_signature_compat: bool,
|
http_signature_compat: bool,
|
||||||
) -> Result<Request, anyhow::Error> {
|
) -> Result<Request, Error> {
|
||||||
static CONFIG: Lazy<Config> = Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER));
|
static CONFIG: Lazy<Config<DefaultSpawner>> =
|
||||||
|
Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER));
|
||||||
static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| {
|
static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| {
|
||||||
Config::new()
|
Config::new()
|
||||||
.mastodon_compat()
|
.mastodon_compat()
|
||||||
|
|
@ -103,14 +110,10 @@ pub(crate) async fn sign_request(
|
||||||
Sha256::new(),
|
Sha256::new(),
|
||||||
activity,
|
activity,
|
||||||
move |signing_string| {
|
move |signing_string| {
|
||||||
let mut signer = Signer::new(MessageDigest::sha256(), &private_key)
|
let mut signer = Signer::new(MessageDigest::sha256(), &private_key)?;
|
||||||
.context("instantiating signer")?;
|
signer.update(signing_string.as_bytes())?;
|
||||||
signer
|
|
||||||
.update(signing_string.as_bytes())
|
|
||||||
.context("updating signer")?;
|
|
||||||
|
|
||||||
Ok(Base64.encode(signer.sign_to_vec().context("sign to vec")?))
|
Ok(Base64.encode(signer.sign_to_vec()?)) as Result<_, Error>
|
||||||
as Result<_, anyhow::Error>
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
|
@ -152,7 +155,7 @@ pub(crate) async fn signing_actor<'a, A, H>(
|
||||||
) -> Result<A, <A as Object>::Error>
|
) -> Result<A, <A as Object>::Error>
|
||||||
where
|
where
|
||||||
A: Object + Actor,
|
A: Object + Actor,
|
||||||
<A as Object>::Error: From<Error> + From<anyhow::Error>,
|
<A as Object>::Error: From<Error>,
|
||||||
for<'de2> <A as Object>::Kind: Deserialize<'de2>,
|
for<'de2> <A as Object>::Kind: Deserialize<'de2>,
|
||||||
H: IntoIterator<Item = (&'a HeaderName, &'a HeaderValue)>,
|
H: IntoIterator<Item = (&'a HeaderName, &'a HeaderValue)>,
|
||||||
{
|
{
|
||||||
|
|
@ -197,8 +200,8 @@ fn verify_signature_inner(
|
||||||
|
|
||||||
let verified = CONFIG
|
let verified = CONFIG
|
||||||
.begin_verify(method.as_str(), path_and_query, header_map)
|
.begin_verify(method.as_str(), path_and_query, header_map)
|
||||||
.map_err(Error::other)?
|
.map_err(|val| Error::Other(val.to_string()))?
|
||||||
.verify(|signature, signing_string| -> anyhow::Result<bool> {
|
.verify(|signature, signing_string| -> Result<bool, Error> {
|
||||||
debug!(
|
debug!(
|
||||||
"Verifying with key {}, message {}",
|
"Verifying with key {}, message {}",
|
||||||
&public_key, &signing_string
|
&public_key, &signing_string
|
||||||
|
|
@ -206,9 +209,13 @@ fn verify_signature_inner(
|
||||||
let public_key = PKey::public_key_from_pem(public_key.as_bytes())?;
|
let public_key = PKey::public_key_from_pem(public_key.as_bytes())?;
|
||||||
let mut verifier = Verifier::new(MessageDigest::sha256(), &public_key)?;
|
let mut verifier = Verifier::new(MessageDigest::sha256(), &public_key)?;
|
||||||
verifier.update(signing_string.as_bytes())?;
|
verifier.update(signing_string.as_bytes())?;
|
||||||
Ok(verifier.verify(&Base64.decode(signature)?)?)
|
|
||||||
})
|
let base64_decoded = Base64
|
||||||
.map_err(Error::other)?;
|
.decode(signature)
|
||||||
|
.map_err(|err| Error::Other(err.to_string()))?;
|
||||||
|
|
||||||
|
Ok(verifier.verify(&base64_decoded)?)
|
||||||
|
})?;
|
||||||
|
|
||||||
if verified {
|
if verified {
|
||||||
debug!("verified signature for {}", uri);
|
debug!("verified signature for {}", uri);
|
||||||
|
|
@ -289,7 +296,7 @@ pub mod test {
|
||||||
// use hardcoded date in order to test against hardcoded signature
|
// use hardcoded date in order to test against hardcoded signature
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"date",
|
"date",
|
||||||
HeaderValue::from_str("Tue, 28 Mar 2023 21:03:44 GMT").unwrap(),
|
reqwest::header::HeaderValue::from_str("Tue, 28 Mar 2023 21:03:44 GMT").unwrap(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let request_builder = ClientWithMiddleware::from(Client::new())
|
let request_builder = ClientWithMiddleware::from(Client::new())
|
||||||
|
|
|
||||||
39
src/lib.rs
39
src/lib.rs
|
|
@ -23,7 +23,46 @@ pub mod protocol;
|
||||||
pub(crate) mod reqwest_shim;
|
pub(crate) mod reqwest_shim;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
config::Data,
|
||||||
|
error::Error,
|
||||||
|
fetch::object_id::ObjectId,
|
||||||
|
traits::{ActivityHandler, Actor, Object},
|
||||||
|
};
|
||||||
pub use activitystreams_kinds as kinds;
|
pub use activitystreams_kinds as kinds;
|
||||||
|
|
||||||
|
use serde::{de::DeserializeOwned, Deserialize};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
/// Mime type for Activitypub data, used for `Accept` and `Content-Type` HTTP headers
|
/// Mime type for Activitypub data, used for `Accept` and `Content-Type` HTTP headers
|
||||||
pub static FEDERATION_CONTENT_TYPE: &str = "application/activity+json";
|
pub static FEDERATION_CONTENT_TYPE: &str = "application/activity+json";
|
||||||
|
|
||||||
|
/// Deserialize incoming inbox activity to the given type, perform basic
|
||||||
|
/// validation and extract the actor.
|
||||||
|
async fn parse_received_activity<Activity, ActorT, Datatype>(
|
||||||
|
body: &[u8],
|
||||||
|
data: &Data<Datatype>,
|
||||||
|
) -> Result<(Activity, ActorT), <Activity as ActivityHandler>::Error>
|
||||||
|
where
|
||||||
|
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
|
||||||
|
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
|
||||||
|
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
|
||||||
|
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
|
||||||
|
<ActorT as Object>::Error: From<Error>,
|
||||||
|
Datatype: Clone,
|
||||||
|
{
|
||||||
|
let activity: Activity = serde_json::from_slice(body).map_err(|e| {
|
||||||
|
// Attempt to include activity id in error message
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Id {
|
||||||
|
id: Url,
|
||||||
|
}
|
||||||
|
let id = serde_json::from_slice::<Id>(body).ok();
|
||||||
|
Error::ParseReceivedActivity(e, id.map(|i| i.id))
|
||||||
|
})?;
|
||||||
|
data.config.verify_url_and_domain(&activity).await?;
|
||||||
|
let actor = ObjectId::<ActorT>::from(activity.actor().clone())
|
||||||
|
.dereference(data)
|
||||||
|
.await?;
|
||||||
|
Ok((activity, actor))
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,25 +15,23 @@
|
||||||
//! };
|
//! };
|
||||||
//! let note_with_context = WithContext::new_default(note);
|
//! let note_with_context = WithContext::new_default(note);
|
||||||
//! let serialized = serde_json::to_string(¬e_with_context)?;
|
//! let serialized = serde_json::to_string(¬e_with_context)?;
|
||||||
//! assert_eq!(serialized, r#"{"@context":["https://www.w3.org/ns/activitystreams"],"content":"Hello world"}"#);
|
//! assert_eq!(serialized, r#"{"@context":"https://www.w3.org/ns/activitystreams","content":"Hello world"}"#);
|
||||||
//! Ok::<(), serde_json::error::Error>(())
|
//! Ok::<(), serde_json::error::Error>(())
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use crate::{config::Data, protocol::helpers::deserialize_one_or_many, traits::ActivityHandler};
|
use crate::{config::Data, traits::ActivityHandler};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
/// Default context used in Activitypub
|
/// Default context used in Activitypub
|
||||||
const DEFAULT_CONTEXT: &str = "https://www.w3.org/ns/activitystreams";
|
const DEFAULT_CONTEXT: &str = "https://www.w3.org/ns/activitystreams";
|
||||||
const DEFAULT_SECURITY_CONTEXT: &str = "https://w3id.org/security/v1";
|
|
||||||
|
|
||||||
/// Wrapper for federated structs which handles `@context` field.
|
/// Wrapper for federated structs which handles `@context` field.
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct WithContext<T> {
|
pub struct WithContext<T> {
|
||||||
#[serde(rename = "@context")]
|
#[serde(rename = "@context")]
|
||||||
#[serde(deserialize_with = "deserialize_one_or_many")]
|
context: Value,
|
||||||
context: Vec<Value>,
|
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
inner: T,
|
inner: T,
|
||||||
}
|
}
|
||||||
|
|
@ -41,15 +39,12 @@ pub struct WithContext<T> {
|
||||||
impl<T> WithContext<T> {
|
impl<T> WithContext<T> {
|
||||||
/// Create a new wrapper with the default Activitypub context.
|
/// Create a new wrapper with the default Activitypub context.
|
||||||
pub fn new_default(inner: T) -> WithContext<T> {
|
pub fn new_default(inner: T) -> WithContext<T> {
|
||||||
let context = vec![
|
let context = Value::String(DEFAULT_CONTEXT.to_string());
|
||||||
Value::String(DEFAULT_CONTEXT.to_string()),
|
|
||||||
Value::String(DEFAULT_SECURITY_CONTEXT.to_string()),
|
|
||||||
];
|
|
||||||
WithContext::new(inner, context)
|
WithContext::new(inner, context)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new wrapper with custom context. Use this in case you are implementing extensions.
|
/// Create new wrapper with custom context. Use this in case you are implementing extensions.
|
||||||
pub fn new(inner: T, context: Vec<Value>) -> WithContext<T> {
|
pub fn new(inner: T, context: Value) -> WithContext<T> {
|
||||||
WithContext { context, inner }
|
WithContext { context, inner }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,12 +56,12 @@ where
|
||||||
/// #[derive(serde::Deserialize)]
|
/// #[derive(serde::Deserialize)]
|
||||||
/// struct Note {
|
/// struct Note {
|
||||||
/// #[serde(deserialize_with = "deserialize_one")]
|
/// #[serde(deserialize_with = "deserialize_one")]
|
||||||
/// to: Url
|
/// to: [Url; 1]
|
||||||
/// }
|
/// }
|
||||||
///
|
///
|
||||||
/// let note = serde_json::from_str::<Note>(r#"{"to": ["https://example.com/u/alice"] }"#);
|
/// let note = serde_json::from_str::<Note>(r#"{"to": ["https://example.com/u/alice"] }"#);
|
||||||
/// assert!(note.is_ok());
|
/// assert!(note.is_ok());
|
||||||
pub fn deserialize_one<'de, T, D>(deserializer: D) -> Result<T, D::Error>
|
pub fn deserialize_one<'de, T, D>(deserializer: D) -> Result<[T; 1], D::Error>
|
||||||
where
|
where
|
||||||
T: Deserialize<'de>,
|
T: Deserialize<'de>,
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
|
|
@ -75,8 +75,8 @@ where
|
||||||
|
|
||||||
let result: MaybeArray<T> = Deserialize::deserialize(deserializer)?;
|
let result: MaybeArray<T> = Deserialize::deserialize(deserializer)?;
|
||||||
Ok(match result {
|
Ok(match result {
|
||||||
MaybeArray::Simple(value) => value,
|
MaybeArray::Simple(value) => [value],
|
||||||
MaybeArray::Array([value]) => value,
|
MaybeArray::Array([value]) => [value],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -125,7 +125,7 @@ mod tests {
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
struct Note {
|
struct Note {
|
||||||
#[serde(deserialize_with = "deserialize_one")]
|
#[serde(deserialize_with = "deserialize_one")]
|
||||||
_to: Url,
|
_to: [Url; 1],
|
||||||
}
|
}
|
||||||
|
|
||||||
let note = serde_json::from_str::<Note>(
|
let note = serde_json::from_str::<Note>(
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ use url::Url;
|
||||||
/// Public key of actors which is used for HTTP signatures.
|
/// Public key of actors which is used for HTTP signatures.
|
||||||
///
|
///
|
||||||
/// This needs to be federated in the `public_key` field of all actors.
|
/// This needs to be federated in the `public_key` field of all actors.
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct PublicKey {
|
pub struct PublicKey {
|
||||||
/// Id of this private key.
|
/// Id of this private key.
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ use serde::{Deserialize, Serialize};
|
||||||
/// Media type for markdown text.
|
/// Media type for markdown text.
|
||||||
///
|
///
|
||||||
/// <https://www.iana.org/assignments/media-types/media-types.xhtml>
|
/// <https://www.iana.org/assignments/media-types/media-types.xhtml>
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||||
pub enum MediaTypeMarkdown {
|
pub enum MediaTypeMarkdown {
|
||||||
/// `text/markdown`
|
/// `text/markdown`
|
||||||
#[serde(rename = "text/markdown")]
|
#[serde(rename = "text/markdown")]
|
||||||
|
|
@ -45,7 +45,7 @@ pub enum MediaTypeMarkdown {
|
||||||
/// Media type for HTML text.
|
/// Media type for HTML text.
|
||||||
///
|
///
|
||||||
/// <https://www.iana.org/assignments/media-types/media-types.xhtml>
|
/// <https://www.iana.org/assignments/media-types/media-types.xhtml>
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
|
||||||
pub enum MediaTypeHtml {
|
pub enum MediaTypeHtml {
|
||||||
/// `text/html`
|
/// `text/html`
|
||||||
#[serde(rename = "text/html")]
|
#[serde(rename = "text/html")]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
//! Verify that received data is valid
|
//! Verify that received data is valid
|
||||||
|
|
||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
use anyhow::anyhow;
|
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
/// Check that both urls have the same domain. If not, return UrlVerificationError.
|
/// Check that both urls have the same domain. If not, return UrlVerificationError.
|
||||||
|
|
@ -16,7 +15,7 @@ use url::Url;
|
||||||
/// ```
|
/// ```
|
||||||
pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> {
|
pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> {
|
||||||
if a.domain() != b.domain() {
|
if a.domain() != b.domain() {
|
||||||
return Err(Error::UrlVerificationError(anyhow!("Domains do not match")));
|
return Err(Error::UrlVerificationError("Domains do not match"));
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -33,7 +32,7 @@ pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> {
|
||||||
/// ```
|
/// ```
|
||||||
pub fn verify_urls_match(a: &Url, b: &Url) -> Result<(), Error> {
|
pub fn verify_urls_match(a: &Url, b: &Url) -> Result<(), Error> {
|
||||||
if a != b {
|
if a != b {
|
||||||
return Err(Error::UrlVerificationError(anyhow!("Urls do not match")));
|
return Err(Error::UrlVerificationError("Urls do not match"));
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,8 @@ use bytes::{BufMut, Bytes, BytesMut};
|
||||||
use futures_core::{ready, stream::BoxStream, Stream};
|
use futures_core::{ready, stream::BoxStream, Stream};
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use reqwest::Response;
|
use reqwest::Response;
|
||||||
use serde::de::DeserializeOwned;
|
|
||||||
use std::{
|
use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
marker::PhantomData,
|
|
||||||
mem,
|
mem,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
|
|
@ -30,10 +28,7 @@ impl Future for BytesFuture {
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
loop {
|
loop {
|
||||||
let this = self.as_mut().project();
|
let this = self.as_mut().project();
|
||||||
if let Some(chunk) = ready!(this.stream.poll_next(cx))
|
if let Some(chunk) = ready!(this.stream.poll_next(cx)).transpose()? {
|
||||||
.transpose()
|
|
||||||
.map_err(Error::other)?
|
|
||||||
{
|
|
||||||
this.aggregator.put(chunk);
|
this.aggregator.put(chunk);
|
||||||
if this.aggregator.len() > *this.limit {
|
if this.aggregator.len() > *this.limit {
|
||||||
return Poll::Ready(Err(Error::ResponseBodyLimit));
|
return Poll::Ready(Err(Error::ResponseBodyLimit));
|
||||||
|
|
@ -49,27 +44,6 @@ impl Future for BytesFuture {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pin_project! {
|
|
||||||
pub struct JsonFuture<T> {
|
|
||||||
_t: PhantomData<T>,
|
|
||||||
#[pin]
|
|
||||||
future: BytesFuture,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> Future for JsonFuture<T>
|
|
||||||
where
|
|
||||||
T: DeserializeOwned,
|
|
||||||
{
|
|
||||||
type Output = Result<T, Error>;
|
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
let this = self.project();
|
|
||||||
let bytes = ready!(this.future.poll(cx))?;
|
|
||||||
Poll::Ready(serde_json::from_slice(&bytes).map_err(Error::other))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pin_project! {
|
pin_project! {
|
||||||
pub struct TextFuture {
|
pub struct TextFuture {
|
||||||
#[pin]
|
#[pin]
|
||||||
|
|
@ -83,7 +57,7 @@ impl Future for TextFuture {
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
let this = self.project();
|
let this = self.project();
|
||||||
let bytes = ready!(this.future.poll(cx))?;
|
let bytes = ready!(this.future.poll(cx))?;
|
||||||
Poll::Ready(String::from_utf8(bytes.to_vec()).map_err(Error::other))
|
Poll::Ready(String::from_utf8(bytes.to_vec()).map_err(Error::Utf8))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -97,20 +71,16 @@ impl Future for TextFuture {
|
||||||
/// TODO: Remove this shim as soon as reqwest gets support for size-limited bodies.
|
/// TODO: Remove this shim as soon as reqwest gets support for size-limited bodies.
|
||||||
pub trait ResponseExt {
|
pub trait ResponseExt {
|
||||||
type BytesFuture;
|
type BytesFuture;
|
||||||
type JsonFuture<T>;
|
|
||||||
type TextFuture;
|
type TextFuture;
|
||||||
|
|
||||||
/// Size limited version of `bytes` to work around a reqwest issue. Check [`ResponseExt`] docs for details.
|
/// Size limited version of `bytes` to work around a reqwest issue. Check [`ResponseExt`] docs for details.
|
||||||
fn bytes_limited(self) -> Self::BytesFuture;
|
fn bytes_limited(self) -> Self::BytesFuture;
|
||||||
/// Size limited version of `json` to work around a reqwest issue. Check [`ResponseExt`] docs for details.
|
|
||||||
fn json_limited<T>(self) -> Self::JsonFuture<T>;
|
|
||||||
/// Size limited version of `text` to work around a reqwest issue. Check [`ResponseExt`] docs for details.
|
/// Size limited version of `text` to work around a reqwest issue. Check [`ResponseExt`] docs for details.
|
||||||
fn text_limited(self) -> Self::TextFuture;
|
fn text_limited(self) -> Self::TextFuture;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseExt for Response {
|
impl ResponseExt for Response {
|
||||||
type BytesFuture = BytesFuture;
|
type BytesFuture = BytesFuture;
|
||||||
type JsonFuture<T> = JsonFuture<T>;
|
|
||||||
type TextFuture = TextFuture;
|
type TextFuture = TextFuture;
|
||||||
|
|
||||||
fn bytes_limited(self) -> Self::BytesFuture {
|
fn bytes_limited(self) -> Self::BytesFuture {
|
||||||
|
|
@ -121,13 +91,6 @@ impl ResponseExt for Response {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn json_limited<T>(self) -> Self::JsonFuture<T> {
|
|
||||||
JsonFuture {
|
|
||||||
_t: PhantomData,
|
|
||||||
future: self.bytes_limited(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn text_limited(self) -> Self::TextFuture {
|
fn text_limited(self) -> Self::TextFuture {
|
||||||
TextFuture {
|
TextFuture {
|
||||||
future: self.bytes_limited(),
|
future: self.bytes_limited(),
|
||||||
|
|
|
||||||
|
|
@ -340,12 +340,12 @@ pub trait Collection: Sized {
|
||||||
pub mod tests {
|
pub mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{
|
use crate::{
|
||||||
|
error::Error,
|
||||||
fetch::object_id::ObjectId,
|
fetch::object_id::ObjectId,
|
||||||
http_signatures::{generate_actor_keypair, Keypair},
|
http_signatures::{generate_actor_keypair, Keypair},
|
||||||
protocol::{public_key::PublicKey, verification::verify_domains_match},
|
protocol::{public_key::PublicKey, verification::verify_domains_match},
|
||||||
};
|
};
|
||||||
use activitystreams_kinds::{activity::FollowType, actor::PersonType};
|
use activitystreams_kinds::{activity::FollowType, actor::PersonType};
|
||||||
use anyhow::Error;
|
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
|
@ -356,7 +356,7 @@ pub mod tests {
|
||||||
pub async fn read_post_from_json_id<T>(&self, _: Url) -> Result<Option<T>, Error> {
|
pub async fn read_post_from_json_id<T>(&self, _: Url) -> Result<Option<T>, Error> {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
pub async fn read_local_user(&self, _: String) -> Result<DbUser, Error> {
|
pub async fn read_local_user(&self, _: &str) -> Result<DbUser, Error> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
pub async fn upsert<T>(&self, _: &T) -> Result<(), Error> {
|
pub async fn upsert<T>(&self, _: &T) -> Result<(), Error> {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue