Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fashion MNIST using improve unified interface #108

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from

Conversation

rajeeja
Copy link
Collaborator

@rajeeja rajeeja commented Jan 18, 2024

Demo example for new community models.
This is a first draft that uses IMPROVE for initialize parameters and candle for checkpointing.

@rajeeja rajeeja requested review from brettin, wilke and adpartin January 18, 2024 21:20
# Get the data directory, batch size and other hyperparameters from params
##IMPROVE
batch_size = params["batch_size"]
learning_rate = params["learning_rate"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these parameters for inference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rajeeja I agree with @wilke, it's weird that these parameters are defined in inference script. Generally, inference is done without knowledge of training settings (e.g., train batch size, learning rate optimizer, etc.). Is there a reason why these are defined here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adpartin @wilke

  1. we need to load the test data in batches -> Hence - batch_size (see line 50)
  2. I'm getting the model from the ckpt method, we can get it via other means also. The ckpt object needs optimizer to instantiate -> Hence - optimizer (see line 64)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but it is unclear why you must load the data for inference in batches. Is this in any way faster than a simple

for v in File
  label=infer(v)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a better way of loading the model. We want the model from the input directory. No optimizer is needed. If this is a problem, I suggest writing a load_model_weights function as a wrapper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rajeeja @adpartin If we use this as an example, we have to make it clean. This is a great hack but not a sustainable solution. Please come up with a better solution or hide it in a function call. These are constants in this case.


# NOTE: using false now for data loading
testset = torchvision.datasets.FashionMNIST(root=dataset_dir, train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loading test data for inference? I would expect to load the model weights.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See line 72, how are you going to get images to test?
outputs = model(images)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand, or is this a problem of naming conventions? We are doing label prediction in this script no testing. Do you have a specific use case in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rajeeja Any thoughts?

model = Net().to(device)

# Define optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above :) we can get the model weights in some other fashion, do you know how else to get the model weights to perform inferece? - where optimizer or learning rate is not needed, it is a minor thing and can be ignored, IMO. The overall logic is to get the model weights and infer on what was done in the training step.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, get the model weights and infer on any input data in the input data directory. The model weights should be located there as well. The outputs of training are model weights and learning metrics.


## IMPROVE
# Note some of these are similar to previous section and may be adjusted as per model requirements
model_infer_params = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which parameter is for loading model weights?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm loading it from the ckpt files, so no specific parameter.

this can be done by using a specific directory to save the model weights and load from there, something along the lines of test_ml_data_dir

## IMPROVE
def run(params):
##
# # Define transformations for data preprocessing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is for data preprocessing, please move it to the preprocessing script.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trains on 60k images, but there are 10k images we need to infer on, those we have to get here. This is load not the entire preprocessing.

loss = running_loss / len(trainloader)
ckpt.ckpt_epoch(epoch, loss)
print('Training finished.')

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where you export the final weights.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ckpt, does that for you, if you run it once, you will see it in the directory.
see ckpt_epoch line 94 and documentation here: https://candle-lib.readthedocs.io/en/latest/api_ckpt_pytorch_utils/_autosummary/candle.ckpt_pytorch_utils.CandleCkptPyTorch.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The final weights should be in the top-level output directory.

@adpartin
Copy link
Contributor

@rajeeja @wilke The code structure doesn't follow the IMPROVE structure, so it's hard to go through the code and fix everything. https://jdacs4c-improve.github.io/docs/content/unified_interface.html

@rajeeja
Copy link
Collaborator Author

rajeeja commented Jan 19, 2024

@rajeeja @wilke The code structure doesn't follow the IMPROVE structure, so it's hard to go through the code and fix everything. https://jdacs4c-improve.github.io/docs/content/unified_interface.html

That link is not helpful. The data is different. This model doesn't use genetics or drug data, so, the structure won't be exactly the same. The .py files are very simple and easy to understand. Also the notebook is very standard fashion-mnist.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants