Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

*.dill
*.pkl
*.feather
*.pb
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ Currently, the dataloader supports interfacing with the following datasets:
| UCY - Zara1 | `eupeds_zara1` | `train`, `val`, `train_loo`, `val_loo`, `test_loo` | `cyprus` | The Zara1 scene from the UCY Pedestrians dataset | 0.4s (2.5Hz) | |
| UCY - Zara2 | `eupeds_zara2` | `train`, `val`, `train_loo`, `val_loo`, `test_loo` | `cyprus` | The Zara2 scene from the UCY Pedestrians dataset | 0.4s (2.5Hz) | |
| Stanford Drone Dataset | `sdd` | `train`, `val`, `test` | `stanford` | Stanford Drone Dataset (60 scenes, randomly split 42/9/9 (70%/15%/15%) for training/validation/test) | 0.0333...s (30Hz) | |
| USDZ Format | `usdz` | N/A | N/A | Universal Scene Description (USDZ) format datasets | Configurable | :white_check_mark: |

### Adding New Datasets
The code that interfaces the original datasets (dealing with their unique formats) can be found in `src/trajdata/dataset_specific`.
Expand Down
2 changes: 1 addition & 1 deletion src/trajdata/caching/df_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def is_map_cached(
vector_map_path,
kdtrees_path,
rtrees_path,
) = DataFrameCache.get_map_paths(cache_path, env_name, map_name, resolution)
) = DataFrameCache.get_map_paths(cache_path, env_name, map_name)

# TODO(bivanovic): For now, rtrees are optional to have in the cache.
# In the future, they may be required (likely after we develop an
Expand Down
11 changes: 9 additions & 2 deletions src/trajdata/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
max_agent_num: Optional[int] = None,
max_neighbor_num: Optional[int] = None,
ego_only: Optional[bool] = False,
dataset_kwargs: Optional[Dict[str, Any]] = None,
data_dirs: Dict[str, str] = {
"eupeds_eth": "~/datasets/eth_ucy_peds",
"eupeds_hotel": "~/datasets/eth_ucy_peds",
Expand Down Expand Up @@ -258,7 +259,12 @@ def __init__(
s.lower() for s in self.scene_description_contains
]

self.envs: List[RawDataset] = env_utils.get_raw_datasets(data_dirs)
# Pass dataset-specific kwargs to raw datasets
dataset_kwargs = dataset_kwargs or {}
self.dataset_kwargs = dataset_kwargs # Save for later use in parallel preprocessing
self.envs: List[RawDataset] = env_utils.get_raw_datasets(
data_dirs, **dataset_kwargs
)
self.envs_dict: Dict[str, RawDataset] = {env.name: env for env in self.envs}

matching_datasets: List[SceneTag] = self._get_matching_scene_tags(desired_data)
Expand All @@ -273,7 +279,7 @@ def __init__(
if self.incl_vector_map:
self._map_api = MapAPI(
self.cache_path,
keep_in_memory=vector_map_params.get("keep_in_memory", True),
keep_in_memory=self.vector_map_params.get("keep_in_memory", True),
)

self.cache_lane_graphs = cache_lane_graphs
Expand Down Expand Up @@ -974,6 +980,7 @@ def _preprocess_scene_data(
self.desired_dt,
self.cache_class,
self.rebuild_cache,
self.dataset_kwargs,
)

# Done with this list. Cutting memory usage because
Expand Down
98 changes: 93 additions & 5 deletions src/trajdata/dataset_specific/nuplan/nuplan_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,40 @@


class NuplanDataset(RawDataset):
def __init__(
self,
name: str,
data_dir: str,
parallelizable: bool = True,
has_maps: bool = True,
central_tokens_config: Optional[List[Dict[str, Any]]] = None,
num_timesteps_before: Optional[int] = None,
num_timesteps_after: Optional[int] = None,
use_central_tokens: bool = False,
) -> None:
"""
Args:
name: Dataset name
data_dir: Data directory path
parallelizable: Whether dataset is parallelizable
has_maps: Whether dataset has maps
central_tokens_config: Optional central tokens configuration
num_timesteps_before: Number of timesteps before central token
num_timesteps_after: Number of timesteps after central token
use_central_tokens: Whether to use central token mode (default: False for backward compatibility)
If central_tokens_config is provided, this will be set to True automatically
"""
super().__init__(name, data_dir, parallelizable, has_maps)
self._central_tokens_config = central_tokens_config
self._num_timesteps_before = num_timesteps_before if num_timesteps_before is not None else 30
self._num_timesteps_after = num_timesteps_after if num_timesteps_after is not None else 80

# Auto-enable central token mode if config is provided
if central_tokens_config is not None:
use_central_tokens = True
self._use_central_tokens = use_central_tokens


def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata:
all_log_splits: Dict[str, List[str]] = nuplan_utils.create_splits_logs()

Expand Down Expand Up @@ -75,10 +109,38 @@ def load_dataset_obj(self, verbose: bool = False) -> None:

if self.name == "nuplan_mini":
subfolder = "mini"
elif self.name == "nuplan_test":
subfolder = "test"
elif self.name.startswith("nuplan"):
subfolder = "trainval"

self.dataset_obj = nuplan_utils.NuPlanObject(self.metadata.data_dir, subfolder)
# Create NuPlanObject with central token configuration
self.dataset_obj = nuplan_utils.NuPlanObject(
self.metadata.data_dir,
subfolder,
central_tokens_config=self._central_tokens_config,
num_timesteps_before=self._num_timesteps_before,
num_timesteps_after=self._num_timesteps_after,
use_central_tokens=self._use_central_tokens,
)

def _parse_originating_log(self, scene_name: str) -> str:
"""
Extract originating log name from scene name.

Supports both formats:
- Old format: logfile=token
- New format: logfile-token

Args:
scene_name: Scene name string

Returns:
Originating log name
"""
if "=" in scene_name:
return scene_name.split("=")[0]
return scene_name.rsplit("-", 1)[0]

def _get_matching_scenes_from_obj(
self,
Expand All @@ -93,7 +155,7 @@ def _get_matching_scenes_from_obj(
scenes_list: List[SceneMetadata] = list()
for idx, scene_record in enumerate(self.dataset_obj.scenes):
scene_name: str = scene_record["name"]
originating_log: str = scene_name.split("=")[0]
originating_log: str = self._parse_originating_log(scene_name)
# scene_desc: str = scene_record["description"].lower()
scene_location: str = scene_record["location"]
scene_split: str = self.metadata.scene_split_map.get(
Expand Down Expand Up @@ -194,30 +256,38 @@ def get_scene(self, scene_info: SceneMetadata) -> Scene:
scene_record: Dict[str, str] = self.dataset_obj.scenes[data_idx]

scene_name: str = scene_record["name"]
originating_log: str = scene_name.split("=")[0]
originating_log: str = self._parse_originating_log(scene_name)
# scene_desc: str = scene_record["description"].lower()
scene_location: str = scene_record["location"]
scene_split: str = self.metadata.scene_split_map.get(
originating_log, default_split
)
scene_length: int = scene_record["num_timesteps"]

# Store start_idx and end_idx in data_access_info for later use.
data_access_info = scene_record.copy() if "start_idx" in scene_record or "end_idx" in scene_record else scene_record

return Scene(
self.metadata,
scene_name,
scene_location,
scene_split,
scene_length,
data_idx,
scene_record,
data_access_info,
# scene_desc,
)

def get_agent_info(
self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache]
) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]:
# instantiate VectorMap from map_api if necessary
self.dataset_obj.open_db(scene.name.split("=")[0] + ".db")
# Support both old format (logfile=token) and new format (logfile-token)
if "=" in scene.name:
log_filename = scene.name.split("=")[0]
else:
log_filename = scene.name.rsplit("-", 1)[0]
self.dataset_obj.open_db(log_filename + ".db")

ego_agent_info: AgentMetadata = AgentMetadata(
name="ego",
Expand All @@ -233,6 +303,7 @@ def get_agent_info(
[ego_agent_info] for _ in range(scene.length_timesteps)
]

# Automatically select the correct method based on scene type
all_frames: pd.DataFrame = self.dataset_obj.get_scene_frames(scene)

ego_df = (
Expand All @@ -252,6 +323,23 @@ def get_agent_info(
agents_df: pd.DataFrame = self.dataset_obj.get_detected_agents(lpc_tokens)
tls_df: pd.DataFrame = self.dataset_obj.get_traffic_light_status(lpc_tokens)

# Extract sensor calibration information.
# Support both old format (logfile=token) and new format (logfile-token)
if "=" in scene.name:
log_filename = scene.name.split("=")[0]
else:
log_filename = scene.name.rsplit("-", 1)[0]
sensor_calib = self.dataset_obj.get_sensor_calibration(log_filename)

# Store calibration information into Scene.data_access_info.
if scene.data_access_info is None:
scene.data_access_info = {}
elif not isinstance(scene.data_access_info, dict):
# If there is other existing non-dict data, wrap it into a dictionary.
scene.data_access_info = {'original': scene.data_access_info}

scene.data_access_info['sensor_calibration'] = sensor_calib

self.dataset_obj.close_db()

agents_df["scene_ts"] = agents_df["lidar_pc_token"].map(
Expand Down
Loading