Skip to content

Commit

Permalink
Add field_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Mar 13, 2024
1 parent 9f50bec commit 56f54c0
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 29 deletions.
17 changes: 11 additions & 6 deletions ecml_tools/commands/inspect/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def frequency(self):
def resolution(self):
return self.metadata["resolution"]

@property
def field_shape(self):
return self.metadata.get("field_shape")

@property
def shape(self):
if self.data and hasattr(self.data, "shape"):
Expand All @@ -170,15 +174,16 @@ def uncompressed_data_size(self):

def info(self, detailed, size):
print()
print(f'📅 Start : {self.first_date.strftime("%Y-%m-%d %H:%M")}')
print(f'📅 End : {self.last_date.strftime("%Y-%m-%d %H:%M")}')
print(f"⏰ Frequency : {self.frequency}h")
print(f'📅 Start : {self.first_date.strftime("%Y-%m-%d %H:%M")}')
print(f'📅 End : {self.last_date.strftime("%Y-%m-%d %H:%M")}')
print(f"⏰ Frequency : {self.frequency}h")
if self.n_missing_dates is not None:
print(f"🚫 Missing : {self.n_missing_dates:,}")
print(f"🌎 Resolution: {self.resolution}")
print(f"🚫 Missing : {self.n_missing_dates:,}")
print(f"🌎 Resolution : {self.resolution}")
print(f"🌎 Field shape: {self.field_shape}")

print()
shape_str = "📐 Shape : "
shape_str = "📐 Shape : "
if self.shape:
shape_str += " × ".join(["{:,}".format(s) for s in self.shape])
if self.uncompressed_data_size:
Expand Down
10 changes: 10 additions & 0 deletions ecml_tools/create/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def _build_coords(self):
self._grid_points = grid_points
self._resolution = first_field.resolution
self._grid_values = grid_values
self._field_shape = first_field.shape

@cached_property
def variables(self):
Expand All @@ -216,6 +217,11 @@ def grid_points(self):
self._build_coords
return self._grid_points

@cached_property
def field_shape(self):
self._build_coords
return self._field_shape


class HasCoordsMixin:
@cached_property
Expand All @@ -238,6 +244,10 @@ def grid_values(self):
def grid_points(self):
return self._coords.grid_points

@cached_property
def field_shape(self):
return self._coords.field_shape

@cached_property
def shape(self):
return [
Expand Down
1 change: 1 addition & 0 deletions ecml_tools/create/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def initialise(self, check_name=True):
metadata["variables"] = variables
metadata["variables_with_nans"] = variables_with_nans
metadata["resolution"] = resolution
metadata["field_shape"] = self.minimal_input.field_shape

metadata["licence"] = self.main_config["licence"]
metadata["copyright"] = self.main_config["copyright"]
Expand Down
18 changes: 12 additions & 6 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,10 @@ def statistics(self):
def resolution(self):
return self.z.attrs["resolution"]

@property
def field_shape(self):
return tuple(self.z.attrs["field_shape"])

@property
def frequency(self):
try:
Expand Down Expand Up @@ -610,6 +614,10 @@ def dates(self):
def resolution(self):
return self.forward.resolution

@property
def field_shape(self):
return self.forward.field_shape

@property
def frequency(self):
return self.forward.frequency
Expand Down Expand Up @@ -912,12 +920,10 @@ def __init__(self, forward, thinning, method):
self.thinning = thinning
self.method = method

assert method is None, f"Thinning method not supported: {method}"
latitudes = sorted(set(forward.latitudes))
longitudes = sorted(set(forward.longitudes))

latitudes = set(latitudes[::thinning])
longitudes = set(longitudes[::thinning])
latitudes = forward.latitudes.reshape(forward.field_shape)
longitudes = forward.longitudes.reshape(forward.field_shape)
latitudes = latitudes[::thinning, ::thinning].flatten()
longitudes = longitudes[::thinning, ::thinning].flatten()

mask = [lat in latitudes and lon in longitudes for lat, lon in zip(forward.latitudes, forward.longitudes)]
mask = np.array(mask, dtype=bool)
Expand Down
64 changes: 47 additions & 17 deletions ecml_tools/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,35 @@
def plot_mask(path, mask, lats, lons, global_lats, global_lons):
import matplotlib.pyplot as plt

middle = (np.amin(lons) + np.amax(lons)) / 2
print("middle", middle)
s = 1

# gmiddle = (np.amin(global_lons)+ np.amax(global_lons))/2

# print('gmiddle', gmiddle)
# global_lons = global_lons-gmiddle+middle
global_lons[global_lons >= 180] -= 360

plt.figure(figsize=(10, 5))
plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r")
plt.scatter(global_lons, global_lats, s=s, marker="o", c="r")
plt.savefig(path + "-global.png")

plt.figure(figsize=(10, 5))
plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k")
plt.scatter(global_lons[mask], global_lats[mask], s=s, c="k")
plt.savefig(path + "-cutout.png")

plt.figure(figsize=(10, 5))
plt.scatter(lons, lats, s=0.01)
plt.scatter(lons, lats, s=s)
plt.savefig(path + "-lam.png")
# plt.scatter(lons, lats, s=0.01)

plt.figure(figsize=(10, 5))
plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r")
plt.scatter(lons, lats, s=s)
plt.savefig(path + "-both.png")
# plt.scatter(lons, lats, s=0.01)


def latlon_to_xyz(lat, lon, radius=1.0):
# https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates
Expand Down Expand Up @@ -64,27 +80,27 @@ def intersect(self, ray_origin, ray_direction):
a = np.dot(self.v1 - self.v0, h)

if -epsilon < a < epsilon:
return None
return False

f = 1.0 / a
s = ray_origin - self.v0
u = f * np.dot(s, h)

if u < 0.0 or u > 1.0:
return None
return False

q = np.cross(s, self.v1 - self.v0)
v = f * np.dot(ray_direction, q)

if v < 0.0 or u + v > 1.0:
return None
return False

t = f * np.dot(self.v2 - self.v0, q)

if t > epsilon:
return t
return True

return None
return False


def cropping_mask(lats, lons, north, west, south, east):
Expand All @@ -106,7 +122,7 @@ def cutout_mask(
global_lats,
global_lons,
cropping_distance=2.0,
min_distance=0.0,
min_distance_km=0.0,
plot=None,
):
"""
Expand All @@ -115,6 +131,8 @@ def cutout_mask(

# TODO: transform min_distance from lat/lon to xyz

min_distance = min_distance_km / 6371.0

assert global_lats.ndim == 1
assert global_lons.ndim == 1
assert lats.ndim == 1
Expand All @@ -140,31 +158,43 @@ def cutout_mask(
)

# return mask
# mask = np.array([True] * len(global_lats), dtype=bool)
global_lats_masked = global_lats[mask]
global_lons_masked = global_lons[mask]

global_xyx = latlon_to_xyz(global_lats_masked, global_lons_masked)
global_points = np.array(global_xyx).transpose()

xyx = latlon_to_xyz(lats, lons)
points = np.array(xyx).transpose()
lam_points = np.array(xyx).transpose()

# Use a KDTree to find the nearest points
kdtree = KDTree(points)
kdtree = KDTree(lam_points)
distances, indices = kdtree.query(global_points, k=3)

zero = np.array([0.0, 0.0, 0.0])
ok = []
for i, (global_point, distance, index) in enumerate(zip(global_points, distances, indices)):
t = Triangle3D(points[index[0]], points[index[1]], points[index[2]])
distance = np.min(distance)
t = Triangle3D(lam_points[index[0]], lam_points[index[1]], lam_points[index[2]])
# distance = np.min(distance)
# The point is inside the triangle if the intersection with the ray
# from the point to the center of the Earth is not None
# (the direction of the ray is not important)
ok.append(
(t.intersect(zero, global_point) or t.intersect(global_point, zero))
# and (distance >= min_distance)
)

intersect = t.intersect(zero, global_point) or t.intersect(global_point, zero)
close = np.min(distance) <= min_distance

if not intersect and False:

if 0 <= global_lons_masked[i] <= 30:
if 55 <= global_lats_masked[i] <= 70:
print(global_lats_masked[i], global_lons_masked[i], distance, intersect, close)
print(lats[index[0]], lons[index[0]])
print(lats[index[1]], lons[index[1]])
print(lats[index[2]], lons[index[2]])
assert False

ok.append(intersect and not close)

j = 0
ok = np.array(ok)
Expand Down

0 comments on commit 56f54c0

Please sign in to comment.