Skip to content

Commit

Permalink
add curves, modify saving functions (curve and km details)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Jan 5, 2024
1 parent fb950ed commit c072d31
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 39 deletions.
8 changes: 5 additions & 3 deletions examples/advanced/kaplan-meier-he/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ python baseline_kaplan_meier.py
```
By default, this will generate a KM curve image `km_curve_baseline.png` under the current working directory.

Then we run the federated job with simulator
Then we run a 5-client federated job with simulator
```commandline
nvflare simulator -w workspace_km_he -n 2 -t 2 jobs/kaplan-meier-he
nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he
```
By default, this will generate a KM curve image `km_curve_fl.png` under each client's directory.

## Display Result

By comparing the two curves, we can observe that the two are identical.
By comparing the two curves, we can observe that the two are identical:
![KM survival baseline](figs/km_curve_baseline.png)
![KM survival fl](figs/km_curve_fl.png)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "km_train.py"
app_script = "kaplan_meier_train.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
task_name = "train"
wf_class_path = "kaplan_meier_wf.KM",
wf_args {
min_clients = 2
min_clients = 5
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
np.random.seed(77)


def prepare_data(num_of_clients: int = 2, bin_days: int = 7):
def prepare_data(num_of_clients: int = 5, bin_days: int = 7):
# Load data
data_x, data_y = load_veterans_lung_cancer()
# Get total data count
Expand All @@ -57,11 +57,43 @@ def prepare_data(num_of_clients: int = 2, bin_days: int = 7):
return event_clients, time_clients


def save(result: dict):
def details_save(kmf):
# Get the survival function at all observed time points
survival_function_at_all_times = kmf.survival_function_
# Get the timeline (time points)
timeline = survival_function_at_all_times.index.values
# Get the KM estimate
km_estimate = survival_function_at_all_times["KM_estimate"].values
# Get the event count at each time point
event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events
# Get the survival rate at each time point (using the 1st column of the survival function)
survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values
# Return the results
results = {
"timeline": timeline.tolist(),
"km_estimate": km_estimate.tolist(),
"event_count": event_count.tolist(),
"survival_rate": survival_rate.tolist(),
}
file_path = os.path.join(os.getcwd(), "km_global.json")
print(f"save the result to {file_path} \n")
print(f"save the details of KM analysis result to {file_path} \n")
with open(file_path, "w") as json_file:
json.dump(result, json_file, indent=4)
json.dump(results, json_file, indent=4)


def plot_and_save(kmf):
# Plot and save the Kaplan-Meier survival curve
plt.figure()
plt.title("Federated HE")
kmf.plot_survival_function()
plt.ylim(0, 1)
plt.ylabel("prob")
plt.xlabel("time")
plt.legend("", frameon=False)
plt.tight_layout()
file_path = os.path.join(os.getcwd(), "km_curve_fl.png")
print(f"save the curve plot to {file_path} \n")
plt.savefig(file_path)


def main():
Expand Down Expand Up @@ -164,36 +196,12 @@ def main():
# Fit the model
kmf.fit(durations=time_unfold, event_observed=event_unfold)

# Plot and save the Kaplan-Meier survival curve
plt.figure()
plt.title("Federated HE")
kmf.plot_survival_function()
plt.ylim(0, 1)
plt.ylabel("prob")
plt.xlabel("time")
plt.legend("", frameon=False)
plt.tight_layout()
plt.savefig(os.path.join(os.getcwd(), "km_curve_fl.png"))

# Save global result to a json file
# Get the survival function at all observed time points
survival_function_at_all_times = kmf.survival_function_
# Get the timeline (time points)
timeline = survival_function_at_all_times.index.values
# Get the KM estimate
km_estimate = survival_function_at_all_times["KM_estimate"].values
# Get the event count at each time point
event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events
# Get the survival rate at each time point (using the 1st column of the survival function)
survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values
# Return the results
results = {
"timeline": timeline.tolist(),
"km_estimate": km_estimate.tolist(),
"event_count": event_count.tolist(),
"survival_rate": survival_rate.tolist(),
}
save(results)
# Plot and save the KM curve
print("plot KM curve!!!!!!!!!!!!")
plot_and_save(kmf)

# Save details of the KM result to a json file
details_save(kmf)

# Send a simple response to server
response = FLModel(params={}, params_type=ParamsType.FULL)
Expand Down

0 comments on commit c072d31

Please sign in to comment.