-
Notifications
You must be signed in to change notification settings - Fork 5
aniketde/DomainGeneralizationMarginal
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Domain Generalization by Marginal Transfer Learning - Under Submission at JMLR This compares marginal predictor method with pooling method. It uses kernel approximation and then liblinear as a solver for speed up. This code gives comparison of pooling and transfer learning for regression and binary classification. %% Instructions - 1) Download liblinear library and put it in the main workspace. 2) Compile the library from matlab folder inside it. 3) For demo, run demo_transfer_pooling. %% Files 1) demo_transfer_pooling = Gives demo for both regression and binary classification 2) cross_validation_pooling_transfer = This is a main file and automatically chooses the optimization parameter plus kernel bandwidths. One can also give their own parameters. After choosing bandwiths it gives error comparison of pooling and transfer learning. 3) util_mrg_datasets = merge all datasets 4) util_mrg_datasets_CV = merge and split training datasets into training and validation sets 5) pooled_CV = Cross Validation for pooling method 6) transfer_CV = Cross Validation for transfer learning method 7) transfer_test_accuracy_CV = Calculates the accuracy on validtion set for transfer learning method 8) pooled_transfer_comparison = After parameters are set, this file calcualtes the actual training and test error and does comparison 9) pooled_test_accuracy = Calculates error for pooling method 10) transfer_test_accuracy = Calculates error for transfer learning method 11) synth_binary_data = synthtic data for binary classification %% Demo run demo_transfer_pooling. This compares marginal predictor method [1] with pooling method. It uses kernel approximation and then liblinear as a solver for speed up input - train and test datasets = datasets in cell format. - datasets{i}.testx contains all features and Dataset{i}.testy contains all labels. Don't get mislead by testx name. It contains all datapoints. - numberOfTrainingUser = number of datasets that are used for training In datasets First numberOfTrainingUser are considered as training datasets - numberOfExamplesPerTask - rand_perm_test - datasets to be used as test datasets - task_type - it could be 'regression' or 'binary' - cross_val - 1 if you want to do cross val otherwise 0 - L,Q,D - number of random fourier features to approximate kernel output - res_avg_test = test error using marginal predictors method - res_avg_train = train error using marginal predictors method - res_avg_test_pooled = test error using pooling method - res_avg_train_pooled = train error using pooling method - In case of regression - squared error - In case of binary classification - % 0-1 error Warning - THIS METHOD WORKS ONLY FOR FIVE FOLD CROSS VALIDATION. SO numberOfTrainingUser SHOULD BE ONLY MULTIPLE OF 5
About
No description, website, or topics provided.
Resources
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published