diff --git a/xbout/load.py b/xbout/load.py index 8213a653..c7742661 100644 --- a/xbout/load.py +++ b/xbout/load.py @@ -542,24 +542,58 @@ def collect( ds : numpy.ndarray """ - from os.path import join + from pathlib import Path as _Path - datapath = join(path, prefix + "*.nc") + datapath_glob = str(_Path(path) / (prefix + "*.nc")) - ds, _ = _auto_open_mfboutdataset( - datapath, keep_xboundaries=xguards, keep_yboundaries=yguards, info=info - ) + # Fast path: use lazy loader which only opens one file for metadata. + # Falls back to open_mfdataset if the directory cannot be detected or + # the variable is not supported by the lazy loader. + try: + path_obj = _Path(path) + if path_obj.is_dir(): + ds = lazyload.lazy_open_boutdataset( + path, + keep_xboundaries=xguards, + keep_yboundaries=yguards, + info=info, + prefix=prefix, + ) + else: + raise ValueError("path is not a directory") + + if varname not in ds: + raise KeyError( + "No variable, {} was found in {}.".format(varname, datapath_glob) + ) + + da = ds[varname] + dims = list(da.dims) - if varname not in ds: - raise KeyError("No variable, {} was found in {}.".format(varname, datapath)) + except Exception: + # Fall back to the slow multi-file open + ds, _ = _auto_open_mfboutdataset( + datapath_glob, + keep_xboundaries=xguards, + keep_yboundaries=yguards, + info=info, + ) - dims = list(ds.dims) - inds = [tind, xind, yind, zind] + if varname not in ds: + raise KeyError( + "No variable, {} was found in {}.".format(varname, datapath_glob) + ) + + da = ds[varname] + dims = list(ds.dims) + + inds = {"t": tind, "x": xind, "y": yind, "z": zind} selection = {} # Convert indexing values to an isel suitable format - for dim, ind in zip(dims, inds): + for dim in dims: + ind = inds.get(dim) if isinstance(ind, int): indexer = [ind] elif isinstance(ind, list): @@ -570,25 +604,26 @@ def collect( else: indexer = None - if indexer: + if indexer is not None: selection[dim] = indexer try: - version = ds["BOUT_VERSION"] - except KeyError: - # If BOUT Version is not saved in the dataset + version = ds.attrs.get("metadata", {}).get("BOUT_VERSION", 0) + if version == 0 and "BOUT_VERSION" in ds: + version = float(ds["BOUT_VERSION"].values) + except Exception: version = 0 # Subtraction of z-dimensional data occurs in boutdata.collect # if BOUT++ version is old - same feature added here if (version < 3.5) and ("z" in dims): - zsize = int(ds["nz"]) - 1 - ds = ds.isel(z=slice(zsize)) + zsize = int(ds.attrs.get("metadata", {}).get("nz", da.sizes["z"])) + da = da.isel(z=slice(zsize)) if selection: - ds = ds.isel(selection) + da = da.isel(selection) - result = ds[varname].values + result = da.values # Close netCDF files to ensure they are not locked if collect is called again ds.close()