Add hook for incoming activities (#146)

* Add hook for incoming activities

* sync version working

* async working

* remove generic

* separate methods

* testing

* use trait to allow references
This commit is contained in:
Nutomic 2025-06-18 09:38:34 +00:00 committed by GitHub
parent 7e876dd5ce
commit 659a6a3cff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -20,6 +20,66 @@ pub async fn receive_activity<Activity, ActorT, Datatype>(
body: Bytes,
data: &Data<Datatype>,
) -> Result<HttpResponse, <Activity as ActivityHandler>::Error>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
let (activity, _) = do_stuff::<Activity, ActorT, Datatype>(request, body, data).await?;
do_more_stuff(activity, data).await
}
/// Workaround required so we can use references for the hook, instead of cloning data.
pub trait ReceiveActivityHook<Activity, ActorT, Datatype>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + Clone + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + Clone + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
/// Called when a new activity is recived
fn hook(
self,
activity: &Activity,
actor: &ActorT,
data: &Data<Datatype>,
) -> impl std::future::Future<Output = Result<(), <Activity as ActivityHandler>::Error>>;
}
/// Same as [receive_activity], only that it calls the provided hook function before
/// calling activity verify and receive functions.
pub async fn receive_activity_with_hook<Activity, ActorT, Datatype>(
request: HttpRequest,
body: Bytes,
hook: impl ReceiveActivityHook<Activity, ActorT, Datatype>,
data: &Data<Datatype>,
) -> Result<HttpResponse, <Activity as ActivityHandler>::Error>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + Clone + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + Clone + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
let (activity, actor) = do_stuff::<Activity, ActorT, Datatype>(request, body, data).await?;
hook.hook(&activity, &actor, data).await?;
do_more_stuff(activity, data).await
}
async fn do_stuff<Activity, ActorT, Datatype>(
request: HttpRequest,
body: Bytes,
data: &Data<Datatype>,
) -> Result<(Activity, ActorT), <Activity as ActivityHandler>::Error>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
@ -41,6 +101,17 @@ where
let uri = http_compat::uri(request.uri());
verify_signature(&headers, &method, &uri, actor.public_key_pem())?;
Ok((activity, actor))
}
async fn do_more_stuff<Activity, Datatype>(
activity: Activity,
data: &Data<Datatype>,
) -> Result<HttpResponse, <Activity as ActivityHandler>::Error>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
Datatype: Clone,
{
debug!("Receiving activity {}", activity.id().to_string());
activity.verify(data).await?;
activity.receive(data).await?;
@ -75,15 +146,38 @@ mod test {
}
#[tokio::test]
async fn test_receive_activity() {
async fn test_receive_activity_hook() {
let (body, incoming_request, config) = setup_receive_test().await;
receive_activity::<Follow, DbUser, DbConnection>(
let res = receive_activity_with_hook::<Follow, DbUser, DbConnection>(
incoming_request.to_http_request(),
body,
Dummy,
&config.to_request_data(),
)
.await
.unwrap();
.await;
assert_eq!(res.err(), Some(Error::Other("test-error".to_string())));
}
struct Dummy;
impl<Activity, ActorT, Datatype> ReceiveActivityHook<Activity, ActorT, Datatype> for Dummy
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + Clone + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + Clone + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<Error> + From<<ActorT as Object>::Error>,
<ActorT as Object>::Error: From<Error>,
Datatype: Clone,
{
async fn hook(
self,
_activity: &Activity,
_actor: &ActorT,
_data: &Data<Datatype>,
) -> Result<(), <Activity as ActivityHandler>::Error> {
// ensure that hook gets called by returning this value
Err(Error::Other("test-error".to_string()).into())
}
}
#[tokio::test]