use trait to allow references

This commit is contained in:
Felix Ableitner 2025-06-17 14:28:53 +02:00
parent 63d132d83f
commit f8a75d2605

View file

@ -1,7 +1,5 @@
//! Handles incoming activities, verifying HTTP signatures and other checks //! Handles incoming activities, verifying HTTP signatures and other checks
use std::future::Future;
use super::http_compat; use super::http_compat;
use crate::{ use crate::{
config::Data, config::Data,
@ -35,12 +33,31 @@ where
do_more_stuff(activity, data).await do_more_stuff(activity, data).await
} }
/// Workaround required so we can use references for the hook, instead of cloning data.
pub trait ReceiveActivityHook<Activity, ActorT, Datatype>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + Clone + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + Clone + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
/// Called when a new activity is recived
fn hook(
self,
activity: &Activity,
actor: &ActorT,
data: &Data<Datatype>,
) -> impl std::future::Future<Output = Result<(), <Activity as ActivityHandler>::Error>>;
}
/// Same as [receive_activity], only that it calls the provided hook function before /// Same as [receive_activity], only that it calls the provided hook function before
/// calling activity verify and receive functions. /// calling activity verify and receive functions.
pub async fn receive_activity_with_hook<Activity, ActorT, Datatype, Fut>( pub async fn receive_activity_with_hook<Activity, ActorT, Datatype>(
request: HttpRequest, request: HttpRequest,
body: Bytes, body: Bytes,
hook: impl FnOnce(Activity, ActorT, Data<Datatype>) -> Fut, hook: impl ReceiveActivityHook<Activity, ActorT, Datatype>,
data: &Data<Datatype>, data: &Data<Datatype>,
) -> Result<HttpResponse, <Activity as ActivityHandler>::Error> ) -> Result<HttpResponse, <Activity as ActivityHandler>::Error>
where where
@ -50,11 +67,10 @@ where
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>, <Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>, <ActorT as Object>::Error: From<Error>,
Datatype: Clone, Datatype: Clone,
Fut: Future<Output = Result<(), <Activity as ActivityHandler>::Error>>,
{ {
let (activity, actor) = do_stuff::<Activity, ActorT, Datatype>(request, body, data).await?; let (activity, actor) = do_stuff::<Activity, ActorT, Datatype>(request, body, data).await?;
hook(activity.clone(), actor.clone(), data.clone()).await?; hook.hook(&activity, &actor, data).await?;
do_more_stuff(activity, data).await do_more_stuff(activity, data).await
} }
@ -132,23 +148,36 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_receive_activity_hook() { async fn test_receive_activity_hook() {
let (body, incoming_request, config) = setup_receive_test().await; let (body, incoming_request, config) = setup_receive_test().await;
let res = receive_activity_with_hook::<Follow, DbUser, DbConnection, _>( let res = receive_activity_with_hook::<Follow, DbUser, DbConnection>(
incoming_request.to_http_request(), incoming_request.to_http_request(),
body, body,
inbox_activity_hook, Dummy,
&config.to_request_data(), &config.to_request_data(),
) )
.await; .await;
assert_eq!(res.err(), Some(Error::Other("test-error".to_string()))); assert_eq!(res.err(), Some(Error::Other("test-error".to_string())));
} }
async fn inbox_activity_hook<Activity: ActivityHandler + Send + Sync, ActorT>( struct Dummy;
_activity: Activity,
_actor: ActorT, impl<Activity, ActorT, Datatype> ReceiveActivityHook<Activity, ActorT, Datatype> for Dummy
_context: Data<DbConnection>, where
) -> Result<(), Error> { Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + Clone + 'static,
// ensure that hook gets called by returning this value ActorT: Object<DataType = Datatype> + Actor + Send + Clone + 'static,
Err(Error::Other("test-error".to_string())) for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
async fn hook(
self,
_activity: &Activity,
_actor: &ActorT,
_data: &Data<Datatype>,
) -> Result<(), <Activity as ActivityHandler>::Error> {
// ensure that hook gets called by returning this value
Err(Error::Other("test-error".to_string()).into())
}
} }
#[tokio::test] #[tokio::test]