Skip to content

Commit 8bd8cb4

Browse files
authored
Merge pull request #35 from chenyangkang/chenyangkang-JOSS-review
fix test bug and add docstrings
2 parents 7436f33 + 0974836 commit 8bd8cb4

15 files changed

+185
-178
lines changed

stemflow/model/SphereAdaSTEM.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
ensemble_fold: int = 10,
8484
min_ensemble_required: int = 7,
8585
grid_len_upper_threshold: Union[float, int] = 8000,
86-
grid_len_lower_threshold: Union[float, int] = 100,
86+
grid_len_lower_threshold: Union[float, int] = 500,
8787
points_lower_threshold: int = 50,
8888
stixel_training_size_threshold: int = None,
8989
temporal_start: Union[float, int] = 1,
@@ -600,7 +600,7 @@ def __init__(
600600
ensemble_fold=10,
601601
min_ensemble_required=7,
602602
grid_len_upper_threshold=8000,
603-
grid_len_lower_threshold=100,
603+
grid_len_lower_threshold=500,
604604
points_lower_threshold=50,
605605
stixel_training_size_threshold=None,
606606
temporal_start=1,
@@ -759,7 +759,7 @@ def __init__(
759759
ensemble_fold=10,
760760
min_ensemble_required=7,
761761
grid_len_upper_threshold=8000,
762-
grid_len_lower_threshold=100,
762+
grid_len_lower_threshold=500,
763763
points_lower_threshold=50,
764764
stixel_training_size_threshold=None,
765765
temporal_start=1,

stemflow/utils/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
from .validation import check_random_state

stemflow/utils/jitterrotation/jitterrotator.py

+6-56
Original file line numberDiff line numberDiff line change
@@ -2,65 +2,13 @@
22

33
import numpy as np
44

5-
# import geopandas as gpd
6-
75

86
class JitterRotator:
7+
"""2D jitter rotator."""
8+
99
def __init__():
1010
pass
1111

12-
# @classmethod
13-
# def rotate_jitter_gpd(cls,
14-
# df: gpd.geodataframe.GeoDataFrame,
15-
# rotation_angle: Union[int, float],
16-
# calibration_point_x_jitter: Union[int, float],
17-
# calibration_point_y_jitter: Union[int, float]
18-
# ) -> gpd.geodataframe.GeoDataFrame:
19-
# """Rotate Normal lng, lat to jittered, rotated space
20-
21-
# Args:
22-
# x_array (np.ndarray): input lng/x
23-
# y_array (np.ndarray): input lat/y
24-
# rotation_angle (Union[int, float]): rotation angle
25-
# calibration_point_x_jitter (Union[int, float]): calibration_point_x_jitter
26-
# calibration_point_y_jitter (Union[int, float]): calibration_point_y_jitter
27-
28-
# Returns:
29-
# tuple(np.ndarray, np.ndarray): newx, newy
30-
# """
31-
# transformed_series = df.rotate(
32-
# rotation_angle, origin=(0,0)
33-
# ).affine_transform(
34-
# [1,0,0,1,calibration_point_x_jitter,calibration_point_y_jitter]
35-
# )
36-
37-
# df1 = gpd.GeoDataFrame(df, geometry=transformed_series)
38-
39-
# return df1
40-
41-
# @classmethod
42-
# def inverse_jitter_rotate_gpd(cls,
43-
# df_rotated: gpd.geodataframe.GeoDataFrame,
44-
# rotation_angle: Union[int, float],
45-
# calibration_point_x_jitter: Union[int, float],
46-
# calibration_point_y_jitter: Union[int, float]
47-
# ) -> gpd.geodataframe.GeoDataFrame:
48-
# """reverse jitter and rotation
49-
50-
# Args:
51-
# x_array_rotated (np.ndarray): input lng/x
52-
# y_array_rotated (np.ndarray): input lng/x
53-
# rotation_angle (Union[int, float]): rotation angle
54-
# calibration_point_x_jitter (Union[int, float]): calibration_point_x_jitter
55-
# calibration_point_y_jitter (Union[int, float]): calibration_point_y_jitter
56-
# """
57-
58-
# return df_rotated.affine_transform(
59-
# [1,0,0,1,-calibration_point_x_jitter,-calibration_point_y_jitter]
60-
# ).rotate(
61-
# -rotation_angle, origin=(0,0)
62-
# )
63-
6412
@classmethod
6513
def rotate_jitter(
6614
cls,
@@ -124,10 +72,12 @@ def inverse_jitter_rotate(
12472

12573

12674
class Sphere_Jitterrotator:
75+
"""3D jitter rotator"""
76+
12777
def __init__(self) -> None:
12878
pass
12979

130-
def rotate_jitter(point: np.ndarray, axis: np.ndarray, angle: Union[float, int]):
80+
def rotate_jitter(point: np.ndarray, axis: np.ndarray, angle: Union[float, int]) -> np.ndarray:
13181
"""_summary_
13282
13383
Args:
@@ -136,7 +86,7 @@ def rotate_jitter(point: np.ndarray, axis: np.ndarray, angle: Union[float, int])
13686
angle (Union[float, int]): angle in degree
13787
13888
Returns:
139-
_type_: _description_
89+
np.ndarray: _description_
14090
"""
14191
u = np.array(axis)
14292
u = u / np.linalg.norm(u)

stemflow/utils/quadtree.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
1-
# import libraries
1+
"A function module to get quadtree results for 2D indexing system. Returns ensemble_df and plotting axes."
2+
23
import os
34
import warnings
4-
5-
# from collections.abc import Sequence
6-
# from functools import partial
7-
# from itertools import repeat
8-
# from multiprocessing import Pool
95
from typing import Tuple, Union
106

117
import matplotlib
12-
13-
# import matplotlib.patches as patches
148
import matplotlib.pyplot as plt # plotting libraries
159
import numpy as np
1610
import pandas
@@ -21,11 +15,6 @@
2115
from ..gridding.QuadGrid import QuadGrid
2216
from .validation import check_transform_spatio_bin_jitter_magnitude, check_transform_temporal_bin_start_jitter
2317

24-
# from tqdm.contrib.concurrent import process_map
25-
# from .generate_soft_colors import generate_soft_color
26-
# from .validation import check_random_state
27-
28-
2918
os.environ["MKL_NUM_THREADS"] = "1"
3019
os.environ["NUMEXPR_NUM_THREADS"] = "1"
3120
os.environ["OMP_NUM_THREADS"] = "1"

stemflow/utils/sphere/Icosahedron.py

+41-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
"Functions for the initial icosahedron in spherical indexing system"
2+
13
import numpy as np
24

35
from .coordinate_transform import lonlat_cartesian_3D_transformer
46

57

6-
def get_Icosahedron_vertices():
8+
def get_Icosahedron_vertices() -> np.ndarray:
9+
"""Return the 12 vertices of icosahedron
10+
11+
Returns:
12+
np.ndarray: (n_vertices, 3D_coordinates)
13+
"""
714
phi = (1 + np.sqrt(5)) / 2
815
vertices = np.array(
916
[
@@ -24,7 +31,17 @@ def get_Icosahedron_vertices():
2431
return vertices
2532

2633

27-
def calc_and_judge_distance(v1, v2, v3):
34+
def calc_and_judge_distance(v1: np.ndarray, v2: np.ndarray, v3: np.ndarray) -> bool:
35+
"""Determine if the three points have same distance with each other
36+
37+
Args:
38+
v1 (np.ndarray): point 1
39+
v2 (np.ndarray): point 1
40+
v3 (np.ndarray): point 1
41+
42+
Returns:
43+
bool: Whether have same pair-wise distance
44+
"""
2845
d1 = np.sum((np.array(v1) - np.array(v2)) ** 2) ** (1 / 2)
2946
d2 = np.sum((np.array(v1) - np.array(v3)) ** 2) ** (1 / 2)
3047
d3 = np.sum((np.array(v2) - np.array(v3)) ** 2) ** (1 / 2)
@@ -34,7 +51,12 @@ def calc_and_judge_distance(v1, v2, v3):
3451
return False
3552

3653

37-
def get_Icosahedron_faces():
54+
def get_Icosahedron_faces() -> np.ndarray:
55+
"""Get icosahedron faces
56+
57+
Returns:
58+
np.ndarray: shape (20,3,3). (faces, point, 3d_dimension)
59+
"""
3860
vertices = get_Icosahedron_vertices()
3961

4062
face_list = []
@@ -51,7 +73,12 @@ def get_Icosahedron_faces():
5173
return face_list
5274

5375

54-
def get_earth_Icosahedron_vertices_and_faces_lonlat():
76+
def get_earth_Icosahedron_vertices_and_faces_lonlat() -> [np.ndarray, np.ndarray]:
77+
"""Get vertices and faces in lon, lat
78+
79+
Returns:
80+
[np.ndarray, np.ndarray]: vertices, faces
81+
"""
5582
# earth_radius_km=6371.0
5683
# get Icosahedron vertices and faces
5784
vertices = get_Icosahedron_vertices()
@@ -68,7 +95,16 @@ def get_earth_Icosahedron_vertices_and_faces_lonlat():
6895
return np.stack([vertices_lng, vertices_lat], axis=-1), np.stack([faces_lng, faces_lat], axis=-1)
6996

7097

71-
def get_earth_Icosahedron_vertices_and_faces_3D(radius=1):
98+
def get_earth_Icosahedron_vertices_and_faces_3D(radius=1) -> [np.ndarray, np.ndarray]:
99+
"""Get vertices and faces in lon, lat
100+
101+
Args:
102+
radius (Union[int, float]): radius of earth in km.
103+
104+
Returns:
105+
[np.ndarray, np.ndarray]: vertices, faces
106+
"""
107+
72108
# earth_radius_km=6371.0
73109
# get Icosahedron vertices and faces
74110
vertices = get_Icosahedron_vertices()

stemflow/utils/sphere/coordinate_transform.py

+53-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
11
from collections.abc import Sequence
2+
from typing import Tuple, Union
23

34
import numpy as np
45

56
from ...gridding.Q_blocks import QPoint_3D
67

78

89
class lonlat_cartesian_3D_transformer:
10+
"""Transformer between longitude,latitude and 3d dimension (x,y,z)."""
11+
912
def __init__(self) -> None:
1013
pass
1114

12-
def transform(lng, lat, radius=6371):
15+
def transform(lng: np.ndarray, lat: np.ndarray, radius: float = 6371.0) -> Tuple[np.ndarray, np.ndarray]:
16+
"""Transform lng, lat to x,y,z
17+
18+
Args:
19+
lng (np.ndarray): lng
20+
lat (np.ndarray): lat
21+
radius (float, optional): radius of earth in km. Defaults to 6371.
22+
23+
Returns:
24+
Tuple[np.ndarray, np.ndarray]: x,y,z
25+
"""
26+
1327
# Convert latitude and longitude from degrees to radians
1428
lat_rad = np.radians(lat)
1529
lng_rad = np.radians(lng)
@@ -21,15 +35,38 @@ def transform(lng, lat, radius=6371):
2135

2236
return x, y, z
2337

24-
def inverse_transform(x, y, z, r=None):
38+
def inverse_transform(
39+
x: np.ndarray, y: np.ndarray, z: np.ndarray, r: float = None
40+
) -> Tuple[np.ndarray, np.ndarray]:
41+
"""transform x,y,z to lon, lat
42+
43+
Args:
44+
x (np.ndarray): x
45+
y (np.ndarray): y
46+
z (np.ndarray): z
47+
r (float, optional): Radius of your spherical coordinate. If not given, calculate from x,y,z. Defaults to None.
48+
49+
Returns:
50+
Tuple[np.ndarray, np.ndarray]: longitude, latitude
51+
"""
2552
if r is None:
2653
r = np.sqrt(x**2 + y**2 + z**2)
2754
latitude = np.degrees(np.arcsin(z / r))
2855
longitude = np.degrees(np.arctan2(y, x))
2956
return longitude, latitude
3057

3158

32-
def get_midpoint_3D(p1, p2, radius=6371):
59+
def get_midpoint_3D(p1: QPoint_3D, p2: QPoint_3D, radius: float = 6371.0) -> QPoint_3D:
60+
"""Get the mid-point of three QPoint_3D objet (vector)
61+
62+
Args:
63+
p1 (QPoint_3D): p1
64+
p2 (QPoint_3D): p2
65+
radius (float, optional): radius of earth in km. Defaults to 6371.0.
66+
67+
Returns:
68+
QPoint_3D: mid-point.
69+
"""
3370
v1 = np.array([p1.x, p1.y, p1.z])
3471
v2 = np.array([p2.x, p2.y, p2.z])
3572

@@ -41,7 +78,19 @@ def get_midpoint_3D(p1, p2, radius=6371):
4178
return p3
4279

4380

44-
def continuous_interpolation_3D_plotting(p1, p2, radius=6371):
81+
def continuous_interpolation_3D_plotting(
82+
p1: np.ndarray, p2: np.ndarray, radius: float = 6371.0
83+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
84+
"""interpolate 10 points on earth surface between the given two points. For plotting.
85+
86+
Args:
87+
p1 (np.ndarray): p1
88+
p2 (np.ndarray): p2
89+
radius (float, optional): radius of earth in km. Defaults to 6371.0.
90+
91+
Returns:
92+
Tuple[np.ndarray, np.ndarray, np.ndarray]: 10 x, 10 y, 10 z
93+
"""
4594
v1 = np.array([p1[0], p1[1], p1[2]])
4695
v2 = np.array([p2[0], p2[1], p2[2]])
4796

stemflow/utils/sphere/discriminant_formula.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,40 @@
1-
import numpy as np
2-
3-
# def sign(target, p2, p3):
4-
# return np.sign((target[:,0] - p3[0]) * (p2[1] - p3[1]) - (p2[0] - p3[0]) * (target[:,1] - p3[1]))
1+
from typing import Union
52

6-
# def point_in_triangle(targets, p1, p2, p3):
3+
import numpy as np
74

8-
# d1 = sign(targets, p1, p2)
9-
# d2 = sign(targets, p2, p3)
10-
# d3 = sign(targets, p3, p1)
115

12-
# signs = np.column_stack([d1<0,d2<0,d3<0])
13-
# has_neg = signs.sum(axis=1)
14-
# has_pos = -signs.sum(axis=1)
15-
# return np.logical_not(np.logical_and(has_neg, has_pos))
6+
def is_point_inside_triangle(point: np.ndarray, A: np.ndarray, B: np.ndarray, C: np.ndarray) -> np.ndarray:
7+
"""Check if a point is inside a triangle
168
9+
Args:
10+
point (np.ndarray): point in vector. Shape (X, dimension).
11+
A (np.ndarray): point A of triangle. Shape (dimension).
12+
B (np.ndarray): point B of triangle. Shape (dimension).
13+
C (np.ndarray): point C of triangle. Shape (dimension).
1714
18-
def is_point_inside_triangle(point, A, B, C):
15+
Returns:
16+
np.ndarray: inside or not
17+
"""
1918
u = np.cross(C - B, point - B) @ np.cross(C - B, A - B)
2019
v = np.cross(A - C, point - C) @ np.cross(A - C, B - C)
2120
w = np.cross(B - A, point - A) @ np.cross(B - A, C - A)
2221

2322
return (u >= 0) & (v >= 0) & (w >= 0)
2423

2524

26-
def intersect_triangle_plane(P0, V, A, B, C):
25+
def intersect_triangle_plane(P0: np.ndarray, V: np.ndarray, A: np.ndarray, B: np.ndarray, C: np.ndarray) -> np.ndarray:
26+
"""Get if the ray go through the triangle of A,B,C
27+
28+
Args:
29+
P0 (np.ndarray): start point of ray
30+
V (np.ndarray): A point that the ray go through
31+
A (np.ndarray): point A of triangle. Shape (dimension).
32+
B (np.ndarray): point A of triangle. Shape (dimension).
33+
C (np.ndarray): point A of triangle. Shape (dimension).
34+
35+
Returns:
36+
np.ndarray: Whether the point go through triangle ABC
37+
"""
2738
# Calculate the normal vector of the plane
2839
N = np.cross(B - A, C - A)
2940

0 commit comments

Comments
 (0)