Skip to content

Commit

Permalink
7-31 version, 3 groups
Browse files Browse the repository at this point in the history
  • Loading branch information
pvarsh6 committed Jul 31, 2024
1 parent 6c3dd8f commit a971833
Show file tree
Hide file tree
Showing 7 changed files with 43,098 additions and 1,989 deletions.
9,509 changes: 7,644 additions & 1,865 deletions FIBERS_Survival_Demo.ipynb

Large diffs are not rendered by default.

35,179 changes: 35,179 additions & 0 deletions demo_new.ipynb

Large diffs are not rendered by default.

69 changes: 42 additions & 27 deletions src/skfibers/fibers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class FIBERS(BaseEstimator, TransformerMixin):
def __init__(self, outcome_label="Duration",outcome_type="survival",iterations=100,pop_size=50,tournament_prop=0.2,crossover_prob=0.5,min_mutation_prob=0.1,
max_mutation_prob=0.5,merge_prob=0.1,new_gen=1.0,elitism=0.1,diversity_pressure=0,min_bin_size=1,max_bin_size=None,max_bin_init_size=10,fitness_metric="log_rank",
log_rank_weighting=None,censor_label="Censoring",group_strata_min=0.2,penalty=0.5,group_thresh=0,min_thresh=0,max_thresh=5,
log_rank_weighting=None,censor_label="Censoring",group_strata_min=0.2,penalty=0.5,group_thresh_list = [1,4],min_thresh=0,max_thresh=5,
int_thresh=True,thresh_evolve_prob=0.5,manual_bin_init=None,covariates=None,pop_clean=None,report=None,random_seed=None,verbose=False):

"""
Expand Down Expand Up @@ -155,9 +155,10 @@ def __init__(self, outcome_label="Duration",outcome_type="survival",iterations=1
if penalty < 0 or penalty > 1:
raise Exception("'penalty' param must be an int or float from 0 - 1")

if not self.check_is_int(group_thresh) and not self.check_is_float(group_thresh) and group_thresh != None:
if group_thresh_list is not None and ((not self.check_is_int(group_thresh_list[0]) and not self.check_is_float(group_thresh_list[0])) or
(not self.check_is_int(group_thresh_list[1]) and not self.check_is_float(group_thresh_list[1]))):
raise Exception("'group_thresh' param must be a non-negative int or float, or None, for adaptive thresholding")
if group_thresh != None and group_thresh < 0:
if group_thresh_list is not None and (group_thresh_list[0] < 0 or group_thresh_list[1] < 0):
raise Exception("'group_thresh' param must be a non-negative int or float, or None, for adaptive thresholding")

if not self.check_is_int(min_thresh) and not self.check_is_float(min_thresh) or min_thresh < 0:
Expand Down Expand Up @@ -215,9 +216,9 @@ def __init__(self, outcome_label="Duration",outcome_type="survival",iterations=1
self.censor_label = censor_label
self.group_strata_min = group_strata_min
self.penalty = penalty
self.group_thresh = group_thresh
self.min_thresh = min_thresh
self.max_thresh = max_thresh
self.group_thresh_list = group_thresh_list
self.min_thresh = min_thresh
self.max_thresh = max_thresh
self.int_thresh = int_thresh
self.thresh_evolve_prob = thresh_evolve_prob
self.manual_bin_init = manual_bin_init
Expand Down Expand Up @@ -328,7 +329,7 @@ def fit(self, x, y=None):
#Initialize bin population
threshold_evolving = False #Adaptive thresholding - evolving thresholds is off by default for bin initialization
self.set = BIN_SET(self.manual_bin_init,self.df,self.feature_names,self.pop_size,
self.min_bin_size,self.max_bin_init_size,self.group_thresh,self.min_thresh,self.max_thresh,
self.min_bin_size,self.max_bin_init_size,self.group_thresh_list,self.min_thresh,self.max_thresh,
self.int_thresh,self.outcome_type,self.fitness_metric,self.log_rank_weighting,self.group_strata_min,
self.outcome_label,self.censor_label,threshold_evolving,self.penalty,self.iterations,0,self.residuals,self.covariates,random)
#Global fitness update
Expand All @@ -348,7 +349,7 @@ def fit(self, x, y=None):
#EVOLUTIONARY LEARNING ITERATIONS
for iteration in tqdm(range(1, self.iterations+ 1)):
# print('Iteration: '+str(iteration))
if self.group_thresh == None:
if self.group_thresh_list is None:
evolve = random.random()
if self.thresh_evolve_prob > evolve:
threshold_evolving = True
Expand All @@ -358,7 +359,7 @@ def fit(self, x, y=None):
#Occelating Mutation Rate
mutation_prob = (transform_value(iteration-1,cycle_length)*(self.max_mutation_prob-self.min_mutation_prob)/cycle_length)+self.min_mutation_prob

# GENETIC ALGORITHM
# GENETIC ALGORITHM
target_offspring_count = int(self.pop_size*self.new_gen) #Determine number of offspring to generate
while len(self.set.offspring_pop) < target_offspring_count: #Generate offspring until we hit the target number
# Parent Selection
Expand All @@ -368,7 +369,7 @@ def fit(self, x, y=None):
self.set.generate_offspring(self.crossover_prob,mutation_prob,self.merge_prob,self.iterations,iteration,parent_list,self.feature_names,
threshold_evolving,self.min_bin_size,self.max_bin_size,self.max_bin_init_size,self.min_thresh,self.max_thresh,
self.df,self.outcome_type,self.fitness_metric,self.log_rank_weighting,self.outcome_label,self.censor_label,self.int_thresh,
self.group_thresh,self.group_strata_min,self.penalty,self.residuals,self.covariates,random)
self.group_thresh_list,self.group_strata_min,self.penalty,self.residuals,self.covariates,random)
# Add Offspring to Population
self.set.add_offspring_into_pop(iteration)

Expand Down Expand Up @@ -506,6 +507,7 @@ def predict(self, x, bin_number=None):

# Count
bt_vote = [0]*len(temp_df) #votesum stored for each instance
mt_vote = [0]*len(temp_df) #votesum stored for each instance
at_vote = [0]*len(temp_df) #votesum stored for each instance

# Iterate through each row of the DataFrame
Expand All @@ -514,8 +516,10 @@ def predict(self, x, bin_number=None):
bin_count = 0
# Iterate through each value in the row
for value in row:
if value <= self.set.bin_pop[bin_count].group_threshold:
if value <= self.set.bin_pop[bin_count].group_threshold_list[0]:
bt_vote[row_count] += self.set.bin_pop[bin_count].pre_fitness
elif value > self.set.bin_pop[bin_count].group_threshold_list[0] and value <= self.set.bin_pop[bin_count].group_threshold_list[1]:
mt_vote[row_count] += self.set.bin_pop[bin_count].pre_fitness
else:
at_vote[row_count] += self.set.bin_pop[bin_count].pre_fitness
bin_count += 1
Expand All @@ -539,15 +543,15 @@ def performance_tracking(self,initialize,iteration):
#self.set.bin_pop = sorted(self.set.bin_pop, key=lambda x: x.fitness,reverse=True)
top_bin = self.set.bin_pop[0]
if initialize:
col_list = ['Iteration','Top Bin', 'Threshold', 'Fitness', 'Pre-Fitness', 'Log-Rank Score', 'Log-Rank p-value', 'Bin Size', 'Group Ratio', 'Count At/Below Threshold',
'Count Below Threshold','Birth Iteration','Residuals Score','Residuals p-value','Elapsed Time']
col_list = ['Iteration','Top Bin', 'Low Threshold', 'High Threshold', 'Fitness', 'Pre-Fitness', 'Log-Rank Score', 'Log-Rank p-value', 'Bin Size', 'Group Ratio',
'Count At/Below Threshold', 'Count Between Thresholds','Count Above Threshold','Birth Iteration','Residuals Score','Residuals p-value','Elapsed Time']
self.perform_track_df = pd.DataFrame(columns=col_list)
if self.verbose:
print(col_list)

tracking_values = [iteration,top_bin.feature_list,top_bin.group_threshold,top_bin.fitness,top_bin.pre_fitness,top_bin.log_rank_score,top_bin.log_rank_p_value,top_bin.bin_size,
top_bin.group_strata_prop,top_bin.count_bt,top_bin.count_at,top_bin.birth_iteration,top_bin.residuals_score,
top_bin.residuals_p_value,self.elapsed_time]
tracking_values = [iteration,top_bin.feature_list,top_bin.group_threshold_list[0], top_bin.group_threshold_list[1],top_bin.fitness,top_bin.pre_fitness,
top_bin.log_rank_score,top_bin.log_rank_p_value,top_bin.bin_size,top_bin.group_strata_prop,top_bin.count_bt,top_bin.count_mt,
top_bin.count_at,top_bin.birth_iteration,top_bin.residuals_score,top_bin.residuals_p_value,self.elapsed_time]
if self.verbose:
print(tracking_values)
# Add the row to the DataFrame
Expand Down Expand Up @@ -591,7 +595,7 @@ def get_bin_groups(self, x, y=None, bin_index=0):
:param bin_index: population index of the bin to return group information for
:return: low_outcome, high_outcome, low_censor, and high_censor
:return: low_outcome, mid_outcome, high_outcome, low_censor, mid_censor, and high_censor
"""
if not self.hasTrained:
raise Exception("FIBERS must be fit first")
Expand All @@ -608,15 +612,19 @@ def get_bin_groups(self, x, y=None, bin_index=0):
# Create evaluation dataframe including bin sum feature with
bin_df = pd.concat([bin_df,df.loc[:,self.outcome_label],df.loc[:,self.censor_label]],axis=1)

low_df = bin_df[bin_df['Bin_'+str(bin_index)] <= self.set.bin_pop[bin_index].group_threshold]
high_df = bin_df[bin_df['Bin_'+str(bin_index)] > self.set.bin_pop[bin_index].group_threshold]
low_df = bin_df[bin_df['Bin_'+str(bin_index)] <= self.set.bin_pop[bin_index].group_threshold_list[0]]
# mid_df = bin_df[(bin_df['Bin_'+str(bin_index)] > self.set.bin_pop[bin_index].group_threshold_list[0]) & (bin_df['Bin_'+str(bin_index)] <= self.set.bin_pop[bin_index].group_threshold_list[1])]
mid_df = bin_df[(bin_df['Bin_'+str(bin_index)] > self.set.bin_pop[bin_index].group_threshold_list[0]) & (bin_df['Bin_'+str(bin_index)] <= self.set.bin_pop[bin_index].group_threshold_list[1])]
high_df = bin_df[bin_df['Bin_'+str(bin_index)] > self.set.bin_pop[bin_index].group_threshold_list[1]]

low_outcome = low_df[self.outcome_label].to_list()
mid_outcome = mid_df[self.outcome_label].to_list()
high_outcome = high_df[self.outcome_label].to_list()
low_censor = low_df[self.censor_label].to_list()
high_censor =high_df[self.censor_label].to_list()
mid_censor = mid_df[self.censor_label].to_list()
high_censor = high_df[self.censor_label].to_list()
df = None
return low_outcome, high_outcome, low_censor, high_censor
return low_outcome, mid_outcome, high_outcome, low_censor, mid_censor, high_censor


def get_cox_prop_hazard_unadjust(self,x, y=None, bin_index=0, use_bin_sums=False):
Expand All @@ -633,7 +641,9 @@ def get_cox_prop_hazard_unadjust(self,x, y=None, bin_index=0, use_bin_sums=False

if not use_bin_sums:
# Transform bin feature values according to respective bin threshold
bin_df['Bin_'+str(bin_index)] = bin_df['Bin_'+str(bin_index)].apply(lambda x: 0 if x <= self.set.bin_pop[bin_index].group_threshold else 1)
bin_df['Bin_'+str(bin_index)] = bin_df['Bin_'+str(bin_index)].apply(
lambda x: 0 if x <= self.set.bin_pop[bin_index].group_threshold_list[0] else
(1 if x <= self.set.bin_pop[bin_index].group_threshold_list[1] else 2))

bin_df = pd.concat([bin_df,df.loc[:,self.outcome_label],df.loc[:,self.censor_label]],axis=1)
summary = None
Expand Down Expand Up @@ -665,7 +675,9 @@ def get_cox_prop_hazard_adjusted(self,x, y=None, bin_index=0, use_bin_sums=False

if not use_bin_sums:
# Transform bin feature values according to respective bin threshold
bin_df['Bin_'+str(bin_index)] = bin_df['Bin_'+str(bin_index)].apply(lambda x: 0 if x <= self.set.bin_pop[bin_index].group_threshold else 1)
bin_df['Bin_'+str(bin_index)] = bin_df['Bin_'+str(bin_index)].apply(
lambda x: 0 if x <= self.set.bin_pop[bin_index].group_threshold_list[0] else
(1 if x <= self.set.bin_pop[bin_index].group_threshold_list[1] else 2))

bin_df = pd.concat([bin_df,df.loc[:,self.outcome_label],df.loc[:,self.censor_label]],axis=1)
summary = None
Expand Down Expand Up @@ -746,7 +758,9 @@ def calculate_cox_prop_hazards(self,x, y=None, use_bin_sums=False):

if not use_bin_sums:
# Transform bin feature values according to respective bin threshold
bin_df['Bin'] = bin_df['Bin'].apply(lambda x: 0 if x <= bin.group_threshold else 1)
bin_df['Bin'] = bin_df['Bin'].apply(
lambda x: 0 if x <= bin.group_threshold_list[0] else
(1 if x <= bin.group_threshold_list[1] else 2))

# Create evaluation dataframe including bin sum feature, outcome, and censoring alone
bin_df = pd.concat([bin_df,df.loc[:,self.outcome_label],df.loc[:,self.censor_label]],axis=1)
Expand Down Expand Up @@ -802,8 +816,8 @@ def get_feature_tracking_plot(self,max_features=50,show=True,save=False,output_f


def get_kaplan_meir(self,data,bin_index,show=True,save=False,output_folder=None,data_name=None):
low_outcome, high_outcome, low_censor, high_censor = self.get_bin_groups(data, bin_index)
plot_kaplan_meir(low_outcome,low_censor,high_outcome, high_censor,show=show,save=save,output_folder=output_folder,data_name=data_name)
low_outcome, mid_outcome, high_outcome, low_censor, mid_censor, high_censor = self.get_bin_groups(data, bin_index)
plot_kaplan_meir(low_outcome,low_censor,mid_outcome, mid_censor,high_outcome, high_censor,show=show,save=save,output_folder=output_folder,data_name=data_name)


def get_fitness_progress_plot(self,show=True,save=False,output_folder=None,data_name=None):
Expand Down Expand Up @@ -868,7 +882,8 @@ def save_run_params(self,filename):
file.write(f"censor_label: {self.censor_label}\n")
file.write(f"group_strata_min: {self.group_strata_min}\n")
file.write(f"penalty: {self.penalty}\n")
file.write(f"group_thresh: {self.group_thresh}\n")
file.write(f"low_thresh: {self.group_thresh_list[0]}\n")
file.write(f"high_thresh: {self.group_thresh_list[1]}\n")
file.write(f"min_thresh: {self.min_thresh}\n")
file.write(f"max_thresh: {self.max_thresh}\n")
file.write(f"int_thresh: {self.int_thresh}\n")
Expand Down
Loading

0 comments on commit a971833

Please sign in to comment.