From a1cd0e6a2591701e85ec87a35b1f6bb0ad76e6c8 Mon Sep 17 00:00:00 2001 From: minneelyyyy Date: Fri, 27 Jun 2025 13:30:38 -0400 Subject: [PATCH] use sqlx's query macro in some places --- .gitignore | 3 +- src/commands/gambling/balance.rs | 4 +- src/commands/gambling/daily.rs | 83 ++++++++++++++-------------- src/commands/gambling/leaderboard.rs | 19 ++++--- src/commands/gambling/mod.rs | 28 +++++----- src/commands/gambling/shop.rs | 4 +- src/commands/self_roles/mod.rs | 30 ++++------ src/commands/self_roles/whois.rs | 6 +- 8 files changed, 82 insertions(+), 95 deletions(-) diff --git a/.gitignore b/.gitignore index 0b745e2..ac1e5a0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target -.env \ No newline at end of file +.env +.vscode diff --git a/src/commands/gambling/balance.rs b/src/commands/gambling/balance.rs index fd39da0..728fe63 100644 --- a/src/commands/gambling/balance.rs +++ b/src/commands/gambling/balance.rs @@ -6,9 +6,9 @@ use poise::serenity_prelude as serenity; #[poise::command(slash_command, prefix_command, aliases("bal", "b"))] pub async fn balance(ctx: Context<'_>, user: Option) -> Result<(), Error> { let user = user.as_ref().unwrap_or(ctx.author()); - let db = &ctx.data().database; + let mut tx = ctx.data().database.begin().await?; - let wealth = super::get_balance(user.id, db).await?; + let wealth = super::get_balance(user.id, &mut tx).await?; common::no_ping_reply(&ctx, format!("{} **{}** token(s).", if user.id == ctx.author().id { diff --git a/src/commands/gambling/daily.rs b/src/commands/gambling/daily.rs index 69157d7..f2139e5 100644 --- a/src/commands/gambling/daily.rs +++ b/src/commands/gambling/daily.rs @@ -1,58 +1,54 @@ use crate::{Context, Error}; use poise::serenity_prelude::{UserId, User}; -use sqlx::{types::chrono::{DateTime, Utc, TimeZone}, PgExecutor, Row}; +use sqlx::{types::chrono::{DateTime, TimeZone, Utc}, PgConnection}; use std::time::Duration; -async fn get_streak<'a, E>(db: E, user: UserId) -> Result, Error> -where - E: PgExecutor<'a>, -{ - match sqlx::query( - "SELECT streak FROM dailies WHERE userid = $1" - ).bind(user.get() as i64).fetch_one(db).await - { - Ok(row) => Ok(Some(row.get(0))), - Err(sqlx::Error::RowNotFound) => Ok(None), - Err(e) => Err(Box::new(e)), - } +pub async fn get_streak(conn: &mut PgConnection, user: UserId) -> Result, Error> { + let result = sqlx::query!( + "SELECT streak FROM dailies WHERE userid = $1", + user.get() as i64 + ) + .fetch_optional(conn) + .await?; + + Ok(result.map(|r| r.streak).unwrap_or(None)) } -async fn set_streak<'a, E>(db: E, user: UserId, streak: i32) -> Result<(), Error> -where - E: PgExecutor<'a>, -{ - sqlx::query("INSERT INTO dailies (userid, streak) VALUES ($1, $2) ON CONFLICT (userid) DO UPDATE SET streak = EXCLUDED.streak") - .bind(user.get() as i64) - .bind(streak) - .execute(db).await?; +pub async fn set_streak(conn: &mut PgConnection, user: UserId, streak: i32) -> Result<(), Error> { + sqlx::query!( + "INSERT INTO dailies (userid, streak) VALUES ($1, $2) + ON CONFLICT (userid) DO UPDATE SET streak = EXCLUDED.streak", + user.get() as i64, + streak + ) + .execute(conn) + .await?; Ok(()) } -async fn get_last<'a, E>(db: E, user: UserId) -> Result>, Error> -where - E: PgExecutor<'a>, -{ - match sqlx::query( - "SELECT last FROM dailies WHERE userid = $1" - ).bind(user.get() as i64).fetch_one(db).await - { - Ok(row) => Ok(Some(row.get(0))), - Err(sqlx::Error::RowNotFound) => Ok(None), - Err(e) => Err(Box::new(e)), - } +pub async fn get_last(conn: &mut PgConnection, user: UserId) -> Result>, Error> { + let result = sqlx::query!( + "SELECT last FROM dailies WHERE userid = $1", + user.get() as i64 + ) + .fetch_optional(conn) + .await?; + + Ok(result.map(|r| r.last).unwrap_or(None)) } -async fn set_last<'a, E>(db: E, user: UserId, last: DateTime) -> Result<(), Error> -where - E: PgExecutor<'a>, -{ - sqlx::query("INSERT INTO dailies (userid, last) VALUES ($1, $2) ON CONFLICT (userid) DO UPDATE SET last = EXCLUDED.last") - .bind(user.get() as i64) - .bind(last) - .execute(db).await?; +pub async fn set_last(conn: &mut PgConnection, user: UserId, last: DateTime) -> Result<(), Error> { + sqlx::query!( + "INSERT INTO dailies (userid, last) VALUES ($1, $2) + ON CONFLICT (userid) DO UPDATE SET last = EXCLUDED.last", + user.get() as i64, + last + ) + .execute(conn) + .await?; Ok(()) } @@ -60,13 +56,14 @@ where /// Tells you what your current daily streak is #[poise::command(slash_command, prefix_command)] pub async fn streak(ctx: Context<'_>, user: Option) -> Result<(), Error> { - let db = &ctx.data().database; + let mut tx = ctx.data().database.begin().await?; + let (user, who) = match user { Some(user) => (user.id, format!("{} has", user.display_name())), None => (ctx.author().id, "You have".to_string()), }; - ctx.reply(format!("{who} a daily streak of **{}**", get_streak(db, user).await?.unwrap_or(0))).await?; + ctx.reply(format!("{who} a daily streak of **{}**", get_streak(&mut tx, user).await?.unwrap_or(0))).await?; Ok(()) } diff --git a/src/commands/gambling/leaderboard.rs b/src/commands/gambling/leaderboard.rs index 42301f5..525fd09 100644 --- a/src/commands/gambling/leaderboard.rs +++ b/src/commands/gambling/leaderboard.rs @@ -1,7 +1,6 @@ use crate::common::{Context, Error}; use poise::serenity_prelude::UserId; -use sqlx::Row; enum LeaderboardType { Tokens(usize), @@ -13,15 +12,16 @@ async fn display_leaderboard(ctx: Context<'_>, t: LeaderboardType) -> Result<(), match t { LeaderboardType::Tokens(count) => { - let rows = sqlx::query( + let rows = sqlx::query!( r#" SELECT id, balance FROM bank ORDER BY balance DESC LIMIT $1 - "# - ).bind(count as i32).fetch_all(db).await?; + "#, + count as i32 + ).fetch_all(db).await?; - let users: Vec<(_, i32)> = rows.iter().map(|row| (UserId::new(row.get::(0) as u64), row.get(1))).collect(); + let users: Vec<(_, i32)> = rows.iter().map(|row| (UserId::new(row.id as u64), row.balance.unwrap_or(100))).collect(); let mut output = String::new(); for (id, balance) in users { @@ -32,15 +32,16 @@ async fn display_leaderboard(ctx: Context<'_>, t: LeaderboardType) -> Result<(), ctx.reply(format!("```\n{output}```")).await?; } LeaderboardType::Dailies(count) => { - let rows = sqlx::query( + let rows = sqlx::query!( r#" SELECT userid, streak FROM dailies ORDER BY streak DESC LIMIT $1 - "# - ).bind(count as i32).fetch_all(db).await?; + "#, + count as i32 + ).fetch_all(db).await?; - let users: Vec<(_, i32)> = rows.iter().map(|row| (UserId::new(row.get::(0) as u64), row.get(1))).collect(); + let users: Vec<(_, i32)> = rows.iter().map(|row| (UserId::new(row.userid as u64), row.streak.unwrap_or(0))).collect(); let mut output = String::new(); for (id, streak) in users { diff --git a/src/commands/gambling/mod.rs b/src/commands/gambling/mod.rs index dc09d4e..c7c3044 100644 --- a/src/commands/gambling/mod.rs +++ b/src/commands/gambling/mod.rs @@ -9,7 +9,7 @@ pub mod blackjack; use crate::{inventory::{self, Inventory}, common::{Context, Error}}; use poise::serenity_prelude::{self as serenity, futures::StreamExt, UserId}; -use sqlx::{Row, PgExecutor}; +use sqlx::PgConnection; use std::collections::HashMap; #[derive(Clone)] @@ -79,16 +79,15 @@ mod items { } } -pub async fn get_balance<'a, E>(id: UserId, db: E) -> Result -where - E: PgExecutor<'a>, +pub async fn get_balance(id: UserId, db: &mut PgConnection) -> Result { - let row = sqlx::query("SELECT balance FROM bank WHERE id = $1") - .bind(id.get() as i64) - .fetch_one(db).await.ok(); + let row = sqlx::query!( + "SELECT balance FROM bank WHERE id = $1", + id.get() as i64 + ).fetch_one(db).await.ok(); let balance = if let Some(row) = row { - row.try_get("balance")? + row.balance.unwrap_or(100) } else { 100 }; @@ -96,14 +95,13 @@ where Ok(balance) } -pub async fn change_balance<'a, E>(id: UserId, balance: i32, db: E) -> Result<(), Error> -where - E: PgExecutor<'a>, +pub async fn change_balance(id: UserId, balance: i32, db: &mut PgConnection) -> Result<(), Error> { - sqlx::query("INSERT INTO bank (id, balance) VALUES ($1, $2) ON CONFLICT (id) DO UPDATE SET balance = EXCLUDED.balance") - .bind(id.get() as i64) - .bind(balance) - .execute(db).await?; + sqlx::query!( + r#"INSERT INTO bank (id, balance) VALUES ($1, $2) + ON CONFLICT (id) DO UPDATE SET balance = EXCLUDED.balance"#, + id.get() as i64, balance + ).execute(db).await?; Ok(()) } diff --git a/src/commands/gambling/shop.rs b/src/commands/gambling/shop.rs index ecec288..12243c5 100644 --- a/src/commands/gambling/shop.rs +++ b/src/commands/gambling/shop.rs @@ -16,8 +16,8 @@ async fn autocomplete_shop<'a>( ctx: Context<'_>, partial: &'a str, ) -> impl Iterator + use<'a> { - let db = &ctx.data().database; - let balance = super::get_balance(ctx.author().id, db).await; + let mut tx = ctx.data().database.begin().await.unwrap(); + let balance = super::get_balance(ctx.author().id, &mut *tx).await; ITEMS.values() .filter(move |(_, item)| item.name.contains(partial)) diff --git a/src/commands/self_roles/mod.rs b/src/commands/self_roles/mod.rs index 8a3c6ac..5d698d5 100644 --- a/src/commands/self_roles/mod.rs +++ b/src/commands/self_roles/mod.rs @@ -1,6 +1,6 @@ use crate::common::{Context, Error}; -use sqlx::{PgConnection, Row}; +use sqlx::PgConnection; use poise::serenity_prelude::{EditRole, GuildId, Permissions, RoleId, UserId}; mod whois; @@ -83,18 +83,14 @@ async fn create_role( /// Remove a row concerning a user's self role from the database pub async fn remove_user_role(user: UserId, guild: GuildId, db: &mut PgConnection) -> Result<(), Error> { - sqlx::query("DELETE FROM selfroles WHERE userid = $1 AND guildid = $2") - .bind(user.get() as i64) - .bind(guild.get() as i64) + sqlx::query!("DELETE FROM selfroles WHERE userid = $1 AND guildid = $2", user.get() as i64, guild.get() as i64) .execute(db).await?; Ok(()) } pub async fn remove_role(role: RoleId, guild: GuildId, db: &mut PgConnection) -> Result<(), Error> { - sqlx::query("DELETE FROM selfroles WHERE roleid = $1 AND guildid = $2") - .bind(role.get() as i64) - .bind(guild.get() as i64) + sqlx::query!("DELETE FROM selfroles WHERE roleid = $1 AND guildid = $2", role.get() as i64, guild.get() as i64) .execute(db).await?; Ok(()) @@ -102,10 +98,10 @@ pub async fn remove_role(role: RoleId, guild: GuildId, db: &mut PgConnection) -> /// Replace a user's custom role with a new one pub async fn update_user_role(user: UserId, guild: GuildId, role: RoleId, db: &mut PgConnection) -> Result<(), Error> { - sqlx::query("INSERT INTO selfroles (userid, guildid, roleid) VALUES($1, $2, $3) ON CONFLICT (userid, guildid) DO UPDATE SET roleid = EXCLUDED.roleid") - .bind(user.get() as i64) - .bind(guild.get() as i64) - .bind(role.get() as i64) + sqlx::query!( + r#"INSERT INTO selfroles (userid, guildid, roleid) VALUES($1, $2, $3) + ON CONFLICT (userid, guildid) DO UPDATE SET roleid = EXCLUDED.roleid"#, + user.get() as i64, guild.get() as i64, role.get() as i64) .execute(db).await?; Ok(()) @@ -113,12 +109,10 @@ pub async fn update_user_role(user: UserId, guild: GuildId, role: RoleId, db: &m /// Get a user's personal role id from the database pub async fn get_user_role(user: UserId, guild: GuildId, db: &mut PgConnection) -> Result, Error> { - match sqlx::query("SELECT roleid FROM selfroles WHERE userid = $1 AND guildid = $2") - .bind(user.get() as i64) - .bind(guild.get() as i64) + match sqlx::query!("SELECT roleid FROM selfroles WHERE userid = $1 AND guildid = $2", user.get() as i64, guild.get() as i64) .fetch_one(db).await { - Ok(row) => Ok(Some(RoleId::new(row.try_get::(0)? as u64))), + Ok(row) => Ok(Some(RoleId::new(row.roleid.unwrap() as u64))), Err(sqlx::Error::RowNotFound) => Ok(None), Err(e) => return Err(Box::new(e)), } @@ -126,12 +120,10 @@ pub async fn get_user_role(user: UserId, guild: GuildId, db: &mut PgConnection) /// Get a user from the role id pub async fn get_user_by_role(role: RoleId, guild: GuildId, db: &mut PgConnection) -> Result, Error> { - match sqlx::query("SELECT userid FROM selfroles WHERE roleid = $1 AND guildid = $2") - .bind(role.get() as i64) - .bind(guild.get() as i64) + match sqlx::query!("SELECT userid FROM selfroles WHERE roleid = $1 AND guildid = $2", role.get() as i64, guild.get() as i64) .fetch_one(db).await { - Ok(row) => Ok(Some(UserId::new(row.try_get::(0)? as u64))), + Ok(row) => Ok(Some(UserId::new(row.userid as u64))), Err(sqlx::Error::RowNotFound) => Ok(None), Err(e) => return Err(Box::new(e)), } diff --git a/src/commands/self_roles/whois.rs b/src/commands/self_roles/whois.rs index ff7f502..1302f93 100644 --- a/src/commands/self_roles/whois.rs +++ b/src/commands/self_roles/whois.rs @@ -3,7 +3,6 @@ use crate::common::{self, Context, Error}; use poise::serenity_prelude as serenity; use serenity::UserId; -use sqlx::Row; /// Let you know who is the owner of a role. #[poise::command(slash_command, prefix_command)] @@ -11,11 +10,10 @@ pub async fn whois(ctx: Context<'_>, role: serenity::Role) -> Result<(), Error> let db = &ctx.data().database; if let Some(guild) = ctx.guild_id() { - let user = match sqlx::query("SELECT userid FROM selfroles WHERE roleid = $1") - .bind(role.id.get() as i64) + let user = match sqlx::query!("SELECT userid FROM selfroles WHERE roleid = $1", role.id.get() as i64) .fetch_one(db).await { - Ok(row) => UserId::new(row.try_get::(0)? as u64), + Ok(row) => UserId::new(row.userid as u64), Err(sqlx::Error::RowNotFound) => { ctx.reply("This role is not owned by anyone.").await?; return Ok(());