Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions nanapi/routers/waicolle.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@
RE_SYMBOLES,
REROLLS_MAX_RANKS,
RNG,
ROLLS,
TagRoll,
UserRoll,
get_roll,
load_rolls,
)

Expand Down Expand Up @@ -331,12 +331,11 @@ async def player_roll(

# Get Roll
if roll_id is not None:
roll_getter = ROLLS.get(roll_id, None)
if roll_getter is None:
roll = await get_roll(roll_id)
if roll is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail='Roll Not Found'
)
roll = roll_getter()
reason = reason if reason is not None else roll_id
elif coupon_code is not None:
# Check eligibility
Expand Down Expand Up @@ -1453,7 +1452,7 @@ async def export_waifus(edgedb: AsyncIOClient = Depends(get_client_edgedb)):
@router.oauth2.get('/exports/daily', response_model=list[MediasPoolExportResult])
async def export_daily():
"""Export daily media pool."""
roll = TagRoll.get_daily()
roll = await TagRoll.get_daily()
await roll.load(get_edgedb())
assert roll.ids_al
return await medias_pool_export(get_edgedb(), ids_al=roll.ids_al)
176 changes: 80 additions & 96 deletions nanapi/utils/waicolle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import re
from contextlib import suppress
from datetime import date, datetime, timedelta
from functools import partial
from itertools import product
from typing import Any, Callable, Self, cast
from typing import Any, Callable, Coroutine, Self, cast

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -131,7 +132,7 @@ async def _charas_probas_from_pool(
async def after(self, executor: AsyncIOExecutor, discord_id: str):
pass

def __call__(self, *args: Any, **kwds: Any) -> Any:
async def __call__(self, *args: Any, **kwds: Any) -> Any:
return self


Expand Down Expand Up @@ -279,22 +280,18 @@ def get_current_date() -> date:
class TagRoll(BaseMediaRoll):
DAILY_BASE_PRICE = 150
DAILY_NB = 1
daily_rolls: dict[date, Self] = {}
daily_rolls: dict[str, Self] = {}

def __init__(
self,
nb: int,
tag: str,
price: int = 0,
min_rank: Rank | None = None,
max_rank: Rank | None = None,
forced_tag: str | None = None,
tag_date: date | None = None,
discount_first: bool = False,
):
super().__init__(nb, price, min_rank, max_rank, genred=False)
self.tag_date = tag_date
self.tag = forced_tag
self.discount_first = discount_first
self.tag = tag

async def get_name(self, executor: AsyncIOExecutor, discord_id: str) -> str:
pool = await self.get_pool(executor, discord_id)
Expand All @@ -313,49 +310,44 @@ async def get_name(self, executor: AsyncIOExecutor, discord_id: str) -> str:

async def get_price(self, executor: AsyncIOExecutor, discord_id: str) -> int:
price = await super().get_price(executor, discord_id)
if self.discount_first:
redis_player_key = f'{discord_id}_{get_current_date()}'
if not await user_daily_roll.get(redis_player_key, tx=executor):
price //= 2

# discount on first roll
redis_player_key = f'{discord_id}_{get_current_date()}'
if not await user_daily_roll.get(redis_player_key, tx=executor):
price //= 2

return price

async def load(self, executor: AsyncIOExecutor, force: bool = False):
if force or not self.loaded.is_set():
if not self.tag:
if not self.tag_date:
raise Exception('TagRoll must have a tag or a tag_date')
self.tag = await self.get_daily_tag(executor, self.tag_date)

resp = await media_select_ids_by_tag(executor, tag_name=self.tag, min_rank=60)
self.ids_al = [media.id_al for media in resp]
cls = self.__class__
cls.daily_rolls[self.tag] = self
self.loaded.set()

async def after(self, executor: AsyncIOExecutor, discord_id: str):
if self.discount_first:
redis_player_key = f'{discord_id}_{get_current_date()}'
await user_daily_roll.set(True, sub_key=redis_player_key)
redis_player_key = f'{discord_id}_{get_current_date()}'
await user_daily_roll.set(True, sub_key=redis_player_key)

@classmethod
def get_daily(cls) -> Self:
async def get_daily(cls) -> 'TagRoll':
today = get_current_date()
if today not in cls.daily_rolls:
cls.daily_rolls[today] = cls(
price=cls.DAILY_BASE_PRICE, nb=cls.DAILY_NB, tag_date=today, discount_first=True
)
tomorrow = today + timedelta(days=1)
if tomorrow not in cls.daily_rolls:
cls.daily_rolls[tomorrow] = cls(
price=cls.DAILY_BASE_PRICE, nb=cls.DAILY_NB, tag_date=tomorrow, discount_first=True
)
asyncio.create_task(cls.daily_rolls[tomorrow].load(get_edgedb(), force=True))
return cls.daily_rolls[today]

@staticmethod
async def get_daily_tag(executor: AsyncIOExecutor, tag_date: date) -> str:
asyncio.create_task(cls.get_daily_tag(get_edgedb(), tag_date=tomorrow))

return await cls.get_daily_tag(get_edgedb(), tag_date=today)

@classmethod
async def get_daily_tag(cls, executor: AsyncIOExecutor, tag_date: date) -> 'TagRoll':
create_roll = partial(cls, nb=cls.DAILY_NB, price=cls.DAILY_BASE_PRICE)
tag = await daily_tag.get(str(tag_date))

if tag is not None:
return tag
roll = cls.daily_rolls.get(tag, create_roll(tag=tag))
await roll.load(executor)
return roll

yesterday = tag_date - timedelta(days=1)
yesterday_tag = await daily_tag.get(str(yesterday))
Expand All @@ -368,42 +360,37 @@ async def get_daily_tag(executor: AsyncIOExecutor, tag_date: date) -> str:
RNG.shuffle(tags)

for tag in tags:
roll = TagRoll(1, forced_tag=tag)
roll = create_roll(tag=tag)
await roll.load(executor, force=True)
_, rates = await roll._roll(executor)
if len(rates) > 400 and (1 / float(np.max(rates))) > 50:
await daily_tag.set(tag, sub_key=str(tag_date))
return tag
return roll

raise RuntimeError('Could not find daily roll tag')


class SeasonalRoll(BaseMediaRoll):
WEEKLY_BASE_PRICE = 600
WEEKLY_NB = 5
weekly_rolls: dict[tuple[int, int], Self] = {}
weekly_rolls: dict[tuple[int, MEDIA_SELECT_IDS_BY_SEASON_SEASON], 'SeasonalRoll'] = {}

def __init__(
self,
nb: int,
season_year: int,
season: MEDIA_SELECT_IDS_BY_SEASON_SEASON,
price: int = 0,
min_rank: Rank | None = None,
max_rank: Rank | None = None,
week_key: tuple[int, int] | None = None,
season_year: int | None = None,
season: MEDIA_SELECT_IDS_BY_SEASON_SEASON | None = None,
discount_first: bool = False,
):
super().__init__(nb, price, min_rank, max_rank)
self.week_key = week_key
self.season_year = season_year
self.season: MEDIA_SELECT_IDS_BY_SEASON_SEASON | None = season
self.discount_first = discount_first
self.season: MEDIA_SELECT_IDS_BY_SEASON_SEASON = season

async def get_name(self, executor: AsyncIOExecutor, discord_id: str) -> str:
assert self.season_year
assert self.season

pool = await self.get_pool(executor, discord_id)

min_rank = S
Expand All @@ -423,71 +410,60 @@ async def get_name(self, executor: AsyncIOExecutor, discord_id: str) -> str:

async def get_price(self, executor: AsyncIOExecutor, discord_id: str) -> int:
price = await super().get_price(executor, discord_id)
if self.discount_first:
curr_date = get_current_date().isocalendar()
redis_player_key = f'{discord_id}_{(curr_date.year, curr_date.week)}'
if not await user_weekly_roll.get(redis_player_key, tx=executor):
price //= 2

# discount on first roll
curr_date = get_current_date().isocalendar()
redis_player_key = f'{discord_id}_{(curr_date.year, curr_date.week)}'
if not await user_weekly_roll.get(redis_player_key, tx=executor):
price //= 2

return price

async def load(self, executor: AsyncIOExecutor, force: bool = False):
async def load(self, executor: AsyncIOExecutor, force: bool = False) -> None:
if force or not self.loaded.is_set():
if not self.season_year or not self.season:
if not self.week_key:
raise Exception('SeasonalRoll must have a season or a week_key')
self.season_year, self.season = await self.get_weekly_season(
executor, self.week_key
)

resp = await media_select_ids_by_season(
executor,
season_year=self.season_year,
season=self.season,
)
self.ids_al = [media.id_al for media in resp]
cls = self.__class__
cls.weekly_rolls[self.season_year, self.season] = self
self.loaded.set()

async def after(self, executor: AsyncIOExecutor, discord_id: str):
if self.discount_first:
curr_date = get_current_date().isocalendar()
redis_player_key = f'{discord_id}_{(curr_date.year, curr_date.week)}'
await user_weekly_roll.set(True, sub_key=redis_player_key)
curr_date = get_current_date().isocalendar()
redis_player_key = f'{discord_id}_{(curr_date.year, curr_date.week)}'
await user_weekly_roll.set(True, sub_key=redis_player_key)

@classmethod
def get_weekly(cls) -> Self:
async def get_weekly(cls) -> 'SeasonalRoll':
today = get_current_date()
today_iso = today.isocalendar()
week_key = (today_iso.year, today_iso.week)
if week_key not in cls.weekly_rolls:
cls.weekly_rolls[week_key] = cls(
price=cls.WEEKLY_BASE_PRICE,
nb=cls.WEEKLY_NB,
week_key=week_key,
discount_first=True,
)
next_week_key = (
(today_iso.year, today_iso.week + 1)
if today_iso.week < 52
else (today_iso.year + 1, 1)
)
if next_week_key not in cls.weekly_rolls:
cls.weekly_rolls[next_week_key] = cls(
price=cls.WEEKLY_BASE_PRICE,
nb=cls.WEEKLY_NB,
week_key=next_week_key,
discount_first=True,
)
asyncio.create_task(cls.weekly_rolls[next_week_key].load(get_edgedb(), force=True))
return cls.weekly_rolls[week_key]

@staticmethod
next_week = today + timedelta(weeks=1)
next_week_iso = next_week.isocalendar()
next_week_key = next_week.year, next_week_iso.week

asyncio.create_task(cls.get_weekly_season(get_edgedb(), week_key=next_week_key))

return await cls.get_weekly_season(get_edgedb(), week_key)

@classmethod
async def get_weekly_season(
executor: AsyncIOExecutor, week_key: tuple[int, int]
) -> tuple[int, MEDIA_SELECT_IDS_BY_SEASON_SEASON]:
cls, executor: AsyncIOExecutor, week_key: tuple[int, int]
) -> 'SeasonalRoll':
create_roll = partial(cls, nb=cls.WEEKLY_NB, price=cls.WEEKLY_BASE_PRICE)
saved = await weekly_season.get(str(week_key))
if saved:
year, season = saved.split('_')
return (int(year), cast(MEDIA_SELECT_IDS_BY_SEASON_SEASON, season))
year_str, season_str = saved.split('_')
year, season = int(year_str), cast(MEDIA_SELECT_IDS_BY_SEASON_SEASON, season_str)
key = year, season

roll = cls.weekly_rolls.get(key, create_roll(season_year=year, season=season))
await roll.load(executor)
return roll

roll_year, roll_week = week_key
last_week_key = (roll_year, roll_week - 1) if roll_week > 1 else (roll_year - 1, 52)
Expand All @@ -512,18 +488,18 @@ async def get_weekly_season(

RNG.shuffle(seasons)

for year, season in seasons:
roll = SeasonalRoll(1, season_year=year, season=season)
for year_str, season_str in seasons:
roll = create_roll(season_year=year_str, season=season_str)
await roll.load(executor, force=True)
_, rates = await roll._roll(executor)
if len(rates) > 400 and (1 / float(np.max(rates))) > 50:
await weekly_season.set(f'{year}_{season}', sub_key=str(week_key))
return year, season
await weekly_season.set(f'{year_str}_{season_str}', sub_key=str(week_key))
return roll

raise RuntimeError('Could not find weekly roll season')


ROLLS: dict[str, Callable[[], BaseRoll]] = {
ROLLS: dict[str, Callable[[], Coroutine[None, None, BaseRoll]]] = {
'A': UserRoll(price=250, nb=5, max_rank=E),
'B': UserRoll(price=75, nb=1, max_rank=C),
'C': UserRoll(price=300, nb=5, max_rank=C),
Expand All @@ -537,7 +513,15 @@ async def get_weekly_season(
}


async def get_roll(roll_id: str) -> BaseRoll | None:
roll = ROLLS.get(roll_id)
if roll is None:
return None

return await roll()


async def load_rolls():
rolls = {roll_id: roll_getter() for roll_id, roll_getter in ROLLS.items()}
rolls = {roll_id: await roll_getter() for roll_id, roll_getter in ROLLS.items()}
await asyncio.gather(*[roll.load(get_edgedb()) for roll in rolls.values()])
return rolls
Loading