Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aeronet_raster/aeronet_raster/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (0, 2, 3)
VERSION = (0, 2, 4)

__version__ = '.'.join(map(str, VERSION))
2 changes: 1 addition & 1 deletion aeronet_raster/aeronet_raster/band/bandsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BandSample(GeoObject):
It implements all the interfaces of the GeoObject, and stores the raster data in memory

Args:
name (str): a name of the sample, which is used as a defaule name for saving to file
name (str): a name of the sample, which is used as a default name for saving to file
raster (np.array): the raster data
crs: geographical coordinate reference system, as :obj:`CRS` or string representation
transform (Affine): affine transform
Expand Down
4 changes: 2 additions & 2 deletions aeronet_raster/aeronet_raster/collectionprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from multiprocessing.pool import ThreadPool
from threading import Lock
from tqdm import tqdm
import cv2

from typing import Union, Optional, Callable, List, Tuple
from .band.band import Band
from .bandcollection.bandcollection import BandCollection
Expand Down Expand Up @@ -389,7 +389,7 @@ def __init__(self,
input_channels: List[str],
output_labels: List[str],
processing_fn: Callable,
sample_size: Tuple[int] = (1024, 1024),
sample_size: Tuple[int, int] = (1024, 1024),
bound: int = 256,
src_nodata=0,
nodata=None, dst_nodata=0,
Expand Down
Empty file.
65 changes: 65 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/abstractadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
from ..utils.utils import validate_coord


class AbstractAdapter:
"""Base abstract class for adapters. Provides numpy array-like interface for arbitrary data source"""
@property
def shape(self):
raise NotImplementedError

@property
def dtype(self):
raise NotImplementedError

@property
def ndim(self):
raise NotImplementedError

def __len__(self):
raise NotImplementedError

def parse_item(self, item):
"""Parse input for __getitem__() to handle arbitrary input
Possible cases:
- item is a single value (int) -> turns it into a tuple and adds the slice over the whole axis for every
missing dimension
- len(item) < self.ndim -> adds the slice over the whole axis for every missing dimension
- len(item) > self.ndim -> raises IndexError
- item contains slices without start or step defined -> defines start=0, step=1
- item contains negative indexes -> substitute them with (self.shape[axis] - index)
"""
if isinstance(item, (list, np.ndarray)):
item = tuple(item)
if not isinstance(item, tuple):
item = (item, )
if len(item) > self.ndim:
raise IndexError(f"Index={item} has more dimensions than data={self.shape}")
item = list(item)
while len(item) < self.ndim:
item.append(None)

for axis, coord in enumerate(item):
item[axis] = validate_coord(coord, self.shape[axis])
return item

# Read -------------------------------------------------------------------------------------------------------------
def __getitem__(self, item):
item = self.parse_item(item)
return self.fetch(item)

def fetch(self, item):
"""Datasource-specific data fetching, e.g. rasterio.read()"""
raise NotImplementedError

# Write ------------------------------------------------------------------------------------------------------------
def __setitem__(self, item, data):
item = self.parse_item(item)
self.write(item, data)

def write(self, item, data):
raise NotImplementedError




53 changes: 53 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/boundsafemixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging

import numpy as np


class BoundSafeMixin:
"""
Redefines __getitem__() and __setitem__() so it works even if the coordinates are out of bounds
"""
def __init__(self, padding_mode: str = 'constant', **kwargs):
super().__init__(**kwargs)
self.padding_mode = padding_mode

def __getitem__(self, item):
item = self.parse_item(item)

pads, safe_coords = list(), list()
for axis, coords in enumerate(item):
# coords can be either slice or tuple at this point (after parse_item)
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted
pads.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices
safe_coords.append(coords)
elif isinstance(coords, slice): # coords = (min:max:step)
pads.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0)))
safe_coords.append(slice(coords.start + pads[-1][0], coords.stop - pads[-1][1], coords.step))
if safe_coords[-1].start >= safe_coords[-1].stop:
logging.warning(f'Probably incorrect slice {safe_coords[-1]}')
else:
raise ValueError(f'Can not parse coords={coords} at axis={axis}')

res = self.fetch(safe_coords)
return np.pad(res, pads, mode=self.padding_mode)

def __setitem__(self, item, data):
item = self.parse_item(item)
assert data.ndim == self.ndim == len(item)
safe_coords, crops = list(), list()
for axis, coords in enumerate(item):
# coords can be either slice or tuple at this point (after parse_item)
if isinstance(coords, (list, tuple)): # coords = (coord1, coord2, ...), already sorted
crops.append((0, 0)) # do nothing since indexing out of bounds makes sense only with slices
safe_coords.append(coords)
elif isinstance(coords, slice): # coords = (min:max:step)
crops.append((max(-coords.start, 0), max(coords.stop - self.shape[axis], 0)))
safe_coords.append(slice(coords.start + crops[-1][0], coords.stop - crops[-1][1], coords.step))
if safe_coords[-1].start >= safe_coords[-1].stop:
logging.warning(f'Probably incorrect slice {safe_coords[-1]}')
else:
raise ValueError(f'Can not parse coords={coords} at axis={axis}')

self.write(safe_coords,
data[tuple(slice(crops[i][0], data.shape[i]-crops[i][1], 1) for i in range(data.ndim))])

39 changes: 39 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/filemixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
class FileMixin:
"""Abstract class, provides interface to work with a file (open, close and context manager)"""

def __init__(self, path, **kwargs):
super().__init__(**kwargs)
self._path = path
self._descriptor = None
self._shape = None

def open(self):
raise NotImplementedError

def close(self):
self._descriptor.close()
self._descriptor = None
self._shape = None

def __enter__(self):
self.open()
return self

def __exit__(self, exc_type, exc_val, traceback):
self.close()

def __getitem__(self, item):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
return super().__getitem__(item)

def __setitem__(self, item, data):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
super().__setitem__(item, data)

@property
def shape(self):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
return self._shape
20 changes: 20 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/imageadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .abstractadapter import AbstractAdapter
from .boundsafemixin import BoundSafeMixin


class ImageAdapter(BoundSafeMixin, AbstractAdapter):
"""Abstract class. Redefines parse_item() so that it works with 3-dimensional data (channels, height, width),
allows indexing channels with Sequence[int], spatial dimensions with slices"""
def parse_item(self, item):
item = super().parse_item(item)
if not len(item) == 3:
raise ValueError(f"Image must be indexed with 3 axes, got {item}")
if isinstance(item[0], slice):
item[0] = list(range(item[0].start, item[0].stop, item[0].step))
assert isinstance(item[1], slice) and isinstance(item[2], slice),\
f"Image spatial axes (1 and 2) must be indexed with slices, got {item}"
return item

@property
def ndim(self):
return 3
34 changes: 34 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/numpyadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from .abstractadapter import AbstractAdapter
from .boundsafemixin import BoundSafeMixin


class NumpyAdapter(BoundSafeMixin, AbstractAdapter):
"""Bound-safe adapter for numpy array"""
def __init__(self, data, padding_mode='constant', **kwargs):
super().__init__(padding_mode, **kwargs)
self._data = data

@property
def ndim(self):
return self._data.ndim

@property
def shape(self):
return self._data.shape

@property
def dtype(self):
return self._data.dtype

def __len__(self):
return self.shape[0]

def fetch(self, item):
if isinstance(item, list):
item = tuple(item)
return self._data[item]

def write(self, item, data):
if isinstance(item, list):
item = tuple(item)
self._data[item] = data
40 changes: 40 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/piladapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from .imageadapter import ImageAdapter
from .filemixin import FileMixin
import numpy as np
import pkg_resources

if 'pillow' in {pkg.key for pkg in pkg_resources.working_set}:
from PIL import Image


class PilAdapter(FileMixin, ImageAdapter):
"""Provides numpy array-like interface to PIL-compatible image file."""

def open(self):
self._descriptor = Image.open(self._path)
self._shape = len(self._descriptor.getbands()), self._descriptor.height, self._descriptor.width

def fetch(self, item):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
channels, y, x = item
return np.array(self._descriptor.crop((x.start, y.start, x.stop, y.stop))).transpose(2, 0, 1)[channels]

def write(self, key, value):
raise AttributeError('PIL Image is not writable. Use NumpyAdapter and save it as Image ')

@property
def dtype(self):
return np.uint8

@property
def ndim(self):
return 3

@property
def shape(self):
return self._shape

def __len__(self):
return self._shape[0]

78 changes: 78 additions & 0 deletions aeronet_raster/aeronet_raster/dataadapters/rasterioadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from .imageadapter import ImageAdapter
from .filemixin import FileMixin
import numpy as np
import rasterio

RASTERIO_OPEN_MODES = {'r', 'r+', 'w', 'w+'}


class RasterioAdapter(FileMixin, ImageAdapter):
"""Provides numpy array-like interface to geotiff file via rasterio"""

def __init__(self, path, mode='r', profile=None, padding_mode: str = 'constant', **kwargs):
super().__init__(path=path, padding_mode=padding_mode)
if mode not in RASTERIO_OPEN_MODES:
raise ValueError(f'Mode must be one of {RASTERIO_OPEN_MODES}')
if mode.startswith('w') and not profile:
raise ValueError(f'Profile must be specified for mode={mode}')
self._mode = mode
self._profile = profile

def open(self):
if self._mode.startswith('w'):
self._descriptor = rasterio.open(self._path, self._mode, **self._profile)
else:
self._descriptor = rasterio.open(self._path, self._mode)
self._profile = self._descriptor.profile
self._shape = self._descriptor.count, self._descriptor.shape[0], self._descriptor.shape[1]

@property
def shape(self):
return self._shape

@property
def ndim(self):
return 3

def __len__(self):
return self._shape[0]

@property
def profile(self):
return self._profile

@property
def crs(self):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
return self._descriptor.crs

@property
def res(self):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
return self._descriptor.res

@property
def count(self):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
return self._descriptor.count

@property
def dtype(self):
if not self._descriptor:
raise ValueError(f'File {self._path} is not opened')
return self._descriptor.profile['dtype']

def fetch(self, item):
channels, y, x = item
return self._descriptor.read([ch+1 for ch in channels],
window=((y.start, y.stop),
(x.start, x.stop)))

def write(self, item, data):
channels, y, x = item
self._descriptor.write(data, [ch+1 for ch in channels],
window=((y.start, y.stop),
(x.start, x.stop)))
Loading