-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
43 lines (32 loc) · 1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import rasterio
from pathlib import Path
# reading in geotiff file as numpy array
def read_tif(file: Path):
if not file.exists():
raise FileNotFoundError(f'File {file} not found')
with rasterio.open(file) as dataset:
arr = dataset.read() # (bands X height X width)
transform = dataset.transform
crs = dataset.crs
return arr.transpose((1, 2, 0)), transform, crs
# writing an array to a geo tiff file
def write_tif(file: Path, arr, transform, crs):
if not file.parent.exists():
file.parent.mkdir()
height, width, bands = arr.shape
with rasterio.open(
file,
'w',
driver='GTiff',
height=height,
width=width,
count=bands,
dtype=arr.dtype,
crs=crs,
transform=transform,
) as dst:
for i in range(bands):
dst.write(arr[:, :, i], i + 1)
def to_numpy(tensor:torch.Tensor):
return tensor.cpu().detach().numpy()