Skip to content

nadavlab/FederatedBEHRT

Repository files navigation

BEHRT federated learning Implementation of the FederatedAveraging algorithm. Code of the paper: Federated Learning of Medical Concepts Embedding using BEHRT https://arxiv.org/abs/2305.13052

**If you used our code, please cite us: **

@article{shoham2023federated,
  title={Federated Learning of Medical Concepts Embedding using BEHRT},
  author={Shoham, Ofir Ben and Rappoport, Nadav},
  journal={arXiv preprint arXiv:2305.13052},
  year={2023}
}

FederatedAveraging algorithm proposed in the paper Communication-Efficient Learning of Deep Networks from Decentralized Data in PyTorch. Thanks to https://github.com/vaseline555/Federated-Learning-PyTorch for the initial code base.

Requirements

  • See requirements.txt

Configurations

  • See config.yaml data_dir_path: directory path that contains the multi-center data (csv files), BEHRT format. Each center data is one csv file. Example format: see behrt_nextvisit_example_data.csv test_path: path to test csv file, same format as behrt_nextvisit_example_data.csv. vocab_pickle_path: path to the pickle that contains the vocab.

You can you this script to create your pickle vocab: https://github.com/Ofir408/BEHRT/blob/master/preprocess/bert_vocab_builder.py

In order to create token2idx you can use the following script:

from typing import Dict
import json
import pandas as pd
from typing import List 

def get_all_codes(df: pd.DataFrame, codes_to_ignore: List[str]) -> List[str]:
    codes = []
    for df_list_codes in list(df['code']):
        codes.extend(df_list_codes)
    return list(set(codes) - set(codes_to_ignore))

def get_bert_tokens() -> Dict[str, int]:
    return {
      "PAD": 0,
      "UNK": 1,
      "SEP": 2,
      "CLS": 3,
      "MASK": 4,
    }
    
def build_token2index_dict(df: pd.DataFrame) -> Dict[str, int]:
    token2inx_dict = get_bert_tokens()
    next_index = max(token2inx_dict.values()) + 1
    
    codes = get_all_codes(df= df, codes_to_ignore=token2inx_dict.keys())
    for code in codes:
        token2inx_dict[str(code)] = next_index
        next_index += 1
    return token2inx_dict

def create_token2index_file(df: pd.DataFrame, output_file_path: str):
    token2inx_dict = build_token2index_dict(df= df)
    with open(output_file_path, 'w') as f:
        json.dump(token2inx_dict, f)
        print(f'token2inx was created, path={output_file_path}')

Run

  1. Install Dependencies

    Run the following command to install the required packages: pip install -r requirements.txt

  1. Preprocess MIMIC-IV Dataset:

    Use the provided notebook to preprocess the MIMIC-IV dataset: MIMIC-IV Feature Extraction Notebook

  1. Train MLM via Federated Learning

    Switch to the fl_mlm_training branch for the training code: git checkout fl_mlm_training

  2. Edit Configuration

    Modify the config.yaml file. Update the data_config section to reflect the paths of the datasets saved from the preprocessing notebook.

  3. Run Training

    Start the training process by running: python main.py <path_to_config.yaml>

  4. Switch back to the main branch to run the federated learning for the next visit prediction model: git checkout main

  5. Update Configuration

    In the configuration file, set the pretrained_model_path to the federated learning MLM model saved from the previous step.

  6. Run Next Visit Model Training

    Execute the training for the next visit model: python main.py <path_to_config.yaml>

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published