Skip to content

Commit 167606b

Browse files
committed
fix bugs
1 parent d7ff524 commit 167606b

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

stemflow/model/AdaSTEM.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -485,18 +485,28 @@ def SAC_ensemble_training(self, index_df: pd.core.frame.DataFrame, data: pd.core
485485
bootstrap_random_state = index_df['bootstrap_random_state'].iloc[0]
486486
rng = np.random.default_rng(bootstrap_random_state) # NumPy's random generator
487487
bootstrap_indices = rng.choice(data.index, size=len(data), replace=True) # Full bootstrap sample
488+
else:
489+
bootstrap_indices = None # Place holder
488490

489491
res_list = []
490492
for start in unique_start_indices:
491-
valid_index_window_data_df = data.index[
492-
(data[self.Temporal1] >= start) & (data[self.Temporal1] < start + self.temporal_bin_interval)
493-
]
494-
window_data_df_index = bootstrap_indices[np.isin(bootstrap_indices, valid_index_window_data_df)]
495-
window_data_df = data.loc[window_data_df_index] # So that we don't need to make a whole copy of the data
493+
494+
if self.ensemble_bootstrap:
495+
valid_index_window_data_df = data.index[
496+
(data[self.Temporal1] >= start) & (data[self.Temporal1] < start + self.temporal_bin_interval)
497+
]
498+
window_data_df_index = bootstrap_indices[np.isin(bootstrap_indices, valid_index_window_data_df)]
499+
window_data_df = data.loc[window_data_df_index] # So that we don't need to make a whole copy of the data
500+
del window_data_df_index, valid_index_window_data_df
496501

502+
else:
503+
window_data_df = data[
504+
(data[self.Temporal1] >= start) & (data[self.Temporal1] < start + self.temporal_bin_interval)
505+
]
506+
497507
window_data_df = transform_pred_set_to_STEM_quad(self.Spatio1, self.Spatio2, window_data_df, index_df)
498508
window_index_df = index_df[index_df[f"{self.Temporal1}_start"] == start]
499-
509+
500510
# Merge
501511
def find_belonged_points(df, df_a):
502512
return df_a[

stemflow/utils/quadtree.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,21 @@ def get_one_ensemble_quadtree(
191191
bootstrap_random_state = rng.integers(1e9)
192192
rng = np.random.default_rng(bootstrap_random_state) # NumPy's random generator
193193
bootstrap_indices = rng.choice(data.index, size=len(data), replace=True) # Full bootstrap sample
194+
else:
195+
bootstrap_indices = None # Place holder
194196

195197
ensemble_all_df_list = []
196198

197199
for time_block_index, bin_ in enumerate(temporal_bins):
198200
time_start = bin_[0]
199201
time_end = bin_[1]
200-
valid_index_sub_data = data.index[(data[Temporal1] >= time_start) & (data[Temporal1] < time_end)]
201-
sub_data_index = bootstrap_indices[np.isin(bootstrap_indices, valid_index_sub_data)]
202-
sub_data = data.loc[sub_data_index] # So that we don't need to make a whole copy of the data
202+
203+
if ensemble_bootstrap:
204+
valid_index_sub_data = data.index[(data[Temporal1] >= time_start) & (data[Temporal1] < time_end)]
205+
sub_data_index = bootstrap_indices[np.isin(bootstrap_indices, valid_index_sub_data)]
206+
sub_data = data.loc[sub_data_index] # So that we don't need to make a whole copy of the data
207+
else:
208+
sub_data = data[(data[Temporal1] >= time_start) & (data[Temporal1] < time_end)]
203209

204210
if len(sub_data) == 0:
205211
continue

0 commit comments

Comments
 (0)