Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Adding Domain adaptation #51

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open

Conversation

bruce-edelman
Copy link
Contributor

!!This PR is still a WIP!!

Adding Domain Adaptation following what was done for SIA and ReLEARN from https://www.biorxiv.org/content/10.1101/2023.03.01.529396v1 (their code lives at https://github.com/ziyimo/popgen-dom-adapt)

This requires two major changes to diploshic:

  • Forking the network architecture after feature extraction
    • Discriminator fork of model has GRL to encourage the training to do a bad job at discriminating real/fake data
    • use masked loss functions so that each task is only done for data that makes sense
  • Adjusting the data generators to enable the inclusion of empirical data (target data) that your simulated data (source data) wants to 'adapt' to
    • this involves setting up a second 'Y' or target values for the discriminator prediction outputs

That should be it for the major implementation changes. The rest of this PR is small changes to the interfacing script that handles the logic of using the original model by default and then switching to the domain adaptive model with the CLI argument --domain-adaptation

Currently by default if you turn on domain adaptation then the code assumes that you have .fvec feature vector files created from your target domain data and stored in your training directory named empirical.fvec

Current steps left undone:

  • Construct different simulated data that is 'mis-matched' with current training data to see test the increased performance with domain adaptation if there is a mis-specification of your simulated data and data you want to do predictions on. -- this needs to be simulated data so we have labels to evaluate any changes in performance
  • Compare DA model with original in the mis-specification experiment
  • Compare predictions on the REAL data from soup to nuts example with original and DA data.

@bruce-edelman bruce-edelman self-assigned this Jul 6, 2023
@andrewkern andrewkern self-requested a review July 6, 2023 23:37
@bruce-edelman bruce-edelman marked this pull request as draft July 7, 2023 18:43
@bruce-edelman
Copy link
Contributor Author

Just added small fixes to the bugs you found @andrewkern -- one of the bugs was because train_test_split needs all the same length arrays input so this requires the number of your observations in emprical.fvec need to be the same as your training sets.

For the current hack of using the neut.fvec as our fake target domain data I just copied these 2000 data points 5 times to give 10000 obs to match the simulations.

With this change and a few array shape fixes the model begins training with --domain-adaptation on just fine for me now

if argsDict["domain_adaptation"]:
empirical = np.loadtxt(trainingDir + "empirical.fvec", skiprows=1)
emp = np.reshape(empirical, (empirical.shape[0], nDims, numSubWins))
emp1 = np.concatenate((emp,emp,emp,emp,emp))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the copy 5x line that should be removed in the future when user passes in empirical target domain data the same length of their training set simulations

@andrewkern
Copy link
Member

running this now! one warning I'm getting is

WARNING:tensorflow:Early stopping conditioned on metric `val_accuracy` which is not available. Available metrics are: loss,predictor_loss,discriminator_loss,predictor_accuracy,discriminator_accuracy

this has to do with the metrics on the early stopping criterion.

@bruce-edelman
Copy link
Contributor Author

Fixed the callback issue -- have code change from 'val_accuracy' to 'val_predictor_accuracy' for checkpointing and early stopping when using domain adaptation

@bruce-edelman bruce-edelman marked this pull request as ready for review July 7, 2023 21:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants