Skip to content

Commit

Permalink
Merge pull request #6 from nilsleh/msjr-test_patching
Browse files Browse the repository at this point in the history
PR to make patching and stitching agnostic to coordinate direction
  • Loading branch information
nilsleh authored Aug 21, 2024
2 parents c7a6172 + 6dd9b3a commit a705549
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 71 deletions.
136 changes: 120 additions & 16 deletions deepsensor/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
) = self.infer_context_and_target_var_IDs()

self.coord_bounds = self._compute_global_coordinate_bounds()
self.coord_directions = self._compute_x1x2_direction()

def _set_config(self):
"""Instantiate a config dictionary for the TaskLoader object"""
Expand Down Expand Up @@ -829,6 +830,52 @@ def _compute_global_coordinate_bounds(self) -> List[float]:
x2_max = var_x2_max

return [x1_min, x1_max, x2_min, x2_max]

def _compute_x1x2_direction(self) -> str:
"""
Compute whether the x1 and x2 coords are ascending or descending.
Returns
-------
coord_directions: dict(str)
String containing two booleans: x1_ascend and x2_ascend,
defining if these coordings increase or decrease from top left corner.
"""

for var in itertools.chain(self.context, self.target):
if isinstance(var, (xr.Dataset, xr.DataArray)):
coord_x1_left= var.x1[0]
coord_x1_right= var.x1[-1]
coord_x2_top= var.x2[0]
coord_x2_bottom= var.x2[-1]
#Todo- what to input for pd.dataframe
elif isinstance(var, (pd.DataFrame, pd.Series)):
var_x1_min = var.index.get_level_values("x1").min()
var_x1_max = var.index.get_level_values("x1").max()
var_x2_min = var.index.get_level_values("x2").min()
var_x2_max = var.index.get_level_values("x2").max()

x1_ascend = True
x2_ascend = True
if coord_x1_left < coord_x1_right:
x1_ascend = True
if coord_x1_left > coord_x1_right:
x1_ascend = False

if coord_x2_top < coord_x2_bottom:
x2_ascend = True
if coord_x2_top > coord_x2_bottom:
x2_ascend = False



coord_directions = {
"x1": x1_ascend,
"x2": x2_ascend,
}

return coord_directions

def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]:
"""
Expand Down Expand Up @@ -1413,30 +1460,87 @@ def sample_sliding_window(
stride = patch_size

dy, dx = stride

# Calculate the global bounds of context and target set.
x1_min, x1_max, x2_min, x2_max = self.coord_bounds

## start with first patch top left hand corner at x1_min, x2_min
patch_list = []

for y in np.arange(x1_min, x1_max, dy):
for x in np.arange(x2_min, x2_max, dx):
if y + x1_extend > x1_max:
y0 = x1_max - x1_extend
else:
y0 = y
if x + x2_extend > x2_max:
x0 = x2_max - x2_extend
else:
x0 = x
# Todo: simplify these elif statements
if self.coord_directions['x1'] == False and self.coord_directions['x2'] == True:
for y in np.arange(x1_max, x1_min, -dy):
for x in np.arange(x2_min, x2_max, dx):
if y - x1_extend < x1_min:
y0 = x1_min + x1_extend
else:
y0 = y
if x + x2_extend > x2_max:
x0 = x2_max - x2_extend
else:
x0 = x

# bbox of x1_min, x1_max, x2_min, x2_max per patch
bbox = [y0 - x1_extend, y0, x0, x0 + x2_extend]
patch_list.append(bbox)

elif self.coord_directions['x1'] == False and self.coord_directions['x2'] == False:
for y in np.arange(x1_max, x1_min, -dy):
for x in np.arange(x2_max, x2_min, -dx):
if y - x1_extend < x1_min:
y0 = x1_min + x1_extend
else:
y0 = y
if x - x2_extend < x2_min:
x0 = x2_min + x2_extend
else:
x0 = x

# bbox of x1_min, x1_max, x2_min, x2_max per patch
bbox = [y0 - x1_extend, y0, x0 - x2_extend, x0]
patch_list.append(bbox)

elif self.coord_directions['x1'] == True and self.coord_directions['x2'] == False:
for y in np.arange(x1_min, x1_max, dy):
for x in np.arange(x2_max, x2_min, -dx):
if y + x1_extend > x1_max:
y0 = x1_max - x1_extend
else:
y0 = y
if x - x2_extend < x2_min:
x0 = x2_min + x2_extend
else:
x0 = x

# bbox of x1_min, x1_max, x2_min, x2_max per patch
bbox = [y0, y0 + x1_extend, x0 - x2_extend, x0]
patch_list.append(bbox)
else:
for y in np.arange(x1_min, x1_max, dy):
for x in np.arange(x2_min, x2_max, dx):
if y + x1_extend > x1_max:
y0 = x1_max - x1_extend
else:
y0 = y
if x + x2_extend > x2_max:
x0 = x2_max - x2_extend
else:
x0 = x

# bbox of x1_min, x1_max, x2_min, x2_max per patch
bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend]

# bbox of x1_min, x1_max, x2_min, x2_max per patch
bbox = [y0, y0 + x1_extend, x0, x0 + x2_extend]
patch_list.append(bbox)

patch_list.append(bbox)
# Remove duplicate patches while preserving order
seen = set()
unique_patch_list = []
for lst in patch_list:
# Convert list to tuple for immutability
tuple_lst = tuple(lst)
if tuple_lst not in seen:
seen.add(tuple_lst)
unique_patch_list.append(lst)

return patch_list
return unique_patch_list

def __call__(
self,
Expand Down
Loading

0 comments on commit a705549

Please sign in to comment.