Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 70 additions & 37 deletions assign_roles/assign_roles.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import discord

from redbot.core import commands # Changed from discord.ext
Expand Down Expand Up @@ -34,6 +36,25 @@ def __init__(self, bot: Red):
self.config = Config.get_conf(self, identifier=73600, force_registration=True)
self.config.register_guild(roles={})

# Lock to prevent race conditions when updating roles in Config
self._config_locks = {} # {guild_id: asyncio.Lock}
self._locks_creation_lock = asyncio.Lock() # Lock for creating per-guild locks

def _get_guild_lock(self, guild_id: int) -> asyncio.Lock:
"""Get or create a lock for a specific guild to prevent config race conditions.

Args:
guild_id: The ID of the guild

Returns:
An asyncio.Lock for the guild
"""
if guild_id not in self._config_locks:
# Use a lock to ensure only one lock is created per guild
# Note: We can't await here, but this is safe because dict access is atomic
self._config_locks[guild_id] = asyncio.Lock()
return self._config_locks[guild_id]

# Events

# Commands
Expand Down Expand Up @@ -92,25 +113,33 @@ async def authorise(self, ctx, authorised_role: discord.Role, giveable_role: dis
await ctx.defer(ephemeral=True)

gld = ctx.guild
server_dict = await self.config.guild(gld).roles()

author_max_role = max(r for r in ctx.author.roles)
authorised_id = str(authorised_role.id)
giveable_id = str(giveable_role.id)

if authorised_role.is_default(): # Role to be authorised should not be @everyone.
notice = self.AUTHORISE_NO_EVERYONE
elif giveable_role.is_default(): # Same goes for role to be given.
notice = self.AUTHORISE_NOT_DEFAULT
elif authorised_role >= author_max_role and ctx.author != gld.owner: # Hierarchical role order check.
notice = self.AUTHORISE_NO_HIGHER
# Check if "pair" already exists.
elif giveable_id in server_dict and authorised_id in server_dict[giveable_id]:
notice = self.AUTHORISE_EXISTS
else: # Role authorisation is valid.
server_dict.setdefault(giveable_id, []).append(authorised_id)
await self.config.guild(gld).roles.set(server_dict)
notice = self.AUTHORISE_SUCCESS.format(authorised_role.name, giveable_role.name)
# Use lock to prevent race conditions when multiple users authorize roles simultaneously
lock = self._get_guild_lock(gld.id)
async with lock:
server_dict = await self.config.guild(gld).roles()

author_max_role = max(r for r in ctx.author.roles)
authorised_id = str(authorised_role.id)
giveable_id = str(giveable_role.id)

if authorised_role.is_default(): # Role to be authorised should not be @everyone.
notice = self.AUTHORISE_NO_EVERYONE
elif giveable_role.is_default(): # Same goes for role to be given.
notice = self.AUTHORISE_NOT_DEFAULT
elif authorised_role >= author_max_role and ctx.author != gld.owner: # Hierarchical role order check.
notice = self.AUTHORISE_NO_HIGHER
# Check if "pair" already exists.
elif giveable_id in server_dict and authorised_id in server_dict[giveable_id]:
notice = self.AUTHORISE_EXISTS
else: # Role authorisation is valid.
if giveable_id not in server_dict:
server_dict[giveable_id] = []
# Double-check for duplicates before appending (safety check)
if authorised_id not in server_dict[giveable_id]:
server_dict[giveable_id].append(authorised_id)
await self.config.guild(gld).roles.set(server_dict)
notice = self.AUTHORISE_SUCCESS.format(authorised_role.name, giveable_role.name)
await ctx.send(notice, ephemeral=True)

@commands.guild_only()
Expand All @@ -131,26 +160,30 @@ async def deauthorise(self, ctx, authorised_role: discord.Role, giveable_role: d
await ctx.defer(ephemeral=True)

gld = ctx.guild
server_dict = await self.config.guild(gld).roles()

author_max_role = max(r for r in ctx.author.roles)
authorised_id = str(authorised_role.id)
giveable_id = str(giveable_role.id)

if authorised_role.is_default(): # Role to be de-authorised should not be @everyone.
notice = self.AUTHORISE_NO_EVERYONE
elif giveable_role.is_default(): # Same goes for role to be given.
notice = self.AUTHORISE_NOT_DEFAULT
elif authorised_role >= author_max_role and ctx.author != gld.owner: # Hierarchical role order check.
notice = self.AUTHORISE_NO_HIGHER
elif giveable_id not in server_dict:
notice = self.AUTHORISE_EMPTY.format(giveable_role.name)
elif authorised_id not in server_dict[giveable_id]:
notice = self.AUTHORISE_MISMATCH.format(authorised_role.name, giveable_role.name)
else: # Role de-authorisation is valid.
server_dict[giveable_id].remove(authorised_id)
await self.config.guild(gld).roles.set(server_dict)
notice = self.DEAUTHORISE_SUCCESS.format(authorised_role.name, giveable_role.name)
# Use lock to prevent race conditions when multiple users deauthorize roles simultaneously
lock = self._get_guild_lock(gld.id)
async with lock:
server_dict = await self.config.guild(gld).roles()

author_max_role = max(r for r in ctx.author.roles)
authorised_id = str(authorised_role.id)
giveable_id = str(giveable_role.id)

if authorised_role.is_default(): # Role to be de-authorised should not be @everyone.
notice = self.AUTHORISE_NO_EVERYONE
elif giveable_role.is_default(): # Same goes for role to be given.
notice = self.AUTHORISE_NOT_DEFAULT
elif authorised_role >= author_max_role and ctx.author != gld.owner: # Hierarchical role order check.
notice = self.AUTHORISE_NO_HIGHER
elif giveable_id not in server_dict:
notice = self.AUTHORISE_EMPTY.format(giveable_role.name)
elif authorised_id not in server_dict[giveable_id]:
notice = self.AUTHORISE_MISMATCH.format(authorised_role.name, giveable_role.name)
else: # Role de-authorisation is valid.
server_dict[giveable_id].remove(authorised_id)
await self.config.guild(gld).roles.set(server_dict)
notice = self.DEAUTHORISE_SUCCESS.format(authorised_role.name, giveable_role.name)
await ctx.send(notice, ephemeral=True)

@commands.guild_only()
Expand Down
186 changes: 109 additions & 77 deletions party/party.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import secrets
from datetime import datetime, timezone
Expand Down Expand Up @@ -127,16 +128,19 @@ async def on_submit(self, interaction: discord.Interaction):
new_description = self.description_input.value.strip() or None

# Update the party data
async with self.cog.config.guild(interaction.guild).parties() as parties:
if self.party_id not in parties:
await interaction.followup.send("❌ Party not found.", ephemeral=True)
return
# Use lock to prevent race conditions
lock = self.cog._get_guild_lock(interaction.guild.id)
async with lock:
async with self.cog.config.guild(interaction.guild).parties() as parties:
if self.party_id not in parties:
await interaction.followup.send("❌ Party not found.", ephemeral=True)
return

old_title = parties[self.party_id]['name']
old_description = parties[self.party_id].get('description')
old_title = parties[self.party_id]['name']
old_description = parties[self.party_id].get('description')

parties[self.party_id]['name'] = new_title
parties[self.party_id]['description'] = new_description
parties[self.party_id]['name'] = new_title
parties[self.party_id]['description'] = new_description

# Update the party message
await self.cog.update_party_message(interaction.guild.id, self.party_id)
Expand Down Expand Up @@ -271,8 +275,11 @@ async def on_submit(self, interaction: discord.Interaction):
party["signups"][role] = []

# Save the party
async with self.cog.config.guild(interaction.guild).parties() as parties:
parties[party_id] = party
# Use lock to prevent race conditions when multiple parties are created simultaneously
lock = self.cog._get_guild_lock(interaction.guild.id)
async with lock:
async with self.cog.config.guild(interaction.guild).parties() as parties:
parties[party_id] = party

# Create the party embed
embed = await self.cog.create_party_embed(party, interaction.guild)
Expand All @@ -285,9 +292,12 @@ async def on_submit(self, interaction: discord.Interaction):
message = await channel.send(embed=embed, view=view)

# Save the message ID and channel ID
async with self.cog.config.guild(interaction.guild).parties() as parties:
parties[party_id]["message_id"] = message.id
parties[party_id]["channel_id"] = channel.id
# Use lock to prevent race conditions when multiple parties are created simultaneously
lock = self.cog._get_guild_lock(interaction.guild.id)
async with lock:
async with self.cog.config.guild(interaction.guild).parties() as parties:
parties[party_id]["message_id"] = message.id
parties[party_id]["channel_id"] = channel.id

# Create modlog entry
await self.cog.create_party_modlog(
Expand Down Expand Up @@ -737,11 +747,30 @@ def __init__(self, bot):
}
self.config.register_guild(**default_guild)

# Lock to prevent race conditions when updating parties in Config
self._config_locks = {} # {guild_id: asyncio.Lock}
self._locks_creation_lock = asyncio.Lock() # Lock for creating per-guild locks

# Load persistent views for existing parties
self.bot.loop.create_task(self._register_persistent_views())
# Register custom modlog casetypes
self.bot.loop.create_task(self._register_casetypes())

def _get_guild_lock(self, guild_id: int) -> asyncio.Lock:
"""Get or create a lock for a specific guild to prevent config race conditions.

Args:
guild_id: The ID of the guild

Returns:
An asyncio.Lock for the guild
"""
if guild_id not in self._config_locks:
# Use a lock to ensure only one lock is created per guild
# Note: We can't await here, but this is safe because dict access is atomic
self._config_locks[guild_id] = asyncio.Lock()
return self._config_locks[guild_id]

@staticmethod
def parse_allow_multiple(allow_multiple_text: str) -> tuple[bool, Optional[str]]:
"""Parse and validate allow_multiple_per_role setting.
Expand Down Expand Up @@ -1021,74 +1050,77 @@ async def signup_user(
guild_id = interaction.guild.id
user_id = str(interaction.user.id)

async with self.config.guild_from_id(guild_id).parties() as parties:
if party_id not in parties:
if disabled_view:
# Edit the original message to show error and remove the select view
if deferred:
await interaction.edit_original_response(
content="❌ Party not found.",
view=None
)
else:
await interaction.response.edit_message(
content="❌ Party not found.",
view=None
)
else:
if deferred:
await interaction.followup.send(
"❌ Party not found.",
ephemeral=True
)
# Use lock to prevent race conditions when multiple users sign up simultaneously
lock = self._get_guild_lock(guild_id)
async with lock:
async with self.config.guild_from_id(guild_id).parties() as parties:
if party_id not in parties:
if disabled_view:
# Edit the original message to show error and remove the select view
if deferred:
await interaction.edit_original_response(
content="❌ Party not found.",
view=None
)
else:
await interaction.response.edit_message(
content="❌ Party not found.",
view=None
)
else:
await interaction.response.send_message(
"❌ Party not found.",
ephemeral=True
)
return

party = parties[party_id]
allow_multiple = party.get("allow_multiple_per_role", True)
if deferred:
await interaction.followup.send(
"❌ Party not found.",
ephemeral=True
)
else:
await interaction.response.send_message(
"❌ Party not found.",
ephemeral=True
)
return

# Remove user from any existing role first
for role_name, users in party["signups"].items():
if user_id in users:
party["signups"][role_name].remove(user_id)

# Check if role exists in signups, if not create it
if role not in party["signups"]:
party["signups"][role] = []

# Check if multiple signups allowed
if not allow_multiple and len(party["signups"][role]) > 0:
if disabled_view:
# Edit the original message to show error and remove the select view
if deferred:
await interaction.edit_original_response(
content=f"❌ The role **{role}** is already full (multiple signups not allowed).",
view=None
)
party = parties[party_id]
allow_multiple = party.get("allow_multiple_per_role", True)

# Remove user from any existing role first
for role_name, users in party["signups"].items():
if user_id in users:
party["signups"][role_name].remove(user_id)

# Check if role exists in signups, if not create it
if role not in party["signups"]:
party["signups"][role] = []

# Check if multiple signups allowed
if not allow_multiple and len(party["signups"][role]) > 0:
if disabled_view:
# Edit the original message to show error and remove the select view
if deferred:
await interaction.edit_original_response(
content=f"❌ The role **{role}** is already full (multiple signups not allowed).",
view=None
)
else:
await interaction.response.edit_message(
content=f"❌ The role **{role}** is already full (multiple signups not allowed).",
view=None
)
else:
await interaction.response.edit_message(
content=f"❌ The role **{role}** is already full (multiple signups not allowed).",
view=None
)
else:
if deferred:
await interaction.followup.send(
f"❌ The role **{role}** is already full (multiple signups not allowed).",
ephemeral=True
)
else:
await interaction.response.send_message(
f"❌ The role **{role}** is already full (multiple signups not allowed).",
ephemeral=True
)
return
if deferred:
await interaction.followup.send(
f"❌ The role **{role}** is already full (multiple signups not allowed).",
ephemeral=True
)
else:
await interaction.response.send_message(
f"❌ The role **{role}** is already full (multiple signups not allowed).",
ephemeral=True
)
return

# Add user to the role
party["signups"][role].append(user_id)
# Add user to the role
party["signups"][role].append(user_id)

# Send success response
if disabled_view:
Expand Down
Loading
Loading