Skip to content

Commit be6c8cc

Browse files
Merge pull request #63 from DoctorRabbit55/rbm_data_set_features
Enhance RBMDataSet with load function and new attributes
2 parents 2b62177 + 532597d commit be6c8cc

5 files changed

Lines changed: 52 additions & 6 deletions

File tree

swvo/io/RBMDataSet/RBMDataSet.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,15 @@ def __getattr__(self, name: str) -> NDArray[np.float64]:
228228

229229
raise AttributeError(msg)
230230

231+
def load(self, name_or_var: str | VariableEnum) -> None:
232+
""" Load data into memory """
233+
234+
if isinstance(name_or_var, VariableEnum):
235+
getattr(self, name_or_var.var_name)
236+
else:
237+
getattr(self, name_or_var)
238+
239+
231240
def find_similar_variable(self, name: str) -> tuple[None | VariableEnum, dict[str, Any]]:
232241
levenstein_info: dict[str, Any] = {"min_distance": 10, "var_name": ""}
233242
sat_variable = None

swvo/io/RBMDataSet/RBMNcDataSet.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,28 @@ def __init__(
110110
verbose=verbose,
111111
)
112112

113+
mfm_str = mfm if isinstance(mfm, str) else mfm.mfm_name
114+
115+
self.variable_lut = {
116+
"time": "time",
117+
"datetime": "datetime",
118+
"flux/FEDU": "Flux",
119+
"flux/alpha_eq": "alpha_eq_model",
120+
"flux/energy": "energy_channels",
121+
"flux/alpha_local": "alpha_local",
122+
"position/xGEO": "xGEO",
123+
"psd/PSD": "PSD",
124+
"density/density_local": "density",
125+
126+
f"position/{mfm_str}/MLT": "MLT",
127+
f"position/{mfm_str}/R0": "R0",
128+
f"position/{mfm_str}/Lstar": "Lstar",
129+
f"position/{mfm_str}/Lm": "Lm",
130+
f"mag_field/{mfm_str}/B_local": "B_total",
131+
f"psd/{mfm_str}/inv_mu": "InvMu",
132+
f"psd/{mfm_str}/inv_K": "InvK",
133+
}
134+
113135
def _create_file_path_stem(self) -> Path:
114136
# implement special cases here
115137
# if self._satellite == SatelliteEnum.THEMIS:
@@ -196,20 +218,23 @@ def _load_variable(self, var: Variable | VariableEnum) -> None:
196218
if var_name == "datetime":
197219
loaded_var_arrs[var_name] = list(loaded_var_arrs[var_name]) # ty:ignore[invalid-assignment]
198220

199-
rbm_var_name = RBMNcDataSet._get_rbm_name(var_name, self._mfm.mfm_name) # ty:ignore[invalid-argument-type]
221+
rbm_var_names = RBMNcDataSet._get_rbm_name(var_name, self._mfm.mfm_name) # ty:ignore[invalid-argument-type]
200222

201-
if rbm_var_name is not None:
202-
setattr(self, rbm_var_name, loaded_var_arrs[var_name])
223+
if rbm_var_names is not None:
224+
for name in rbm_var_names:
225+
setattr(self, name, loaded_var_arrs[var_name])
203226

204227
@classmethod
205-
def _get_rbm_name(cls, var_name: str, mag_field: MfmEnumLiteral) -> VariableLiteral | None:
228+
def _get_rbm_name(cls, var_name: str, mag_field: MfmEnumLiteral) -> VariableLiteral | None | list[VariableLiteral]:
206229
match var_name:
207230
case "time":
208231
return "time"
209232
case "datetime":
210233
return "datetime"
211234
case "flux/FEDU":
212-
return "Flux"
235+
return ["Flux", "FEDU"]
236+
case "flux/FEIU":
237+
return ["Flux", "FEIU"]
213238
case "flux/alpha_eq":
214239
return "alpha_eq_model"
215240
case "flux/energy":

swvo/io/RBMDataSet/custom_enums.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class VariableEnum(Variable, Enum):
5757
P = "P", "mlt", with_B
5858
R_0 = "R0", "R0", with_B
5959
DENSITY = "density", "density", without_B
60+
# NC only variables
61+
FEDU = "FEDU", "", without_B
62+
FEIU = "FEIU", "", without_B
63+
Lm = "Lm", "", without_B
6064

6165

6266
VariableLiteral = Literal[
@@ -72,6 +76,9 @@ class VariableEnum(Variable, Enum):
7276
"InvV",
7377
"Lstar",
7478
"Flux",
79+
"FEDU",
80+
"FEIU",
81+
"Lm",
7582
"PSD",
7683
"MLT",
7784
"B_SM",

swvo/io/RBMDataSet/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,13 @@ def join_var(var1: NDArray[np.generic], var2: NDArray[np.generic]) -> NDArray[np
3131
def get_file_path_any_format(folder_path: Path, file_stem: str, preferred_ext: str) -> Path | None:
3232
"""Get the file path for a given file stem and preferred extension."""
3333
pattern = re.compile(fnmatch.translate(file_stem + ".*"), re.IGNORECASE)
34+
35+
if not folder_path.exists():
36+
return None
37+
3438
all_files = [p for p in folder_path.iterdir() if pattern.match(p.name)]
3539

40+
3641
if len(all_files) == 0:
3742
warnings.warn(f"File not found: {folder_path / (file_stem + '.*')}", stacklevel=2)
3843
return None

tests/io/RBMDataSet/test_RBMNcDataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def test_load_variable_real_file():
260260

261261
assert hasattr(dataset, "alpha_local"), "Dataset should have 'alpha_local' attribute after loading."
262262
assert isinstance(dataset.alpha_local, np.ndarray), "'alpha_local' should be a NumPy array."
263-
263+
assert hasattr(dataset, "FEDU")
264264

265265
def test_all_variables_in_dir(mock_dataset: RBMNcDataSet):
266266
vars = [

0 commit comments

Comments
 (0)