-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathemotion_image_selector.py
More file actions
140 lines (112 loc) · 4.04 KB
/
emotion_image_selector.py
File metadata and controls
140 lines (112 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# emotion_image_selector.py
"""
Moduł dobierający najbliższe zdjęcie z datasetu OASIS
na podstawie znormalizowanej (0–1) walencji i pobudzenia
oraz generujący pośrednie kroki (checkpointy) w przestrzeni PAD.
Zakładamy plik `final_eeg_dataset.csv` o strukturze:
Theme,Class_Label,Valence_mean,Arousal_mean
Dog 6.jpg,Joy,6.49,5.03
...
"""
from __future__ import annotations
import math
from typing import Tuple, Sequence
import pandas as pd
# Stałe zakresów z opisu
VALENCE_MIN, VALENCE_MAX = 1.11, 6.49
AROUSAL_MIN, AROUSAL_MAX = 1.69, 5.47
DATASET_PATH = "final_eeg_dataset.csv"
def _scale_to_dataset_range(v_norm: float, a_norm: float) -> Tuple[float, float]:
"""
Skaluje walencję i pobudzenie z zakresu [0, 1] do zakresów datasetu OASIS.
Wartości spoza [0, 1] są przycinane (clamp).
"""
v_norm_clamped = max(0.0, min(1.0, float(v_norm)))
a_norm_clamped = max(0.0, min(1.0, float(a_norm)))
v_real = VALENCE_MIN + v_norm_clamped * (VALENCE_MAX - VALENCE_MIN)
a_real = AROUSAL_MIN + a_norm_clamped * (AROUSAL_MAX - AROUSAL_MIN)
return v_real, a_real
def load_oasis_dataset(path: str = DATASET_PATH) -> pd.DataFrame:
"""
Ładuje dataset OASIS do DataFrame i pilnuje, żeby były wymagane kolumny.
"""
try:
df = pd.read_csv(path)
except FileNotFoundError:
# Fallback - tworzymy pusty DF, żeby kod się nie wywalił od razu
print(f"UWAGA: Nie znaleziono pliku {path}. Tworzę przykładowy dataset.")
return pd.DataFrame({
"Theme": ["Astronaut 1.jpg"],
"Valence_mean": [3.5],
"Arousal_mean": [3.5]
})
required_cols = {"Theme", "Valence_mean", "Arousal_mean"}
missing = required_cols - set(df.columns)
if missing:
raise ValueError(f"Brakuje kolumn w CSV: {missing}")
return df
def get_closest_theme(
v_norm: float,
a_norm: float,
df: pd.DataFrame | None = None,
*,
recent_themes: Sequence[str] | None = None,
avoid_repeats: bool = True,
repeat_window: int = 3,
) -> str:
"""
Zwraca nazwę pliku z kolumny `Theme` z datasetu `final_eeg_dataset.csv`.
"""
if df is None:
df = load_oasis_dataset()
v_target, a_target = _scale_to_dataset_range(v_norm, a_norm)
distances = (
(df["Valence_mean"] - v_target) ** 2
+ (df["Arousal_mean"] - a_target) ** 2
) ** 0.5
# Jeśli nie unikamy powtórek albo nie mamy historii – klasyczne zachowanie
if (not avoid_repeats) or not recent_themes or repeat_window <= 0:
closest_idx = distances.idxmin()
return df.loc[closest_idx, "Theme"]
# bierzemy tylko ostatnie repeat_window motywów
recent_slice = list(recent_themes)[-repeat_window:]
recent_set = set(recent_slice)
# sortujemy indeksy po odległości (od najbliższych)
sorted_idx = distances.sort_values().index
# szukamy najbliższego Theme, którego nie ma w ostatnich N wyświetleniach
for idx in sorted_idx:
theme = df.loc[idx, "Theme"]
if theme not in recent_set:
return theme
# fallback: jeśli wszystkie są w historii, bierzemy najbliższy
closest_idx = distances.idxmin()
return df.loc[closest_idx, "Theme"]
def generate_path(
tar_a: float,
tar_v: float,
tar_d: float, # ignorowane, ale zostawione dla spójności interfejsu
cur_a: float,
cur_v: float,
cur_d: float, # ignorowane
step: float,
) -> Tuple[float, float, bool]:
"""
Generuje pojedynczy krok w stronę punktu docelowego w przestrzeni (Arousal, Valence).
"""
if step <= 0:
raise ValueError("Parametr 'step' musi być dodatni.")
da = tar_a - cur_a
dv = tar_v - cur_v
dist = math.sqrt(da * da + dv * dv)
if dist == 0 or dist <= step:
next_a = tar_a
next_v = tar_v
reached = True
else:
scale = step / dist
next_a = cur_a + da * scale
next_v = cur_v + dv * scale
reached = False
next_a = max(0.0, min(1.0, next_a))
next_v = max(0.0, min(1.0, next_v))
return next_a, next_v, reached