Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hirundo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
)
from .dataset_qa import (
ClassificationRunArgs,
Domain,
HirundoError,
ModalityType,
ObjectDetectionRunArgs,
QADataset,
RunArgs,
Expand Down Expand Up @@ -43,7 +43,7 @@
"KeylabsObjSegImages",
"KeylabsObjSegVideo",
"QADataset",
"Domain",
"ModalityType",
"RunArgs",
"ClassificationRunArgs",
"ObjectDetectionRunArgs",
Expand Down
27 changes: 15 additions & 12 deletions hirundo/dataset_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,27 +128,27 @@ class AugmentationName(str, Enum):
GAUSSIAN_BLUR = "GaussianBlur"


class Domain(str, Enum):
class ModalityType(str, Enum):
RADAR = "RADAR"
VISION = "VISION"
SPEECH = "SPEECH"
TABULAR = "TABULAR"


DOMAIN_TO_SUPPORTED_LABELING_TYPES = {
Domain.RADAR: [
MODALITY_TO_SUPPORTED_LABELING_TYPES = {
ModalityType.RADAR: [
LabelingType.SINGLE_LABEL_CLASSIFICATION,
LabelingType.OBJECT_DETECTION,
],
Domain.VISION: [
ModalityType.VISION: [
LabelingType.SINGLE_LABEL_CLASSIFICATION,
LabelingType.OBJECT_DETECTION,
LabelingType.OBJECT_SEGMENTATION,
LabelingType.SEMANTIC_SEGMENTATION,
LabelingType.PANOPTIC_SEGMENTATION,
],
Domain.SPEECH: [LabelingType.SPEECH_TO_TEXT],
Domain.TABULAR: [LabelingType.SINGLE_LABEL_CLASSIFICATION],
ModalityType.SPEECH: [LabelingType.SPEECH_TO_TEXT],
ModalityType.TABULAR: [LabelingType.SINGLE_LABEL_CLASSIFICATION],
}


Expand Down Expand Up @@ -206,9 +206,9 @@ class QADataset(BaseModel):
For audio datasets, this field is ignored.
If no value is provided, all augmentations are applied to vision datasets.
"""
domain: Domain = Domain.VISION
modality: ModalityType = ModalityType.VISION
"""
Used to define the domain of the dataset.
Used to define the modality of the dataset.
Defaults to Image.
"""

Expand All @@ -221,13 +221,16 @@ class QADataset(BaseModel):

@model_validator(mode="after")
def validate_dataset(self):
if self.domain not in DOMAIN_TO_SUPPORTED_LABELING_TYPES:
if self.modality not in MODALITY_TO_SUPPORTED_LABELING_TYPES:
raise ValueError(
f"Domain {self.domain} is not supported. Supported domains are: {list(DOMAIN_TO_SUPPORTED_LABELING_TYPES.keys())}"
f"Modality {self.modality} is not supported. Supported modalities are: {list(MODALITY_TO_SUPPORTED_LABELING_TYPES.keys())}"
)
if self.labeling_type not in DOMAIN_TO_SUPPORTED_LABELING_TYPES[self.domain]:
if (
self.labeling_type
not in MODALITY_TO_SUPPORTED_LABELING_TYPES[self.modality]
):
raise ValueError(
f"Labeling type {self.labeling_type} is not supported for domain {self.domain}. Supported labeling types are: {DOMAIN_TO_SUPPORTED_LABELING_TYPES[self.domain]}"
f"Labeling type {self.labeling_type} is not supported for modality {self.modality}. Supported labeling types are: {MODALITY_TO_SUPPORTED_LABELING_TYPES[self.modality]}"
)
if self.storage_config is None and self.storage_config_id is None:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions on-prem/on_prem_audio_ar_test_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
"import os\n",
"\n",
"from hirundo import (\n",
" Domain,\n",
" GitPlainAuth,\n",
" GitRepo,\n",
" HirundoCSV,\n",
" LabelingType,\n",
" ModalityType,\n",
" QADataset,\n",
" StorageConfig,\n",
" StorageGit,\n",
Expand Down Expand Up @@ -67,7 +67,7 @@
")\n",
"test_dataset = QADataset(\n",
" name=f\"TEST-STT-MASC-dataset{unique_id}\",\n",
" domain=Domain.SPEECH,\n",
" modality=ModalityType.SPEECH,\n",
" labeling_type=LabelingType.SPEECH_TO_TEXT,\n",
" language=\"ar\",\n",
" storage_config=StorageConfig(\n",
Expand Down
4 changes: 2 additions & 2 deletions on-prem/on_prem_audio_he_small_test_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
"import os\n",
"\n",
"from hirundo import (\n",
" Domain,\n",
" GitPlainAuth,\n",
" GitRepo,\n",
" HirundoCSV,\n",
" LabelingType,\n",
" ModalityType,\n",
" QADataset,\n",
" StorageConfig,\n",
" StorageGit,\n",
Expand Down Expand Up @@ -51,7 +51,7 @@
")\n",
"test_dataset = QADataset(\n",
" name=f\"TEST-STT-RoboShaulGolden-dataset{unique_id}\",\n",
" domain=Domain.SPEECH,\n",
" modality=ModalityType.SPEECH,\n",
" labeling_type=LabelingType.SPEECH_TO_TEXT,\n",
" language=\"he\",\n",
" storage_config=StorageConfig(\n",
Expand Down
4 changes: 2 additions & 2 deletions on-prem/on_prem_audio_he_test_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
"import os\n",
"\n",
"from hirundo import (\n",
" Domain,\n",
" GitPlainAuth,\n",
" GitRepo,\n",
" HirundoCSV,\n",
" LabelingType,\n",
" ModalityType,\n",
" QADataset,\n",
" StorageConfig,\n",
" StorageGit,\n",
Expand Down Expand Up @@ -51,7 +51,7 @@
")\n",
"test_dataset = QADataset(\n",
" name=f\"TEST-STT-RoboShaul-dataset{unique_id}\",\n",
" domain=Domain.SPEECH,\n",
" modality=ModalityType.SPEECH,\n",
" labeling_type=LabelingType.SPEECH_TO_TEXT,\n",
" language=\"he\",\n",
" storage_config=StorageConfig(\n",
Expand Down
Loading
Loading