Skip to content

Commit

Permalink
Merge branch 'master' into lightglue
Browse files Browse the repository at this point in the history
  • Loading branch information
tonzowonzo authored Nov 9, 2023
2 parents b344cbd + e97ecc4 commit 9075e44
Show file tree
Hide file tree
Showing 23 changed files with 185 additions and 177 deletions.
98 changes: 53 additions & 45 deletions s2p/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import subprocess
import sys
import copy
import os.path
import json
import datetime
Expand All @@ -33,7 +34,6 @@
from plyflatten import plyflatten_from_plyfiles_list


from s2p.config import cfg
from s2p import common
from s2p import parallel
from s2p import geographiclib
Expand All @@ -46,10 +46,11 @@
from s2p import triangulation
from s2p import fusion
from s2p import visualisation
from s2p import config
from s2p.image_coordinates_to_coordinates import matches_to_geojson


def remove_missing_tiles(tiles):
def remove_missing_tiles(cfg, tiles):
"""Remove tiles where any of the rectified images is missing (also remove from neighborhood_dirs)."""
n = len(cfg['images'])
tiles_new = []
Expand Down Expand Up @@ -79,7 +80,7 @@ def remove_missing_tiles(tiles):
return tiles_new, tiles_pairs


def check_missing_sift(tiles_pairs):
def check_missing_sift(cfg, tiles_pairs):
missing_sift = []
with open(os.path.join(cfg["out_dir"], "missing_sift.txt"), "w") as f:
for tile, i in tiles_pairs:
Expand All @@ -94,20 +95,20 @@ def check_missing_sift(tiles_pairs):
"SIFT matches, this may deteriorate output quality")


def merge_all_match_files():
def merge_all_match_files(out_dir):
"""
Merges together all non-empty match files into one .txt file and saves them in the output directory.
"""
read_files = glob.glob("**/*sift_matches.txt", recursive=True)
with open(os.path.join(cfg["out_dir"], "merged_sift_matches.txt"), "w") as outfile:
with open(os.path.join(out_dir, "merged_sift_matches.txt"), "w") as outfile:
for f in read_files:
if os.stat(f).st_size == 0:
continue
with open(f, "r") as infile:
outfile.write(infile.read())


def pointing_correction(tile, i):
def pointing_correction(tile, i, cfg):
"""
Compute the translation that corrects the pointing error on a pair of tiles.
Expand All @@ -128,7 +129,7 @@ def pointing_correction(tile, i):
A, m = pointing_accuracy.compute_correction(
img1, img2, rpc1, rpc2, x, y, w, h, method,
cfg['sift_match_thresh'], cfg['max_pointing_error'], cfg['matching_method'],
cfg['min_value'], cfg['max_value'], cfg['confidence_threshold']
cfg['min_value'], cfg['max_value'], cfg['confidence_threshold'], cfg
)

if A is not None: # A is the correction matrix
Expand All @@ -141,10 +142,10 @@ def pointing_correction(tile, i):
visualisation.plot_matches(img1, img2, rpc1, rpc2, m,
os.path.join(out_dir,
'sift_matches_pointing.png'),
x, y, w, h)
x, y, w, h, cfg)


def global_pointing_correction(tiles):
def global_pointing_correction(cfg, tiles):
"""
Compute the global pointing corrections for each pair of images.
Expand All @@ -161,7 +162,7 @@ def global_pointing_correction(tiles):
common.remove(os.path.join(d, 'center_keypts_sec.txt'))


def rectification_pair(tile, i):
def rectification_pair(tile, i, cfg):
"""
Rectify a pair of images on a given tile.
Expand Down Expand Up @@ -211,10 +212,14 @@ def rectification_pair(tile, i):
H1, H2, disp_min, disp_max = rectification.rectify_pair(img1, img2,
rpc1, rpc2,
x, y, w, h,
rect1, rect2, A, m,
rect1, rect2,
cfg,
A=A,
sift_matches=m,
method=cfg['rectification_method'],
hmargin=cfg['horizontal_margin'],
vmargin=cfg['vertical_margin'])
vmargin=cfg['vertical_margin'],
)

np.savetxt(os.path.join(out_dir, 'H_ref.txt'), H1, fmt='%12.6f')
np.savetxt(os.path.join(out_dir, 'H_sec.txt'), H2, fmt='%12.6f')
Expand All @@ -225,7 +230,7 @@ def rectification_pair(tile, i):
common.remove(os.path.join(out_dir, 'pointing.txt'))
common.remove(os.path.join(out_dir, 'sift_matches.txt'))

def stereo_matching(tile, i):
def stereo_matching(tile, i, cfg):
"""
Compute the disparity of a pair of images on a given tile.
Expand All @@ -245,7 +250,7 @@ def stereo_matching(tile, i):
disp_min, disp_max = np.loadtxt(os.path.join(out_dir, 'disp_min_max.txt'))

block_matching.compute_disparity_map(rect1, rect2, disp, mask,
cfg['matching_algorithm'], disp_min,
cfg['matching_algorithm'], cfg, disp_min,
disp_max, timeout=cfg['mgm_timeout'],
max_disp_range=cfg['max_disp_range'])

Expand All @@ -258,7 +263,7 @@ def stereo_matching(tile, i):
common.remove(rect2)
common.remove(os.path.join(out_dir, 'disp_min_max.txt'))

def disparity_to_height(tile, i):
def disparity_to_height(tile, i, cfg):
"""
Compute a height map from the disparity map of a pair of image tiles.
Expand Down Expand Up @@ -301,7 +306,7 @@ def disparity_to_height(tile, i):
common.remove(mask)


def disparity_to_ply(tile):
def disparity_to_ply(tile, cfg):
"""
Compute a point cloud from the disparity map of a pair of image tiles.
Expand Down Expand Up @@ -334,7 +339,7 @@ def disparity_to_ply(tile):
with rasterio.open(os.path.join(out_dir, 'pair_1', 'rectified_ref.tif')) as f:
ww, hh = f.width, f.height

colors = common.tmpfile(".tif")
colors = common.tmpfile(cfg['temporary_dir'], ".tif")
success_state = common.image_apply_homography(colors, cfg['images'][0]['clr'],
np.loadtxt(H_ref), ww, hh)
if success_state is False:
Expand Down Expand Up @@ -389,7 +394,7 @@ def disparity_to_ply(tile):
common.remove(os.path.join(out_dir, 'pair_1', 'rectified_ref.tif'))
return tile

def mean_heights(tile):
def mean_heights(tile, cfg):
"""
"""
w, h = tile['coordinates'][2:]
Expand All @@ -411,7 +416,7 @@ def mean_heights(tile):
[np.nanmean(validity_mask * maps[:, :, i]) for i in range(n)])


def global_mean_heights(tiles):
def global_mean_heights(cfg, tiles):
"""
"""
local_mean_heights = [np.loadtxt(os.path.join(t['dir'], 'local_mean_heights.txt'))
Expand All @@ -423,7 +428,7 @@ def global_mean_heights(tiles):
[global_mean_heights[i]])


def heights_fusion(tile):
def heights_fusion(cfg, tile):
"""
Merge the height maps computed for each image pair and generate a ply cloud.
Expand All @@ -449,22 +454,22 @@ def heights_fusion(tile):
# merge the height maps (applying mean offset to register)
fusion.merge_n(os.path.join(tile_dir, 'height_map.tif'), height_maps,
global_mean_heights, averaging=cfg['fusion_operator'],
threshold=cfg['fusion_thresh'])
threshold=cfg['fusion_thresh'], debug=cfg['debug'])

if cfg['clean_intermediate']:
for f in height_maps:
common.remove(f)


def heights_to_ply(tile):
def heights_to_ply(tile, cfg):
"""
Generate a ply cloud.
Args:
tile: a dictionary that provides all you need to process a tile
"""
# merge the n-1 height maps of the tile (n = nb of images)
heights_fusion(tile)
heights_fusion(cfg, tile)

# compute a ply from the merged height map
out_dir = tile['dir']
Expand Down Expand Up @@ -500,7 +505,7 @@ def heights_to_ply(tile):
common.remove(os.path.join(out_dir, 'mask.png'))


def plys_to_dsm(tile):
def plys_to_dsm(tile, cfg):
"""
Generates DSM from plyfiles (cloud.ply)
Expand Down Expand Up @@ -535,6 +540,9 @@ def plys_to_dsm(tile):
roi=roi,
radius=cfg['dsm_radius'],
sigma=cfg['dsm_sigma'])
raster = np.nan_to_num(raster, -9999.)
raster[raster == 0] = -9999.
profile.update(nodata=-9999.)

# save output image with utm georeferencing
common.rasterio_write(out_dsm, raster[:, :, 0], profile=profile)
Expand All @@ -547,7 +555,7 @@ def plys_to_dsm(tile):
common.rasterio_write(out_conf, raster[:, :, 4], profile=profile)


def global_dsm(tiles):
def global_dsm(cfg, tiles):
"""
Merge tilewise DSMs and confidence maps in a global DSM and confidence map.
"""
Expand Down Expand Up @@ -581,7 +589,7 @@ def global_dsm(tiles):
rasterio.merge.merge(dsms,
bounds=bounds,
res=cfg["dsm_resolution"],
nodata=np.nan,
nodata=-9999,
indexes=[1],
dst_path=os.path.join(cfg["out_dir"], "dsm.tif"),
dst_kwds=creation_options)
Expand All @@ -605,18 +613,18 @@ def main(user_cfg, start_from=0, merge_matches=False):
start_from: the step to start from (default: 0)
"""
common.print_elapsed_time.t0 = datetime.datetime.now()
initialization.build_cfg(user_cfg)
initialization.make_dirs()
cfg = initialization.build_cfg(copy.deepcopy(config.cfg), user_cfg)
initialization.make_dirs(cfg)

# multiprocessing setup
nb_workers = multiprocessing.cpu_count() # nb of available cores
if cfg['max_processes'] is not None:
nb_workers = cfg['max_processes']
print(f"Running s2p using {nb_workers} workers.")

tw, th = initialization.adjust_tile_size()
tw, th = initialization.adjust_tile_size(cfg)
tiles_txt = os.path.join(cfg['out_dir'], 'tiles.txt')
tiles = initialization.tiles_full_info(tw, th, tiles_txt, create_masks=True)
tiles = initialization.tiles_full_info(cfg, tw, th, tiles_txt, create_masks=True)
if not tiles:
print('ERROR: the ROI is not seen in two images or is totally masked.')
sys.exit(1)
Expand All @@ -637,20 +645,20 @@ def main(user_cfg, start_from=0, merge_matches=False):
# local-pointing step:
if start_from <= 1:
print('1) correcting pointing locally...')
parallel.launch_calls(pointing_correction, tiles_pairs, nb_workers,
parallel.launch_calls(pointing_correction, tiles_pairs, nb_workers, cfg,
timeout=timeout)
check_missing_sift(tiles_pairs)
check_missing_sift(cfg, tiles_pairs)

# global-pointing step:
if start_from <= 2:
print('2) correcting pointing globally...')
global_pointing_correction(tiles)
global_pointing_correction(cfg, tiles)
common.print_elapsed_time()

# Create matches GeoJSON.
if merge_matches:
print("Creating matches GeoJSON")
merge_all_match_files()
merge_all_match_files(cfg['out_dir'])
matches_to_geojson(f"{cfg['out_dir']}/merged_sift_matches.txt",
cfg['images'][0]['rpcm'],
10,
Expand All @@ -661,13 +669,13 @@ def main(user_cfg, start_from=0, merge_matches=False):
# rectification step:
if start_from <= 3:
print('3) rectifying tiles...')
parallel.launch_calls(rectification_pair, tiles_pairs, nb_workers,
parallel.launch_calls(rectification_pair, tiles_pairs, nb_workers, cfg,
timeout=timeout)

# matching step:
if start_from <= 4:
# Check which tiles were rectified correctly, and skip tiles that have missing files
tiles, tiles_pairs = remove_missing_tiles(tiles)
tiles, tiles_pairs = remove_missing_tiles(cfg, tiles)

if cfg['max_processes_stereo_matching'] is not None:
nb_workers_stereo = cfg['max_processes_stereo_matching']
Expand All @@ -676,7 +684,7 @@ def main(user_cfg, start_from=0, merge_matches=False):
try:

print(f'4) running stereo matching using {nb_workers_stereo} workers...')
parallel.launch_calls(stereo_matching, tiles_pairs, nb_workers_stereo,
parallel.launch_calls(stereo_matching, tiles_pairs, nb_workers_stereo, cfg,
timeout=timeout)
except subprocess.CalledProcessError as e:
print(f'ERROR: stereo matching failed. In case this is due too little RAM set '
Expand All @@ -687,25 +695,25 @@ def main(user_cfg, start_from=0, merge_matches=False):
if n > 2:
# disparity-to-height step:
print('5a) computing height maps...')
parallel.launch_calls(disparity_to_height, tiles_pairs, nb_workers,
parallel.launch_calls(disparity_to_height, tiles_pairs, nb_workers, cfg,
timeout=timeout)

print('5b) computing local pairwise height offsets...')
parallel.launch_calls(mean_heights, tiles, nb_workers, timeout=timeout)
parallel.launch_calls(mean_heights, tiles, nb_workers, cfg, timeout=timeout)

# global-mean-heights step:
print('5c) computing global pairwise height offsets...')
global_mean_heights(tiles)
global_mean_heights(cfg, tiles)

# heights-to-ply step:
print('5d) merging height maps and computing point clouds...')
parallel.launch_calls(heights_to_ply, tiles, nb_workers,
parallel.launch_calls(heights_to_ply, tiles, nb_workers, cfg,
timeout=timeout)
else:
# triangulation step:
print('5) triangulating tiles...')
num_tiles = len(tiles)
tiles = parallel.launch_calls(disparity_to_ply, tiles, nb_workers,
tiles = parallel.launch_calls(disparity_to_ply, tiles, nb_workers, cfg,
timeout=timeout)
tiles = [t for t in tiles if t is not None]
if len(tiles) != num_tiles:
Expand All @@ -715,17 +723,17 @@ def main(user_cfg, start_from=0, merge_matches=False):
# local-dsm-rasterization step:
if start_from <= 6:
print('6) computing DSM by tile...')
parallel.launch_calls(plys_to_dsm, tiles, nb_workers, timeout=timeout)
parallel.launch_calls(plys_to_dsm, tiles, nb_workers, cfg, timeout=timeout)

# global-dsm-rasterization step:
if start_from <= 7:
print('7) computing global DSM...')
global_dsm(tiles)
global_dsm(cfg, tiles)

common.print_elapsed_time()

# cleanup
common.garbage_cleanup()
common.garbage_cleanup(cfg['clean_tmp'])
common.print_elapsed_time(since_first_call=True)


Expand Down
Loading

0 comments on commit 9075e44

Please sign in to comment.