Compare commits

...

7 commits

Author SHA1 Message Date
Felix Ableitner
f8a75d2605 use trait to allow references 2025-06-17 14:28:53 +02:00
Felix Ableitner
63d132d83f testing 2025-06-16 16:31:34 +02:00
Felix Ableitner
4f0af179b1 separate methods 2025-06-16 16:25:11 +02:00
Felix Ableitner
d833d7d716 remove generic 2025-06-16 16:06:22 +02:00
Felix Ableitner
2827ca3030 async working 2025-06-16 16:03:10 +02:00
Felix Ableitner
ac2b7882ae sync version working 2025-06-16 15:35:19 +02:00
Felix Ableitner
9f56f5390c Add hook for incoming activities 2025-06-16 14:46:03 +02:00

View file

@ -20,6 +20,66 @@ pub async fn receive_activity<Activity, ActorT, Datatype>(
body: Bytes, body: Bytes,
data: &Data<Datatype>, data: &Data<Datatype>,
) -> Result<HttpResponse, <Activity as ActivityHandler>::Error> ) -> Result<HttpResponse, <Activity as ActivityHandler>::Error>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
let (activity, _) = do_stuff::<Activity, ActorT, Datatype>(request, body, data).await?;
do_more_stuff(activity, data).await
}
/// Workaround required so we can use references for the hook, instead of cloning data.
pub trait ReceiveActivityHook<Activity, ActorT, Datatype>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + Clone + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + Clone + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
/// Called when a new activity is recived
fn hook(
self,
activity: &Activity,
actor: &ActorT,
data: &Data<Datatype>,
) -> impl std::future::Future<Output = Result<(), <Activity as ActivityHandler>::Error>>;
}
/// Same as [receive_activity], only that it calls the provided hook function before
/// calling activity verify and receive functions.
pub async fn receive_activity_with_hook<Activity, ActorT, Datatype>(
request: HttpRequest,
body: Bytes,
hook: impl ReceiveActivityHook<Activity, ActorT, Datatype>,
data: &Data<Datatype>,
) -> Result<HttpResponse, <Activity as ActivityHandler>::Error>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + Clone + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + Clone + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
let (activity, actor) = do_stuff::<Activity, ActorT, Datatype>(request, body, data).await?;
hook.hook(&activity, &actor, data).await?;
do_more_stuff(activity, data).await
}
async fn do_stuff<Activity, ActorT, Datatype>(
request: HttpRequest,
body: Bytes,
data: &Data<Datatype>,
) -> Result<(Activity, ActorT), <Activity as ActivityHandler>::Error>
where where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static, Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static, ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
@ -41,6 +101,17 @@ where
let uri = http_compat::uri(request.uri()); let uri = http_compat::uri(request.uri());
verify_signature(&headers, &method, &uri, actor.public_key_pem())?; verify_signature(&headers, &method, &uri, actor.public_key_pem())?;
Ok((activity, actor))
}
async fn do_more_stuff<Activity, Datatype>(
activity: Activity,
data: &Data<Datatype>,
) -> Result<HttpResponse, <Activity as ActivityHandler>::Error>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
Datatype: Clone,
{
debug!("Receiving activity {}", activity.id().to_string()); debug!("Receiving activity {}", activity.id().to_string());
activity.verify(data).await?; activity.verify(data).await?;
activity.receive(data).await?; activity.receive(data).await?;
@ -75,15 +146,38 @@ mod test {
} }
#[tokio::test] #[tokio::test]
async fn test_receive_activity() { async fn test_receive_activity_hook() {
let (body, incoming_request, config) = setup_receive_test().await; let (body, incoming_request, config) = setup_receive_test().await;
receive_activity::<Follow, DbUser, DbConnection>( let res = receive_activity_with_hook::<Follow, DbUser, DbConnection>(
incoming_request.to_http_request(), incoming_request.to_http_request(),
body, body,
Dummy,
&config.to_request_data(), &config.to_request_data(),
) )
.await .await;
.unwrap(); assert_eq!(res.err(), Some(Error::Other("test-error".to_string())));
}
struct Dummy;
impl<Activity, ActorT, Datatype> ReceiveActivityHook<Activity, ActorT, Datatype> for Dummy
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + Clone + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + Clone + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
async fn hook(
self,
_activity: &Activity,
_actor: &ActorT,
_data: &Data<Datatype>,
) -> Result<(), <Activity as ActivityHandler>::Error> {
// ensure that hook gets called by returning this value
Err(Error::Other("test-error".to_string()).into())
}
} }
#[tokio::test] #[tokio::test]