diff --git a/preprocessing/sports/SAR_data/SAR_class.py b/preprocessing/sports/SAR_data/SAR_class.py index 432381a..eb1de28 100644 --- a/preprocessing/sports/SAR_data/SAR_class.py +++ b/preprocessing/sports/SAR_data/SAR_class.py @@ -12,14 +12,21 @@ def __new__(cls, data_provider, state_def, *args, **kwargs): elif data_provider == "statsbomb": # For 'statsbomb', raise a NotImplementedError indicating it is not implemented raise NotImplementedError("StatsBomb SAR data is not implemented yet.") - elif data_provider == "robocup_2d": - # Add a new clause for 'robocup_2d' that raises a NotImplementedError for RL usage - raise NotImplementedError("RoboCup 2D SAR data is not implemented for RL. Please use a supported data provider.") + elif data_provider == "robocup_2d" and state_def in cls.state_list: + preprocess_method = kwargs.get("preprocess_method", "SAR") + if preprocess_method != "SAR2RL": + raise NotImplementedError( + "RoboCup 2D SAR data is only supported for preprocess_method='SAR2RL'." + ) + from .soccer.soccer_SAR_class import Soccer_SAR_data + + return Soccer_SAR_data(data_provider, state_def, *args, **kwargs) else: # If the data_provider is unrecognized or state_def is unrecongnized, raise a ValueError raise ValueError( f"Unsupported data provider '{data_provider}' or state definition '{state_def}'. " - f"Supported providers: {cls.sports}, Supported states: {cls.state_list}." + f"Supported providers: {cls.sports + ['robocup_2d (SAR2RL only)']}, " + f"Supported states: {cls.state_list}." )