Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1) sequence recovery (sequence design) code request 2) generating .mdb file 3) PEFT 4) best model selection after training #73

Open
sj584 opened this issue Nov 18, 2024 · 1 comment

Comments

@sj584
Copy link

sj584 commented Nov 18, 2024

Hi,

I successfully ran the finetuning code using config/pretrain/saprot.py and config/Thermostability/saprot.py
Then I newly got these questions

I would really appreciate it if you could answer these.




1. Could you share sequence recovery (or sequence design) code?
I made it in my own way, but not sure whether it is correct

Pseudocode would be.

given these;
initial_tokens = ['M#', 'Ev', 'Vp', 'Qp', 'L#', 'Vy', 'Qd', 'Ya', 'Kv'] (initial sequence)
input_tokens = ['##', 'Ev', '#p', 'Qp', 'L#', '#y', '#d', 'Ya', 'Kv'] (masked sequence in sequence subtoken)
(##, #p, #y, #d)

the model predicts a single token (seq/structure token) solely from the masked token, the structure subtoken could be wrong.
(ex. ## -> Gr, #p -> Gp, #y -> Gp, #d -> Sd)

Then only extract the sequence token from the predicted token and reconstruct it. (structure subtoken is same)

input_tokens   = ['##', 'Ev', '#p', 'Qp', 'L#', '#y', '#d', 'Ya', 'Kv']
recovered_tokens = ['G#', 'Ev', 'Gp', 'Qp', 'L#, 'Gy', 'Sd', 'Ya', 'Kv']





2. I also made a code to generate the .mdb file as dataset
I checked that it runs ok. But not sure whether the id can be arbitrary or not. (ex. 550, 5500)
I would appreciate it if you could verify this code compared to yours

Generating .mdb file

'''python
import lmdb
import json

Example data
data = {
  "550": {"description": "A0A0J6SSW7", "seq": "M#R#A#A#A#T#L#L#V#T#L#C#V#V#G#A#N#E#A#R#A#GfIwLe..."},
  "5500": {"description": "A0A535NFD5", "seq": "AdAvRvEvAvLvRvAvSvGvHdPdFdVdEdAdPpGpEpAaAdFp..."},
  # Add more entries here
}

Open (or create) an LMDB environment
env = lmdb.open("my_lmdb_file", map_size=1e9) # map_size is the maximum size (in bytes) of the DB
with env.begin(write=True) as txn:
  # Add the length of the dataset for.. return int(self._get("length")) in SaprotFoldseekDataset
  length = len(data)
  txn.put("length".encode("utf-8"), str(length).encode("utf-8"))
  for key, value in data.items():
    # Convert the value to a JSON string
    value_json = json.dumps(value)
    # Store key-value pairs in the database; keys must be bytes
    txn.put(key.encode("utf-8"), value_json.encode("utf-8"))

Close the LMDB environment
env.close()
'''

Reading .mdb file

'''python
env = lmdb.open("my_lmdb_file/", readonly=True)

with env.begin() as txn:
  cursor = txn.cursor()
  for key, value in cursor:
    print(key, value)
'''





3. I onced asked whether PEFT is possible and you kindly answered that it is there in SaprotBaseModel.py
In the code, I could see that Lora can be used for downstream task.

In my case, I was hoping to use LoRA for MLM finetuning first in certain protein domain
and then do further finetuning on downstream task.

I somehow made the code but I think no approaches like this were available previously.
So I was asking your opinion. Whether it will be viable approaches or not.

So the steps will be

  1. Load SaProt model weights
  2. Use LoRA for MLM finetuning
  3. Load (SaProt model weights + Lora MLM finetuning weights)
  4. finetune on downstream task
  5. Load (SaProt model weights + Lora MLM finetuning weights + Lora downstream finetuning weights)
  6. Prediction on downstream task

Or simply downstream task can be done by getting the embeddings from the
(SaProt model weights + Lora MLM finetuning weights)
coz above mentioned steps are too complicated





4. When I ran the code using config/pretrain/saprot.py or config/pretrain/saprot.py
It seems that only one model is saved after training
If so, how can I know whether the saved model is the optimal model?

I could see that in Trainer, enable_checkpointing: false.
Should I change it into True and keep track of the result with wandb and find the model?





Thank you for reading long inqueries. It will be very helpful to me :)

@LTEnjoy
Copy link
Contributor

LTEnjoy commented Nov 18, 2024

Hi,

Glad to see you digging into the code very much!

  1. Could you share sequence recovery (or sequence design) code?

Of course! I have uploaded an new model file named saprot_if_model.py, which is used for protein inverse folding. The overall pipeline is nearly the same as you described above and you could check the function predict for more details. Simply you could follow the example to easily do the inverse folding:

from model.saprot.saprot_if_model import SaProtIFModel

# Load model
config = {
    "config_path": "/your/path/to/SaProt_650M_AF2_inverse_folding", # Please download the weights from https://huggingface.co/westlake-repl/SaProt_650M_AF2_inverse_folding
    "load_pretrained": True,
}

device = "cuda"
model = SaProtIFModel(**config)
model = model.to(device)

aa_seq = "##########" # All masked amino acids will be predicted. You could also partially mask the amino acids.
struc_seq = "dddddddddd"

# Predict amino acids given the structure sequence
pred_aa_seq = model.predict(aa_seq, struc_seq)
print(pred_aa_seq)
  1. About the generation of .mdb file

Sorry I didn't see your previously proposed issue asking for the code for generating .mdb file. You could refer to this reply #72 (comment) to generate your own .mdb dataset.

  1. Using LoRA for MLM training

I think it may not be necessary to first fine-tune SaProt using MLM function and then fine-tune it on the downstream task. In my opinion if you already have some labeled data you could directly fine-tune your model on this data and there is no need to do MLM pre-training at first. I guess the final performance should be comparable.

  1. The strategy for saving a checkpoint

I believe this issue #69 could resolve your question:)

Overall, thank you again for proposing such good questions! If you have any other questions, let me know and I'd love to help:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants