forked from fwitmer/CoastlineExtraction
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlabel_inputs.py
More file actions
103 lines (87 loc) · 4.45 KB
/
label_inputs.py
File metadata and controls
103 lines (87 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import rasterio as rio
from rasterio import merge
from rasterio.enums import Resampling
from rasterio.warp import calculate_default_transform, reproject
from rasterio.io import MemoryFile
from rasterio.features import shapes
from rasterio.mask import mask
from datetime import datetime
from datetime import timedelta
import numpy as np
import glob
import re
import os
# creates a rasterio dataset in memory from a data array and corresponding CRS and transform
# defaults to single-band datasets with nodata value of 0
# adapted from https://medium.com/analytics-vidhya/python-for-geosciences-raster-merging-clipping-and-reprojection-with-rasterio-9f05f012b88a
def create_dataset(data, crs, transform):
memfile = MemoryFile()
dataset = memfile.open(driver="GTiff", height=data.shape[0], width=data.shape[1], count=1,
crs=crs, transform=transform, dtype=data.dtype, nodata=0)
dataset.write(data, 1)
return dataset
# takes a path for an input image and a path for a corresponding label image
# upscales the label image to match the resolution of the input image and merges them into a 5-banded image
def add_labels(input_path, label_path, output_path):
with rio.open(label_path, 'r', driver='GTiff') as label, \
rio.open(input_path, 'r', driver='GTiff') as input:
# copying metadat and updating for the new band count
input_depth = input.count
input_meta = input.meta
input_meta.update(count=5)
# reprojecting label layer to match the CRS and resolution of input
label_reproj, label_reproj_trans = reproject(source=rio.band(label, 1),
dst_crs = input.profile['crs'],
dst_resolution=input.res,
resampling=rio.enums.Resampling.cubic_spline)
label_ds = create_dataset(label_reproj[0], input.profile['crs'], label_reproj_trans)
# cropping reprojected labels to input image's extent
extents, _ = next(shapes(np.zeros_like(input.read(1)), transform=input.profile['transform']))
cropped_label, crop_transf = mask(label_ds, [extents], crop=True)
# updating label layer to have no data where input image has no data
cropped_label_array = cropped_label[0][:input.shape[0], :input.shape[1]]
cropped_label_array = np.where(input.read(1) == 0, 0, cropped_label_array)
# print(reprojected_labels[0].shape)
with rio.open(output_path, 'w', **input_meta) as dst:
dst.write_band(1, input.read(1))
dst.write_band(2, input.read(2))
dst.write_band(3, input.read(3))
dst.write_band(4, input.read(input_depth))
dst.write_band(5, cropped_label_array.astype(rio.uint16))
# returns the date from a filename in YYYY-MM-DD string format
def parse_date(filename):
date = re.search("([0-9]{4}\-[0-9]{2}-[0-9]{2})", filename)
if date:
return date.group(0)
else:
date = re.search("([0-9]{4}\_[0-9]{2})", filename)
return date.group(0)
def match_labels(input_path, label_path):
input_files = glob.glob(input_path + "*.tif")
label_files = glob.glob(label_path + "*.tif")
# preparing labels for comparison
label_dict = {}
label_dates = []
for label in label_files:
label_date = datetime.strptime(parse_date(label), "%Y_%m") + timedelta(days=14)
label_dates.append(label_date)
label_dict[label_date] = label
# comparing each input file to the label files to find the closest match
for input in input_files:
print("Input file:", os.path.basename(input))
input_date = datetime.strptime(parse_date(input), "%Y-%m-%d")
out_name = input_date.strftime("%Y-%m-%d") + "_labeled.tif"
out_path = "data/labeled_inputs/" + out_name
# checking if file already exists
if os.path.exists(out_path):
print("Labeled file already exists at:", out_path)
print()
continue
date_diffs = [abs(label_date - input_date) for label_date in label_dates]
closest_date = label_dates[date_diffs.index(min(date_diffs))]
print("Matching label:", os.path.basename(label_dict[closest_date]))
print("Merging as: {}...".format(out_name), end="")
add_labels(input, label_dict[closest_date], out_path)
print("DONE")
print()
match_labels("data/input/", "data/labels/")