Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix factorization router is bloated #9

Open
KTibow opened this issue Jul 6, 2024 · 2 comments
Open

Matrix factorization router is bloated #9

KTibow opened this issue Jul 6, 2024 · 2 comments

Comments

@KTibow
Copy link

KTibow commented Jul 6, 2024

The current matrix factorization router (MFModel) is unnecessarily complex. Given that all operations in the forward pass are linear with no activations, we can significantly simplify this model.

Currently, we're doing several steps:

  1. Embedding model IDs
  2. Normalizing embeddings
  3. Projecting text embeddings
  4. Element-wise multiplication
  5. Linear classification

Since these are all linear operations, they can be collapsed into a single matrix multiplication, embedding * model. This would:

  • Reduce code complexity
  • Improve performance
  • Decrease the number of parameters
For an example of what this would look like, here's a flattened vector for the `mixtral-8x7b-instruct-v0.1` model:
[0.017317, 0.01118, 0.303653, 0.579448, -0.376634, -0.208742, -0.193392, 0.639659, -0.085909, 0.107312, 0.300785, -0.349391, -0.384368, -0.145022, 0.317397, -0.063074, -0.128751, 0.243364, -0.181707, 0.808825, 0.275169, 0.666149, -0.115858, 0.155953, 0.24292, -0.197154, -0.157491, 0.11632, 0.197647, 0.040279, 0.409797, -1.24056, -0.511287, -0.393113, -0.108808, 0.039914, 0.366597, 0.135737, 0.198802, 0.119974, 0.153426, -0.22505, 0.674797, 0.284063, -0.196429, 0.155066, -0.212335, -0.363016, 0.212736, 0.211674, -0.372157, 0.010955, 0.037939, -0.066029, -0.07933, -0.101132, -0.311588, 0.077285, -0.207608, 0.125983, 0.510143, -0.255973, -0.096116, 0.229892, -0.434007, 0.344456, -0.137472, 0.41125, -0.052777, 0.06959, -0.043151, -0.062137, 0.162818, 0.041656, 0.077479, 0.126347, 0.061875, 0.116124, 0.247373, 0.453157, -0.101855, 0.040579, -0.552021, 0.12112, -0.823787, -0.296899, -0.46667, 0.022095, -0.310721, -0.401873, 0.016014, 0.548683, -0.438079, 0.239599, 0.445288, 0.132415, 0.160069, 0.509489, 0.058122, 0.108559, -0.005905, -0.425724, -0.189577, 0.053441, 0.535769, -0.008355, 0.142684, 0.009374, -0.219168, 0.033156, -0.420615, -0.145288, -0.135326, -0.172469, -0.371276, -0.215616, -0.413526, -0.300192, 0.005224, -0.410654, -0.407338, -0.086193, 0.244957, -0.122847, -0.180609, -0.221066, -0.014492, -0.125457, 0.077016, -0.481223, 0.571043, -0.406598, 0.58677, -0.793018, -0.574046, 0.168964, -0.140509, 0.062438, -0.574914, 0.517542, 0.305398, -0.040312, 0.133368, 0.227152, 0.194301, -0.204837, 0.117291, 0.12243, -0.357704, -0.12873, -0.305825, 0.041006, -0.307598, 0.295055, -0.178294, -0.434973, 0.395057, 0.150889, -0.384117, 0.362086, -0.52113, 0.616021, 0.011602, -0.140464, -0.408201, 0.412563, -0.194249, -0.455087, -0.418446, 0.00887, -0.351214, 0.140768, 0.436393, 0.102469, 0.367923, -0.026261, 0.122515, -0.436895, -0.119046, -0.242179, -0.215407, 0.443827, -0.048456, -0.215488, -0.181659, 0.28717, 0.001767, 0.122558, -0.494734, 0.113986, -0.307884, 0.331145, 0.101143, 0.196236, 0.120081, -0.567916, 0.431674, 0.210675, 0.397218, -0.003222, -0.124574, 0.163897, -0.012514, -0.437356, -0.083122, -0.277771, -0.072012, 0.215686, 0.06413, -0.13142, 0.094984, -0.38486, -0.072067, 0.495137, -0.393166, -0.230718, -0.116048, 0.712394, 0.279401, 0.238164, 0.041076, 0.148722, 0.580803, 0.614566, -0.147473, 0.432371, 0.713287, -0.012816, 0.17443, 0.122719, -0.168159, 0.062227, -0.511618, -0.242144, -0.15323, 0.176365, -0.331397, -0.130046, 0.520083, -0.236528, -0.234034, -0.201974, -0.235412, 0.408897, -0.56152, 0.197764, -0.57766, -0.745011, -0.153192, 0.378314, 0.060145, -0.132778, 0.742793, 0.08398, -0.493689, 0.071289, 0.147504, -0.078614, 0.068797, -0.36456, -0.150841, -0.128449, -0.523257, -0.515868, -0.293334, -0.1087, -0.216722, -0.37464, -0.362562, 0.145622, 0.273712, -0.309493, 0.331044, -0.169746, -0.116288, 0.106025, 0.075529, -0.081049, 0.245917, 0.180469, -0.562014, -0.29112, 0.020882, 0.134045, 0.106949, 0.326906, 0.184262, 0.028326, 0.03369, -0.251042, 0.196618, -0.420975, -0.021204, -0.00376, 0.19101, -0.335425, -0.217719, 0.111878, -0.016975, 0.30771, 0.433765, 0.150516, -0.073278, 0.171964, -0.305194, 0.080526, 0.08366, -0.170164, 0.442168, 0.106601, -0.04912, -0.071456, 0.259819, -0.111718, 0.138566, -0.60584, 0.147761, 0.152774, 0.057143, -0.759514, 0.069949, -0.825877, 0.335864, 0.199449, -0.266243, -0.403074, 0.20985, -0.188599, -0.019244, -0.375069, -0.421033, -0.1462, -0.220748, -0.061277, -0.135211, -0.141422, -0.215439, 0.091185, -0.007891, -0.274731, -0.594342, -0.451199, 0.021285, 0.272036, -0.007255, 0.172055, -0.032696, 0.376444, -0.175173, 0.255335, -0.264267, 0.083475, 0.179118, 0.091082, 0.260392, 0.171118, 0.421613, -0.558687, -0.341742, -0.279588, -0.10411, -0.058658, 0.086409, 0.492655, -0.210353, 0.551876, -0.128579, -0.1514, 0.193864, 0.246684, -0.30106, 0.512475, -0.348025, 0.269122, -0.478439, 0.593487, 0.375225, 0.332428, 0.340556, 0.264401, 0.087356, 0.632642, 0.088945, -0.560939, 0.390676, 0.162052, 0.411639, -0.289915, -0.632261, -0.413713, 0.355988, -0.485467, 0.383603, 0.303537, -0.381534, -0.092763, 0.417598, 0.803573, -0.405532, -0.347625, 0.285436, -0.178229, -0.400952, -0.26588, 0.369776, 0.145592, -0.439068, -0.263021, -0.086975, -0.229377, 0.395024, -0.386163, 0.785113, -0.064416, 0.236719, -0.15956, 0.083073, -0.106436, 0.145759, 0.221116, 0.033651, -0.142519, -0.135173, -0.163156, 0.111862, 0.309598, 0.234952, -0.487281, 0.028245, -0.869042, 0.15329, 0.262507, 0.154243, 0.327218, -0.018497, -0.247377, -0.144596, 0.131668, 0.129669, 0.204863, 0.201405, 0.348766, -0.056677, 0.067078, -0.473868, 0.084789, 0.342684, 0.190402, -0.130603, -0.294488, -0.053648, -0.393713, -0.288915, -0.081103, 0.031378, 0.284718, 0.303025, 0.199114, 0.252887, 0.023813, 0.317988, -0.108757, 0.021954, 0.25852, 0.133369, 0.033529, 0.348155, 0.484368, -0.082799, -0.478818, 0.139661, -0.112743, 0.242792, 0.214462, -0.035537, 0.198981, 0.720355, 0.105166, 0.118839, -0.398867, 0.056433, 0.290868, -0.166652, -0.187253, -0.025869, 0.348516, -0.194342, -0.018384, -0.707275, -0.212753, -0.342487, 0.366088, 0.502167, -0.481192, 0.002395, -0.303666, -0.496359, -0.437127, 0.070926, -0.164983, 0.493785, 0.254769, -0.330188, 0.69813, 0.110217, -0.165885, -0.09634, 0.056251, 0.016961, 0.357136, -0.484442, 0.201198, 0.066619, -0.262838, -0.326917, 0.323244, 0.05816, 0.367642, -0.142585, -0.375839, -0.334615, 0.190125, -0.029278, 0.099811, 0.143965, -0.275176, -0.58817, -0.144267, -0.359548, -0.310475, 0.534627, -0.576121, -0.194085, 0.052187, -0.187249, -0.235678, -0.122017, 0.375999, -0.086289, -0.127235, 0.080998, 0.181145, 0.157067, -0.26179, -0.451746, -0.135946, -0.236, 0.052646, 0.336281, 0.21719, 0.186457, -0.002216, -0.056215, -0.369369, 0.442009, -0.228632, -0.175233, -0.292619, 0.18085, -0.222465, -0.071054, -0.036178, 0.42584, -0.052242, -0.186202, 0.438149, -0.189797, -0.16556, 0.036239, -0.02704, -0.254496, -0.069539, -0.232275, -0.695319, 0.460565, 0.33609, 0.51992, -0.283644, 0.143016, 0.185549, 0.012047, -0.222176, -0.130095, -0.261126, -0.422626, 0.286046, 0.318453, -0.25702, 0.280548, 0.066077, 0.205378, 0.221395, 0.134313, 0.202538, -0.112085, 0.112352, -0.311995, -0.114661, -0.305415, 0.163122, -0.162758, 0.064207, 0.100317, -0.297041, 0.153704, 0.412633, -0.236838, -0.213884, 0.043544, 0.078991, 0.026837, 0.399776, -0.292028, -0.702604, 0.238641, -0.057664, -0.338922, 0.101509, -0.030345, -0.092672, 0.189603, -0.184702, -0.224473, 0.232278, 0.167241, 0.204301, -0.074669, -0.31327, -0.069146, 0.169052, 0.34982, 0.001693, 0.495445, 0.169925, -0.079298, -0.00096, 0.068827, -0.110808, 0.049159, -0.156822, 0.033281, -0.138699, 0.064114, -0.183973, 0.299447, 0.020633, -0.394375, 0.22391, 0.29888, -0.162223, -0.154018, 0.0686, 0.091588, 0.010075, 0.177063, 0.337276, -0.258455, -0.172135, -0.309286, 0.11186, -0.063176, -0.131384, -0.117094, -0.025922, 0.217625, 0.064211, 0.097853, 0.21063, 0.209421, -0.003702, -0.12937, 0.568447, 0.056538, 0.071752, 0.131685, 0.265961, 0.13205, -0.342845, -0.14158, 0.327599, 0.206992, 0.380256, -0.092596, -0.077388, -0.19744, 0.0181, 0.287433, 0.088687, 0.097779, -0.044891, -0.404558, 0.147617, 0.422414, 0.11152, 0.308355, -0.106925, 0.204491, 0.043149, 0.065036, -0.753266, 0.122351, 0.336833, -0.00801, -0.262349, -0.193282, -0.103019, -0.089863, 0.171337, 0.309414, 0.014423, 0.098344, -0.110209, -0.169665, -0.030896, -0.097471, 0.00666, 0.101595, 0.061852, 0.176964, -0.21323, -0.099782, 0.228022, -0.262198, -0.425247, 0.417079, 0.017299, -0.191564, 0.004748, -0.250221, 0.234701, -0.271065, -0.057453, 0.304677, 0.4701, 0.250589, -0.087086, -0.429968, -0.26403, -0.387913, -0.464612, -0.342326, -0.071384, 0.056032, 0.187852, 0.380555, 0.189432, 0.34011, 0.266143, 0.009143, -0.317522, -0.234059, 0.276891, 0.174809, 0.140528, -0.105288, -0.65848, 0.084518, -0.234592, 0.318019, 0.510351, 0.006479, 0.537869, -0.392096, -0.411233, -0.189889, 0.134191, -0.075683, -0.169409, 0.125705, -0.327027, -0.066445, -0.52144, -0.097577, -0.177766, 0.232948, -0.135097, -0.343601, -0.091137, 0.062618, 0.053287, 0.18644, -0.6094, 0.048837, 0.267879, -0.413453, -0.141747, 0.207981, -0.04925, -0.174698, -0.509869, -0.476397, 0.068638, -0.152651, 0.104868, 0.197331, -0.064872, -0.1051, -1.40418, -0.194817, 0.208227, -0.045253, -0.232286, 0.073835, 0.12477, 0.393212, 0.347051, -0.187002, 0.079182, -0.27366, -0.215268, 0.375153, 0.270839, -0.334651, -0.126299, 0.34891, -0.174526, 0.234166, -0.317101, 0.057596, -0.157946, 0.15384, 0.16841, 0.158807, -0.192711, 0.192967, -0.262208, 0.108206, 0.238273, 0.236885, -0.399003, 0.221671, 0.038937, -0.107384, 0.288186, 0.160961, -0.086901, 0.055572, -0.190251, -0.233012, -0.054056, -0.080065, 0.111019, -0.044721, 0.036763, 0.068096, -0.017873, 0.261569, 0.346434, 0.065229, -0.023851, -0.330086, 0.213761, 0.128141, -0.138356, -0.062674, 0.195684, 0.215495, 0.194634, -0.339133, -0.268465, -0.298594, -0.362164, -0.253306, -0.168292, 0.199113, -0.524123, -0.090773, -0.096247, 0.046664, -0.046513, 0.13497, 0.114262, -0.488398, -0.2347, 0.26051, 0.031243, -0.152594, 0.258885, -0.064539, -0.176934, -0.027078, 0.197796, -0.050404, 0.004199, -0.020745, -0.127675, 0.053641, 0.515427, 0.131214, 0.353022, 0.284469, 0.01992, 0.120054, -0.318418, -0.026164, 0.306722, 0.035191, 0.425452, 0.046934, 0.010072, -0.134704, -0.118026, 0.033954, 0.444288, 0.004718, 0.035425, -0.030341, 0.394551, -0.165347, -0.115437, -0.017297, -0.585792, 0.17584, 0.377414, 0.421793, 0.188193, 0.307312, 0.610973, -0.196335, -0.29751, -0.105334, 0.199592, -0.195532, -0.095663, 0.142824, 0.130411, -0.080841, 0.202719, 0.471838, -0.072826, 0.246151, 0.109777, -0.101721, 0.169312, 0.54931, -0.074526, 0.021988, -0.096728, -0.223985, -0.058271, 0.23175, -0.332564, 0.169538, -0.225755, 0.046639, 0.136866, -0.158008, 0.114861, 0.065593, -0.117845, 0.490567, -0.378452, 0.408763, 0.048036, 0.315145, -0.041749, 0.309414, 0.031155, 0.347439, -0.051953, -0.201888, 0.179567, 0.17787, 0.152476, -0.050791, 0.420996, -0.111863, 0.110077, 0.268456, -0.074361, -0.144558, 0.119518, 0.188343, 0.396397, -0.381355, 0.012706, 0.245918, 0.26378, 0.207468, 0.06862, 0.268775, 0.503796, -0.042588, 0.299801, 0.264099, 0.567906, 0.343754, 0.112813, -0.058419, 0.151873, 0.105714, 0.013268, -0.104881, 0.179048, 0.103319, 0.155907, -0.207802, -0.594822, 0.001902, 0.334797, -0.128813, 0.02412, 0.158227, 0.232278, -0.168783, -0.101024, 0.001426, -0.334838, -0.25871, -0.281469, 0.175912, 0.173545, 0.199818, 0.156694, -0.202074, -0.528855, 0.341782, -0.294037, -0.567092, 0.042527, 0.229844, -0.274017, 0.111275, 0.022757, -0.276101, 0.432179, 0.322151, -0.11445, 0.865446, 0.367544, 0.267589, 0.00913, -0.410267, 0.137246, -0.013712, 0.620266, -0.091809, -0.297659, -0.373554, 0.207084, -0.421513, -0.183964, -0.156403, 0.219091, -0.508866, 0.516564, -0.361563, -0.201876, 0.202988, 0.183052, -0.22674, 0.057602, 0.041183, -0.211405, 0.247517, 0.204372, 0.042675, -0.214661, -0.111943, 0.009249, -0.014273, -0.351459, 0.070249, -0.315316, 0.133022, -0.073426, -0.180068, -0.333467, -0.067528, 0.357887, 0.430013, 0.131229, 0.298485, 0.373571, -0.302588, -0.04142, -0.344667, -0.283525, 0.640575, 0.317337, 0.401381, 0.189486, 0.073186, 0.02416, -0.215443, 0.056143, 0.120336, -0.231008, -0.105986, -0.453503, -0.219785, -0.030274, -0.367342, -0.113358, 0.196147, 0.291157, 0.326472, 0.446857, -0.085561, 0.010959, 0.066616, 0.15023, -0.209559, -0.112984, 0.072598, -0.427699, -0.260073, 0.032521, 0.081192, -0.014159, 0.143266, 0.197289, 0.067981, 0.173343, -0.155237, 0.193014, -0.033441, -0.270513, 0.12482, -0.140087, -0.524852, -0.142413, -0.197585, 0.069683, 0.00106, -0.060416, 0.241788, -0.273508, 0.014679, -0.066452, -0.355985, -0.262008, 0.26785, -0.009632, 0.163352, -0.068926, 0.46138, -0.317769, -0.397394, 0.224559, 0.352467, -0.097191, -0.287376, 0.408935, 0.345993, 0.09068, 0.2473, 2.3111, -0.14702, -0.111799, -0.052716, 0.230692, 0.225265, -0.35181, 0.094639, -0.154193, 0.185283, -0.315491, -0.077438, 0.24265, -0.103315, -0.156623, -0.086985, -0.316301, 0.000796, -0.025065, -0.097864, -0.362233, -0.448295, -0.403811, 0.258856, -0.100113, -0.055167, 0.294756, 0.024366, 0.102181, -0.106253, 0.023481, 0.160745, 0.063656, 0.155556, -0.336469, 0.325614, -0.266145, -0.074525, 0.201849, 0.441004, -0.174538, 0.131324, 0.284181, -0.261139, 0.098757, -0.019434, -0.194059, -0.108849, -0.072083, -0.093592, -0.285213, -0.176247, 0.069006, 0.297378, -0.025485, 0.268425, -0.101778, 0.018244, 0.776521, 0.297483, 0.251349, -0.167599, -0.30711, 0.070886, 0.01418, 0.285411, -0.430578, -0.237813, 0.059797, 0.027026, -0.0401, 0.143306, -0.469388, 0.055392, 0.137084, 0.284571, 0.189084, -0.405384, 0.135162, -0.680802, -0.434545, -0.210474, 0.30213, 0.114895, 0.167591, -0.307093, -0.255949, 0.242898, 0.187186, 0.3594, -0.125649, 0.174752, 0.301497, -0.150837, 0.118552, 0.144685, 0.023964, 0.20746, -0.186843, 0.230801, 0.11998, 0.099391, -0.390997, 0.242291, -0.209336, -0.369022, 0.225537, -0.254627, -0.19489, 0.007398, 0.30297, -0.100568, -0.039901, -0.267365, 0.17685, 0.032181, -0.051405, -0.003954, 0.061989, -0.398622, -0.102953, 0.230554, 0.369276, -0.32691, 0.121757, 0.282954, 0.275177, 0.301383, -0.048143, -0.102173, 0.270449, 0.326503, 0.356696, 0.198148, 0.566387, 0.118633, 0.069914, 0.049507, 0.264942, -0.021149, -0.315653, 0.195143, -0.037403, -0.560274, 0.036958, 0.226462, -0.187307, 0.00932, 0.06245, 0.158091, -0.02271, 0.303259, -0.281134, 0.229444, 0.202054, -0.022002, -0.175618, -0.035272, -0.416639, -0.079588, -0.190756, 0.237299, 0.128946, -0.025495, 0.31631, 0.165038, -0.036987, -0.056892, -0.472618, -0.240427, 0.258912, 0.142983, -0.017613, 0.09934, 0.301944, -0.317137, -0.045731, 0.176888, -0.237915, 0.034828, -0.244753, -0.262084, 0.007381, 0.179293, 0.012775, 0.134795, -0.16332, -0.444582, -0.080167, 0.024672, -0.090209, -0.09143, 0.177423, 0.066397, -0.464973, 0.473688, 0.156524, -0.011874, -0.018553, 0.049021, -0.058733, -0.16094, -0.055641, 0.084314, -0.180604, -0.147321, 0.507487, 0.259353, 0.214523, 0.136566, 0.10569, -0.117942, 0.207137, 0.524199, 0.176873, 0.319673, 0.065076, 0.200993, 0.067377, -0.128274, -0.148678, -0.369512, -0.073067, 0.022234, -0.376015, -0.161213, -0.004808, -0.385252, -0.063738, 0.172607, -0.040167, -0.120519, 0.296494, -0.195137, 0.055634, 0.323904, -0.638334, -0.255347, -0.100382, 0.251132, -0.055979, 0.004391, -0.289993, -0.004406, 0.050617, 0.410566, 0.452379, -0.556643, 0.081581, 0.137408, 0.254382, 0.251986, 0.082583, -0.024478, -0.477649, 0.310222, 0.211715, 0.022005, 0.063267, -0.130571, 0.155438, 0.380635, 0.231092, 0.099042, -0.391679, -0.058661, -0.540002, -0.358878, -0.324142, 0.243863, -0.400055, 0.103157, -0.262598, -0.044676, -0.444585, 0.030034, 0.01668, 0.311564, 0.543531, -0.047709, -0.113976, -0.304748, -0.150807, -0.274888, 0.024604, -0.183968, 0.024504, 0.393683, -0.430544, -0.323938, 0.306146, -0.039433, -0.189903, 0.057104, 0.19676, 0.036725, 0.079969, -0.205473, -0.314785, 0.030175, -0.049927, 0.061419, -0.36235, -0.056072, 0.159138, 0.456674, 0.007084, 0.441482, -0.175448, 0.061765, 0.412505, -0.402356, -0.084174, 0.085337, -0.180057, 0.284374, 0.031825, 0.15114, 0.045856, 0.362218, 0.371848, 0.142496, 0.376347, 0.309523, 0.437986, -0.178713, -0.200895, -0.046065, 0.183416, -0.31115, 0.299963, -0.005362, 0.397519, -0.025268, 0.382294, -0.424654, -0.169118, 0.246686, -0.017109, -0.480841, -0.132066, 0.066515, -0.014366, 0.487456, -0.023139, 0.006938, 0.314802, 0.340747, -0.010792, 0.064729, 0.304637, 0.072488, -0.257531, -0.164407, -0.238009, 0.251726, 0.442151, -0.439882, -0.096664, 0.030146, -0.100694, -0.168094, -0.193923, 0.46795, 0.080172, 0.063586, -0.328571, -0.16416, -0.259619, 0.293085, -0.279067, 0.232538, 0.033095, -0.198362, -0.305268, -0.361208, 0.034213, 0.427696, -0.033954, -0.227259, 0.01694, -0.551509, -0.055286, -0.099024, 0.267421, 0.104194, 0.000865, -0.088973, 0.200319]
@KTibow
Copy link
Author

KTibow commented Jul 6, 2024

Code used:

v1 = loaded["P.weight"][36] / np.sqrt(np.sum(loaded["P.weight"][36] ** 2))
v2 = v1 * loaded["classifier.0.weight"][0]
v3 = v2 @ loaded["text_proj.0.weight"]

@iojw
Copy link
Collaborator

iojw commented Jul 7, 2024

Hi there, thank you for raising this! You are right that currently, all the operations can be collapsed into a single matrix multiplication. The reason the operations are broken up is that we were experimenting with nonlinear variants as well when training the matrix factorization router.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants