Skip to content

Commit bd54a80

Browse files
authored
Merge pull request #85 from open-starlab/akshatgarg06-patch-3
Add support for RoboCup 2D SAR data with restrictions
2 parents 0eee5c3 + 54e4e07 commit bd54a80

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

preprocessing/sports/SAR_data/SAR_class.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,21 @@ def __new__(cls, data_provider, state_def, *args, **kwargs):
1212
elif data_provider == "statsbomb":
1313
# For 'statsbomb', raise a NotImplementedError indicating it is not implemented
1414
raise NotImplementedError("StatsBomb SAR data is not implemented yet.")
15-
elif data_provider == "robocup_2d":
16-
# Add a new clause for 'robocup_2d' that raises a NotImplementedError for RL usage
17-
raise NotImplementedError("RoboCup 2D SAR data is not implemented for RL. Please use a supported data provider.")
15+
elif data_provider == "robocup_2d" and state_def in cls.state_list:
16+
preprocess_method = kwargs.get("preprocess_method", "SAR")
17+
if preprocess_method != "SAR2RL":
18+
raise NotImplementedError(
19+
"RoboCup 2D SAR data is only supported for preprocess_method='SAR2RL'."
20+
)
21+
from .soccer.soccer_SAR_class import Soccer_SAR_data
22+
23+
return Soccer_SAR_data(data_provider, state_def, *args, **kwargs)
1824
else:
1925
# If the data_provider is unrecognized or state_def is unrecongnized, raise a ValueError
2026
raise ValueError(
2127
f"Unsupported data provider '{data_provider}' or state definition '{state_def}'. "
22-
f"Supported providers: {cls.sports}, Supported states: {cls.state_list}."
28+
f"Supported providers: {cls.sports + ['robocup_2d (SAR2RL only)']}, "
29+
f"Supported states: {cls.state_list}."
2330
)
2431

2532

0 commit comments

Comments
 (0)