Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion mmai25_hackathon/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
BaseSampler: Template for custom samplers, e.g., for multimodal sampling.
"""

import load_data
import pandas as pd
from torch.utils.data import Dataset, Sampler
from torch_geometric.data import DataLoader

__all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"]
__all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler", "MultimodalDataLoader"]


class BaseDataset(Dataset):
Expand Down Expand Up @@ -51,6 +53,7 @@ def __getitem__(self, idx: int):

def __repr__(self) -> str:
"""Return a string representation of the dataset."""

return f"{self.__class__.__name__}({self.extra_repr()})"

def extra_repr(self) -> str:
Expand All @@ -73,6 +76,9 @@ def __add__(self, other):
"""
raise NotImplementedError("Subclasses may implement __add__ method if needed.")

def __subject_list__(self):
"""Return a list with all the subject ids found in the dataset."""

@classmethod
def prepare_data(cls, *args, **kwargs):
"""
Expand Down Expand Up @@ -101,13 +107,27 @@ class CXRDataset(BaseDataset):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
relative_path = "mimic-iv/mimic-cxr-jpg-chest-radiographs-with-structured-labels-2.1.0/" # noqa: F841
raise NotImplementedError("Subclasses may implement prepare_data class method if needed.")

def extract_cxr(sefl):
load_data.load_chest_xray_image()


class ECGDataset(BaseDataset):
"""Example subclass for an ECG dataset."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
raise NotImplementedError("Subclasses may implement prepare_data class method if needed.")


class ClinicalNotes(BaseDataset):
"""Example subclass for a Clinical Notes dataset."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
raise NotImplementedError("Subclasses may implement prepare_data class method if needed.")


class BaseDataLoader(DataLoader):
Expand All @@ -132,6 +152,16 @@ class BaseDataLoader(DataLoader):
Note: This is not a hard requirement. Consider it a future-facing idea you can evolve.
"""

def __init__(self, dataset, batch_size=1, shuffle=False, follow_batch=None, exclude_keys=None, **kwargs):
super().__init__(
dataset,
batch_size=batch_size,
shuffle=shuffle,
follow_batch=follow_batch,
exclude_keys=exclude_keys,
**kwargs,
)


class MultimodalDataLoader(BaseDataLoader):
"""Example dataloader for handling multiple data modalities."""
Expand All @@ -141,6 +171,27 @@ def __init__(self, data_list, *args, **kwargs):
self.data_list = data_list


class ID_Mapper(BaseDataLoader):
"""ID Mapper for validating all IDs found accross the modalities list"""

def __init__(self, data_list, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_list = data_list
self.mapper = pd.DataFrame()

def __create_mapper_(self):
"""Function to creatre the mapper"""
raise NotImplementedError("Subclasses may implement prepare_data class method if needed.")

def __len__(self) -> int:
"""Return the number of samples in the dataset."""
raise NotImplementedError("Subclasses must implement __len__ method.")

def __getitem__(self, idx: int):
"""Return a single sample from the dataset."""
raise NotImplementedError("Subclasses must implement __getitem__ method.")


class BaseSampler(Sampler):
"""
Base sampler to extend for custom sampling strategies.
Expand All @@ -153,3 +204,15 @@ class BaseSampler(Sampler):
balanced or paired sampling before passing to `BaseDataLoader`.
Note: This is optional and meant as a design hint, not a constraint.
"""


class MultimodalDataSampler(BaseSampler):
"""Example dataloader for handling multiple data modalities."""

def __init__(self, data_list, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_list = data_list

def subject_extract(self):
"""Return all the info from a specific subject"""
raise NotImplementedError("Subclasses may implement prepare_data class method if needed.")
9 changes: 9 additions & 0 deletions mmai25_hackathon/load_data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from . import cxr, ecg, text # noqa: F403
from .cxr import * # noqa: F403 F401
from .ecg import * # noqa: F403 F401
from .text import * # noqa: F403 F401

__all__ = []
__all__ += cxr.__all__
__all__ += text.__all__
__all__ += ecg.__all__