diff --git a/src/bot/commands.rs b/src/bot/commands.rs index f451c1a..917761d 100644 --- a/src/bot/commands.rs +++ b/src/bot/commands.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, collections::HashMap}; +use std::collections::HashMap; use anyhow::anyhow; use serenity::all::{ @@ -59,8 +59,6 @@ impl Handler { return Ok(()); } - // if let CommandType interaction.data.kind {} - for option in interaction.data.options() { if option.name == "script" { let script = get_guild(&self.db, guild_id) @@ -171,6 +169,7 @@ if reactions["⭐"] >= 3 { .guild() .and_then(|x| x.parent_id) .map(|x| x.to_string()), + false, ) .unwrap_or_else(|err| (None, format!("run failed: {err:?}"))); diff --git a/src/bot/handler.rs b/src/bot/handler.rs index 5c25aad..522a262 100644 --- a/src/bot/handler.rs +++ b/src/bot/handler.rs @@ -15,6 +15,7 @@ use serenity::{ small_fixed_array::{FixedArray, FixedString}, }; use sqlx::{PgExecutor, PgPool}; +use tokio::sync::Semaphore; use crate::{ commands, @@ -49,7 +50,7 @@ impl From for CloneBuilderMessage { pub struct Handler { pub db: PgPool, - pub message_lock: Arc>>>, + pub message_lock: Arc>>, } #[async_trait] @@ -71,25 +72,12 @@ impl EventHandler for Handler { } } FullEvent::ReactionAdd { add_reaction, .. } => { - let lock = self - .message_lock - .entry(add_reaction.message_id) - .or_insert_with(|| Arc::new(Mutex::new(()))); - - { - let _guard = lock.lock().await; - - if let Err(e) = self.reaction_add(context, add_reaction).await { - error!("Error while processing reaction add event: {e:?}"); - } - } - - if Arc::strong_count(&lock) == 2 { - self.message_lock - .remove_if(&add_reaction.message_id, |_, val| { - Arc::strong_count(val) <= 2 - }); - } + self.process_reaction(context, add_reaction, false).await; + } + FullEvent::ReactionRemove { + removed_reaction, .. + } => { + self.process_reaction(context, removed_reaction, true).await; } FullEvent::InteractionCreate { interaction } => { if let Err(e) = self.interaction_create(context, interaction).await { @@ -129,8 +117,37 @@ impl Handler { Ok(()) } - async fn reaction_add(&self, ctx: &Context, reaction: &Reaction) -> anyhow::Result<()> { - if reaction.user(&ctx.http).await?.bot() { + async fn process_reaction(&self, ctx: &Context, reaction: &Reaction, is_remove: bool) { + let semaphore = { + self.message_lock + .entry(reaction.message_id) + .or_insert_with(|| Arc::new(Semaphore::new(1))) + .value() + .clone() + }; + + let _permit = match semaphore.acquire().await { + Ok(p) => p, + Err(_) => return, + }; + + if let Err(e) = self.process_reaction_inner(ctx, reaction, is_remove).await { + error!("Error while processing reaction add event: {e:?}"); + } + + if Arc::strong_count(&semaphore) <= 2 { + self.message_lock + .remove_if(&reaction.message_id, |_, val| Arc::strong_count(val) <= 2); + } + } + + async fn process_reaction_inner( + &self, + ctx: &Context, + reaction: &Reaction, + is_remove: bool, + ) -> anyhow::Result<()> { + if reaction.user((ctx.cache(), ctx.http())).await?.bot() { return Ok(()); } let Some(guild_id) = reaction.guild_id else { @@ -145,7 +162,11 @@ impl Handler { let existing_message = get_message(&self.db, reaction.message_id).await?; - let msg = reaction.message(&ctx.http).await?; + if is_remove && existing_message.is_none() { + return Ok(()); + } + + let msg = reaction.message((ctx.cache(), ctx.http())).await?; let reactions: HashMap = msg .reactions @@ -159,7 +180,7 @@ impl Handler { CloneBuilderMessage::from(msg.clone()) }; - let channel = reaction.channel(&ctx.http).await?; + let channel = reaction.channel((ctx.cache(), ctx.http())).await?; debug!("channel: {channel:?}"); @@ -171,6 +192,7 @@ impl Handler { .guild() .and_then(|x| x.parent_id) .map(|x| x.to_string()), + is_remove, ) .map(|x| x.0) .inspect_err(|res| { @@ -200,7 +222,7 @@ impl Handler { .parse() .context("unable to parse counter channel id")?; - let existing_message = ctx.http.get_message(chn_id, msg_id).await?; + let existing_message = chn_id.message((ctx.cache(), ctx.http()), msg_id).await?; for comp in existing_message .components @@ -311,7 +333,7 @@ impl Handler { EditWebhookMessage::new().components(components), ) .await?; - } else { + } else if !is_remove { let counter_msg = webhook .execute( &ctx.http, diff --git a/src/bot/main.rs b/src/bot/main.rs index 5a1114d..6818f25 100644 --- a/src/bot/main.rs +++ b/src/bot/main.rs @@ -57,7 +57,9 @@ async fn main() -> anyhow::Result<()> { Token::from_str(config.bot.token.expose_secret()).unwrap(), GatewayIntents::GUILD_MESSAGES | GatewayIntents::GUILD_MESSAGE_REACTIONS - | GatewayIntents::GUILDS, + | GatewayIntents::GUILDS + | GatewayIntents::GUILD_MESSAGES + | GatewayIntents::MESSAGE_CONTENT, ) .event_handler(Arc::new(Handler { db, diff --git a/src/script.rs b/src/script.rs index bb1fc53..a0c6867 100644 --- a/src/script.rs +++ b/src/script.rs @@ -20,19 +20,21 @@ pub fn check( reactions: HashMap, channel: String, category: Option, + is_removed: bool, ) -> anyhow::Result<(Option, String)> { debug!("script: {script}"); let mut engine = Engine::new(); let buffer = Arc::new(RwLock::new(String::new())); engine.set_max_operations(1000); engine.register_type::(); - engine.register_fn("result", |webhook_url: String, count: i64, icon: String| { - ReactionResult { + engine.register_fn( + "result", + |webhook_url: String, count: Dynamic, icon: String| ReactionResult { channel_id: webhook_url, - count, + count: count.as_int().unwrap_or(0), icon, - } - }); + }, + ); let logger = buffer.clone(); @@ -56,6 +58,7 @@ pub fn check( let mut scope = Scope::new(); scope.set_value("reactions", Dynamic::from(emotes_input)); scope.set_value("channel", Dynamic::from(channel)); + scope.set_value("removed", Dynamic::from(is_removed)); scope.set_value( "category",