From 853b65dfcdcd61ad09b8b4949e68bfbb3f7cb56f Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Tue, 12 Apr 2022 15:34:04 -0700 Subject: [PATCH 01/27] DRAFT: first draft of model --- examples/refactor/041222pytorchdraft.ipynb | 495 +++++++++++++++++++++ examples/refactor/soil_metabolites.biom | Bin 0 -> 53681 bytes examples/refactor/soil_microbes.biom | Bin 0 -> 81409 bytes 3 files changed, 495 insertions(+) create mode 100644 examples/refactor/041222pytorchdraft.ipynb create mode 100644 examples/refactor/soil_metabolites.biom create mode 100644 examples/refactor/soil_microbes.biom diff --git a/examples/refactor/041222pytorchdraft.ipynb b/examples/refactor/041222pytorchdraft.ipynb new file mode 100644 index 0000000..968710a --- /dev/null +++ b/examples/refactor/041222pytorchdraft.ipynb @@ -0,0 +1,495 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "213bcdfc", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset\n", + "from torch.distributions import Multinomial\n", + "import biom" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "382bb9ce", + "metadata": {}, + "outputs": [], + "source": [ + "# some example data\n", + "microbes = biom.load_table(\"./soil_microbes.biom\")\n", + "metabolites = biom.load_table(\"./soil_metabolites.biom\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "96fac3bf", + "metadata": {}, + "outputs": [], + "source": [ + "class MicrobeMetaboliteData(Dataset):\n", + " def __init__(self, microbes: biom.table, metabolites: biom.table):\n", + " # arrange\n", + " self.microbes = microbes.to_dataframe().T \n", + " self.metabolites = metabolites.to_dataframe().T\n", + " \n", + " # only samples that have results\n", + " self.microbes = self.microbes.loc[self.metabolites.index]\n", + " \n", + " # convert to tensors/final form\n", + " self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)\n", + " self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)\n", + " \n", + " # counts\n", + " self.microbe_count = self.microbes.shape[1]\n", + " self.metabolite_count = self.metabolites.shape[1]\n", + " \n", + " # relative frequencies\n", + " self.microbe_relative_frequency = (self.microbes.T\n", + " / self.microbes.sum(1)\n", + " ).T\n", + " \n", + " self.metabolite_relative_frequency = (self.metabolites.T\n", + " / self.metabolites.sum(1)\n", + " ).T\n", + " \n", + " self.total_microbe_observations = self.microbes.sum()\n", + " \n", + " def __len__(self):\n", + " return self.total_microbe_observations" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "234ccc47", + "metadata": {}, + "outputs": [], + "source": [ + "example_data = MicrobeMetaboliteData(microbes, metabolites)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0ab12e60", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "424846" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example_data.total_microbe_observations.item()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f106a231", + "metadata": {}, + "outputs": [], + "source": [ + "class MMVec(nn.Module):\n", + " def __init__(self, num_microbes, num_metabolites, latent_dim):\n", + " super().__init__()\n", + " #\n", + " self.encoder = nn.Embedding(num_microbes, latent_dim)\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(latent_dim, num_metabolites),\n", + " # [batch, sample, metabolite]\n", + " nn.Softmax(dim=2)\n", + " )\n", + " \n", + " # X = batch_size of microbe indexes\n", + " # Y = expected metabolite data\n", + " def forward(self, X, Y):\n", + " \n", + " # pass our random draws to our embedding\n", + " z = self.encoder(X)\n", + " \n", + " # from latent dimensions in embedding through\n", + " # our linear function to predicted metabolite frequencies which\n", + " # we then normalize with softmax\n", + " y_pred = self.decoder(z)\n", + " \n", + " # total_count=0 and validate_args=False allows skipping total count when calling log_prob\n", + " # as there having floating point issues leading to \"incorrect\" total counts.\n", + " # This multinomial is generated from the output of the single\n", + " forward_dist = Multinomial(total_count=0,\n", + " validate_args=False,\n", + " probs=y_pred)\n", + " \n", + " # the log probability of drawing our expected results from our \"predictions\"\n", + " forward_dist = forward_dist.log_prob(Y)\n", + " \n", + " # get sample loss, a sample in each \"row\"/ zeroeth dimension of the tensor\n", + " forward_dist = forward_dist.mean(0)\n", + " \n", + " # total log probability loss in regards to all samples\n", + " lp = forward_dist.mean()\n", + "\n", + " return lp" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b74bdf61", + "metadata": {}, + "outputs": [], + "source": [ + "mmvec_model = MMVec(example_data.microbe_count, example_data.metabolite_count, 15)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cbc8d647", + "metadata": {}, + "outputs": [], + "source": [ + "def train_loop(dataset, model, optimizer, batch_size):\n", + " \n", + " # because we are wanting to look at all of the samples together we are having to \n", + " # handle our own batching for now. This method currently leads to slight over-\n", + " # sampling but can be refined.\n", + " n_batches = torch.div(dataset.total_microbe_observations.item(),\n", + " batch_size,\n", + " rounding_mode = 'floor') + 1\n", + " \n", + " # We will want to implement batching functionality later for\n", + " # paralizability, but for now running on cpu this works.\n", + " for batch in range(n_batches * epochs):\n", + " \n", + " # the draws we will be training each batch on that will\n", + " # be fed to all samples in our model. This step will probably be\n", + " # moved to a sampler or collate_fn somewhere in the dataset/dataloader\n", + " # but how exactly that will work is not clear at the moment\n", + " draws = torch.multinomial(dataset.microbe_relative_frequency,\n", + " batch_size,\n", + " replacement=True).T\n", + " \n", + " # \"forward step\", our model generates our \"predictions\", so there is no need to\n", + " # call `forward` separately.\n", + " lp = model(draws,\n", + " dataset.metabolite_relative_frequency)\n", + " \n", + " # this location is idiomatic but flexible\n", + " optimizer.zero_grad()\n", + " \n", + " # the typical training bit.\n", + " lp.backward()\n", + " optimizer.step()\n", + " \n", + " if batch % 100 == 0:\n", + " print(f\"loss: {lp.item()}\\nBatch #: {batch}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfb75b21", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: -4.114527225494385\n", + "Batch #: 0\n", + "loss: -3.6144325733184814\n", + "Batch #: 100\n", + "loss: -3.0469698905944824\n", + "Batch #: 200\n", + "loss: -2.70939564704895\n", + "Batch #: 300\n", + "loss: -2.5499744415283203\n", + "Batch #: 400\n", + "loss: -2.473045587539673\n", + "Batch #: 500\n", + "loss: -2.4374732971191406\n", + "Batch #: 600\n", + "loss: -2.421781539916992\n", + "Batch #: 700\n", + "loss: -2.4101920127868652\n", + "Batch #: 800\n", + "loss: -2.4041030406951904\n", + "Batch #: 900\n", + "loss: -2.4012131690979004\n", + "Batch #: 1000\n", + "loss: -2.397974967956543\n", + "Batch #: 1100\n", + "loss: -2.3931915760040283\n", + "Batch #: 1200\n", + "loss: -2.3923048973083496\n", + "Batch #: 1300\n", + "loss: -2.389982223510742\n", + "Batch #: 1400\n", + "loss: -2.3868303298950195\n", + "Batch #: 1500\n", + "loss: -2.3855628967285156\n", + "Batch #: 1600\n", + "loss: -2.382643222808838\n", + "Batch #: 1700\n", + "loss: -2.381664991378784\n", + "Batch #: 1800\n", + "loss: -2.3774473667144775\n", + "Batch #: 1900\n", + "loss: -2.378610372543335\n", + "Batch #: 2000\n", + "loss: -2.3776485919952393\n", + "Batch #: 2100\n", + "loss: -2.376375675201416\n", + "Batch #: 2200\n", + "loss: -2.3723671436309814\n", + "Batch #: 2300\n", + "loss: -2.372851848602295\n", + "Batch #: 2400\n", + "loss: -2.373134136199951\n", + "Batch #: 2500\n", + "loss: -2.3704051971435547\n", + "Batch #: 2600\n", + "loss: -2.37052059173584\n", + "Batch #: 2700\n", + "loss: -2.371293306350708\n", + "Batch #: 2800\n", + "loss: -2.3711659908294678\n", + "Batch #: 2900\n", + "loss: -2.3693435192108154\n", + "Batch #: 3000\n", + "loss: -2.370833396911621\n", + "Batch #: 3100\n", + "loss: -2.36956787109375\n", + "Batch #: 3200\n", + "loss: -2.3683981895446777\n", + "Batch #: 3300\n", + "loss: -2.368025064468384\n", + "Batch #: 3400\n", + "loss: -2.3673665523529053\n", + "Batch #: 3500\n", + "loss: -2.3669538497924805\n", + "Batch #: 3600\n", + "loss: -2.364877700805664\n", + "Batch #: 3700\n", + "loss: -2.3676393032073975\n", + "Batch #: 3800\n", + "loss: -2.3655707836151123\n", + "Batch #: 3900\n", + "loss: -2.365952253341675\n", + "Batch #: 4000\n", + "loss: -2.366527557373047\n", + "Batch #: 4100\n", + "loss: -2.364421844482422\n", + "Batch #: 4200\n", + "loss: -2.363978385925293\n", + "Batch #: 4300\n", + "loss: -2.3649704456329346\n", + "Batch #: 4400\n", + "loss: -2.364382743835449\n", + "Batch #: 4500\n", + "loss: -2.361299991607666\n", + "Batch #: 4600\n", + "loss: -2.3609752655029297\n", + "Batch #: 4700\n", + "loss: -2.3623459339141846\n", + "Batch #: 4800\n", + "loss: -2.3606176376342773\n", + "Batch #: 4900\n", + "loss: -2.3621227741241455\n", + "Batch #: 5000\n", + "loss: -2.3601856231689453\n", + "Batch #: 5100\n", + "loss: -2.3616325855255127\n", + "Batch #: 5200\n", + "loss: -2.3607864379882812\n", + "Batch #: 5300\n", + "loss: -2.3603267669677734\n", + "Batch #: 5400\n", + "loss: -2.3611979484558105\n", + "Batch #: 5500\n", + "loss: -2.36138653755188\n", + "Batch #: 5600\n", + "loss: -2.3617565631866455\n", + "Batch #: 5700\n", + "loss: -2.3602635860443115\n", + "Batch #: 5800\n", + "loss: -2.3588624000549316\n", + "Batch #: 5900\n", + "loss: -2.363048791885376\n", + "Batch #: 6000\n", + "loss: -2.357430934906006\n", + "Batch #: 6100\n", + "loss: -2.359692335128784\n", + "Batch #: 6200\n", + "loss: -2.359476327896118\n", + "Batch #: 6300\n", + "loss: -2.358708381652832\n", + "Batch #: 6400\n", + "loss: -2.3578848838806152\n", + "Batch #: 6500\n", + "loss: -2.3591620922088623\n", + "Batch #: 6600\n", + "loss: -2.3596458435058594\n", + "Batch #: 6700\n", + "loss: -2.358290672302246\n", + "Batch #: 6800\n", + "loss: -2.3569066524505615\n", + "Batch #: 6900\n", + "loss: -2.3586177825927734\n", + "Batch #: 7000\n", + "loss: -2.359415054321289\n", + "Batch #: 7100\n", + "loss: -2.358649969100952\n", + "Batch #: 7200\n", + "loss: -2.35966420173645\n", + "Batch #: 7300\n", + "loss: -2.358867883682251\n", + "Batch #: 7400\n", + "loss: -2.3568341732025146\n", + "Batch #: 7500\n", + "loss: -2.3596749305725098\n", + "Batch #: 7600\n", + "loss: -2.359412670135498\n", + "Batch #: 7700\n", + "loss: -2.357198476791382\n", + "Batch #: 7800\n", + "loss: -2.358001947402954\n", + "Batch #: 7900\n", + "loss: -2.3569891452789307\n", + "Batch #: 8000\n", + "loss: -2.3587193489074707\n", + "Batch #: 8100\n", + "loss: -2.3581130504608154\n", + "Batch #: 8200\n", + "loss: -2.3578381538391113\n", + "Batch #: 8300\n", + "loss: -2.357231855392456\n", + "Batch #: 8400\n", + "loss: -2.3578529357910156\n", + "Batch #: 8500\n", + "loss: -2.3557262420654297\n", + "Batch #: 8600\n", + "loss: -2.355126142501831\n", + "Batch #: 8700\n", + "loss: -2.3567700386047363\n", + "Batch #: 8800\n", + "loss: -2.3553476333618164\n", + "Batch #: 8900\n", + "loss: -2.356520175933838\n", + "Batch #: 9000\n", + "loss: -2.3572936058044434\n", + "Batch #: 9100\n", + "loss: -2.358710527420044\n", + "Batch #: 9200\n", + "loss: -2.3547816276550293\n", + "Batch #: 9300\n", + "loss: -2.3565027713775635\n", + "Batch #: 9400\n", + "loss: -2.3561108112335205\n", + "Batch #: 9500\n", + "loss: -2.356635808944702\n", + "Batch #: 9600\n", + "loss: -2.356121301651001\n", + "Batch #: 9700\n", + "loss: -2.3586411476135254\n", + "Batch #: 9800\n", + "loss: -2.3572912216186523\n", + "Batch #: 9900\n", + "loss: -2.35567045211792\n", + "Batch #: 10000\n", + "loss: -2.3584144115448\n", + "Batch #: 10100\n", + "loss: -2.3562276363372803\n", + "Batch #: 10200\n", + "loss: -2.3546085357666016\n", + "Batch #: 10300\n", + "loss: -2.3559350967407227\n", + "Batch #: 10400\n", + "loss: -2.356455087661743\n", + "Batch #: 10500\n", + "loss: -2.3574140071868896\n", + "Batch #: 10600\n", + "loss: -2.3562002182006836\n", + "Batch #: 10700\n", + "loss: -2.35746169090271\n", + "Batch #: 10800\n", + "loss: -2.3548736572265625\n", + "Batch #: 10900\n", + "loss: -2.3564090728759766\n", + "Batch #: 11000\n", + "loss: -2.3564658164978027\n", + "Batch #: 11100\n", + "loss: -2.3554699420928955\n", + "Batch #: 11200\n", + "loss: -2.3563244342803955\n", + "Batch #: 11300\n", + "loss: -2.357598066329956\n", + "Batch #: 11400\n", + "loss: -2.35477614402771\n", + "Batch #: 11500\n", + "loss: -2.3572442531585693\n", + "Batch #: 11600\n", + "loss: -2.357273817062378\n", + "Batch #: 11700\n", + "loss: -2.3560562133789062\n", + "Batch #: 11800\n", + "loss: -2.355698823928833\n", + "Batch #: 11900\n", + "loss: -2.3559463024139404\n", + "Batch #: 12000\n", + "loss: -2.35664439201355\n", + "Batch #: 12100\n", + "loss: -2.355379104614258\n", + "Batch #: 12200\n", + "loss: -2.354964256286621\n", + "Batch #: 12300\n" + ] + } + ], + "source": [ + "learning_rate = 1e-3\n", + "batch_size = 500\n", + "epochs = 25\n", + "optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)\n", + "\n", + "# run the training loop \n", + "train_loop(dataset=example_data, model=mmvec_model, optimizer=optimizer, batch_size=batch_size)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/refactor/soil_metabolites.biom b/examples/refactor/soil_metabolites.biom new file mode 100644 index 0000000000000000000000000000000000000000..4bd76fbb2142e9e518a6109cd5d6ac3741b1c0f2 GIT binary patch literal 53681 zcmeFa2Ut|gvM){!B0&(zNhIeaAW@N^Bmv1ua%RXdWKck|1O=2J2%=;p=L`}g3zBor zIS%t5V8*>S=k9mzd;j~sch0@6@6%r0Rb5^6>+0%Wy_jBsQjaAtZxG*rJHOD-;E>>O zE^cA}fuUD^&t(ZL52L?;-oQiC*UvScv%n&<$JCuE)ECb zgvp1c&}wilPysMGu=4Ble+B|l;t!vlTdD>%mlB$Ohx!jR3kT4*H?nimb1(x~UliHv zS=m?`oqN{h#Rp^jPo}q^{(9A~Z^DT}#>%-|M7Y~=Y{)$!98{i@k)8bov^WqOF3k|VXaF8!_xT0uj;krm{y+4@ywKCc z$lAzG&%wx$MIU%x|DW0YXZ+6b|3ei%@GcXG40b4vNRWFpIBRRy3(7ypgYrK?f!_R%n6Ug`Q+SB?oMQ@$ zhcG{#JU8O}f)f5&T=gcsWQRg|$=PIeA{7IqF64xZ;6_aE?ZJz(c&<>Key<>tCzhl33h^vf)qiKCu1 zz~0RIe9WI;sFw+3%w1llKpTLIp0$H1lm+JRFg`PAZJ4~v1rZ#anLWVL$kE`pU|2b< zo~53F!+E5E+4V=c<9S&7-Tv=#s4z^ZUNAc>ftSi-LCaz7KyWgJ1l-^4V?*h%ac>2H zD6VHWa27mX14a5fk4yH!dPz#Bm^2m{9;2(0Spmdo3+dCNPojdB?tMwcW zAWr|gIXG%49fpI0p5sM>_pa1)Fts!Kle@yv`~w{baf#nB!O>ooYYL%5@TG&6!~Dg; z4rt>5urbxUT-L$SL+LO(93c3fk01tUISd!bC^RsGtbrJz<;c+Va&*GMF+uY%yC7Cw zXvYjKhxL!69l${EPk69E>9BU3ApLV5d|097Fuh<2bUMDeLX{? z3(w<#mc!aVFaL9Vao!9`1>giG61asdWL2;W(JfHBn-Ks^7m zhvBVfZv*iQ2o36x_pj=6PXE(5{uBT%X#gfi3QC8~JDe;Ro~I9)OfF!20;R*|H%=CS%Vj{km?q(#Lg}!1hLeuP z&WTAA}e8kOiffQuS1Agmz z&LRipgXzP~VhS|0gDAuTn*lBv4ruZ4MArh21O_x7VjYwJkY?%4#2|2Vq*%hhx~&y2Z7k~9BxQ&C|#9jW^H&r8yeVM zOdbr+uhP#I`IA1d=NFjW^jv>62-e>!P&u&vq~|{GP$)JKCRHdMhBG}6^L-XWvkTiG z_*^Or;ZcL~!PXCu_Mkvt%&5OTOdZMx8-Mhi%sdcz7fxns1c|axr`CY-!R9##9~Z1$ zNaw+<`Q6trdo-c-VCx7jh=#i7(=BASIwxsC>9BPkzp34y$#B|N>T~e@l|knUotyKo zbX_PNh9@__4bPuBUO?&S(Db){(7fIi`tSCk<@#6XzuSkF8$eT-eSd2If^q=@+ZvWH@7}d|3Sp{r<#Ya)o}O-=FBFP&#a!LG}C1VFsnc z=C42MUlPorbXfoXx&ArH0!oL~|3kk^f+ds=!~eH^&~mFQ_5Z*JN`NM?{m#FZUgxe6 z_OK}s9WF>4=NF7m6PkB(y)6Hq_W$4Q11>@zY|)?( zqL=%CaDV%L<$oG(*Y^Gufva%ixFK=M-@}T){=!OT(&RXbrf$?1(KTvW1vdQ1Yfvfmk z$AN#@F8r5J?C;_LJLO;1gN>8_syOg>{(E)2{|Hd!S0WBB_E9OKbgWq-Z)*13FyFpht6oQ_u{bky08=x zLO*A{MJe+8JvuiA95)#HLJb*GdQL|`ycCN8M|p7rea#NJKc~aP85$X1yc9r!>WKzz z7ZCwL?cQY>4MZetnuEiI-k!@AR}qG6Yy4M@V5kDnpL1|Hb2wWp@jljPF0+;&$h`$#qQq=!z zrE#ESd)N!DzhQ?x2Nb|wQ2Y(wf0tqx9U3*D;6N81!L8&suh1Bt_-W~?4*ehtl8Syj z8>HB#o!CYOJSF`~KlfVkNzPNn)X2NfZ>Qm;L~5o^aP=)D0e`Hoo?mXp!j_hsm`lnX z9>&5vT|u+)#2A^8=68=2$R$-6g=lX`l4|FG2r1S5$F0?K9|hXhr)!T+N6kjHDRYtg z>CO&VaeW9dK%ABDUks3}wqJ_Kf4Y-I65qPk1!5={`TTv%g>_;>M}K!)vj`~+&n~)% z=#!=5t4KAg8qNKH!1>u3>jEiT;CRj+o*kLuOTXEXn<+sQ(rVfn>&dx$ zOY9dQLgS~?B=RyXsSj4RrvWDSgUn%q)it-{L9~I^BhS3_q?dja6I+{pK-np_v%K*+ z&kuCpZ3Cp$Yh8@rMpA#59De=)enY@x6mxc9Go7r#gLgm;p~$TTfbGyv%ZWZm;M(Go zsKOzPC&-r6g&wP`z+_9hsIgDsrgt_pFq}QfOXBoIP_?Sj@MedmusVgLAAp!GQ9ZE7mxLQ_kk zA;(2=dVLBQB?56-IuRZA5@xr%u$rL^D=Yu}hOzgAQ5}9xi`Ru{AoHE{ShGh{m~oG@ zdUeHE#+V+cC80Qx{cJ6(jpf~F!VyiYkxLl6WMTW#$5`x5^MR$sn#IFb0nxSU7_NsK z9to{tHep|1vL$?d-dIaY=B7Hi)(|y1S-I4f@Fi*OEA23g6?3=~zY912%kiM%%1@>- z!X(2JM1>9L&5m>Ty5)6^C3~=nHwNWvUE;%l>}`7EQHS^V@O?6E3XD+cZV$w!_9mEesfXy_It$fNeIjR>AtBCS*(>Y;!w4ygcFgd@xm(a==qI<*E0> zlM>+I*%(Y4*g59nbQ4tXl=v;xHTK^d8YVEFLWWl0-2;AK`_HlG*8Y6OEXvo~s+3|Yc8-}mt%gW|F?RFUXS!zT_ z)=)^51F;IBU?5Xe_vbb$?MHdph?e(3ZR(POWz-l8GML{+;rBmh(6;vwE*HvA1(EM< z$be~bGtd_B$0}9stWqr9=zF<`JW8~};Oz72u$79()s&&8PbM5~x``N_Cl-6*w1Q+6 zx1bFRK9Ow~InAe%X`A!;FH_|7H=9+MzIhabAGwb+CC)r>$qE^hk}{kkg(Fx6PkX#S z=ZXlg?-Mht-vRp|lU)OeGh zAHw3@1@o0$GK(*%pDTPks`xr6Y(4D>J6AUHmJv3qX%k=T#@)l&iaQBU!+cgfqxbM| zeTc-)2V42f@tw{o@Tach9b09$22K>ALcOy_>NJ;-q)*a0K9wqVQP|C0i>`?7yi}pD zGl2x2w9{uZ>YkI5&SRT!jp(hk1RA=DUiG>M48-Ti96v+bwh6)}Z-&_B@W z+I!pNR~F9(KRP$5Ybo_M-Yovykb^=g1rq>+BcncITUvSp~Nr6&R1tR`JU$ zK5zOAF z=f}{a_ivNX5ZJ7kie3{_Bx;z(wgr~5pP4euv9$gqUW-B^?WwXerUET`9pxNW+(H1l ziKX9X(Y?d*?UN&c(NBsYSxscqWaH!h&aj8|^##pS+LdbZ+(s4CK(l?7_ioAlQ_0!# zA~IyTBW{o1s*pbEmHDCSzf)qnx?kr$PPQ)UmcFxK7@iCQ1BMv|3oRI=~Q2%1TW6(V%J~IRxX<*s9jn9+8IxJz{auis*0PvH~l(ssB8Athnky5O+>4IlY+ z6-NgiT z?Sl!&rJp`WE#pruovnO~t+!;d(I1QLnkvwQ38y-fk-Y}L!4p1<8)Bc2AXck9(j^XF z-`e}d`u$+{iw{P>7tMm01J#78%@_S@I#pi|Hnn%*2EyB*>~ADTS%F?a?yVmx-5Xnr zxXY?{a+{Ac3qE#`ukRiGtU9c&e=F9hf12=BOzRn6iZ%*2;626G-P3y9EQ=H{fjj!n zrVxh9pe2{~-7X)_0BZ7D*MO`}_sN7vK5vyKf2_hhdaqKIZQIEBPhU(obq+PkbreMwF4cBd{J4N+QE#OV4F1y!v02qk7iu;p%_|=BV zON;^a4#Ys~XSo?nu4E|HZCx~XTFOpj>U28W{nWxAXm1`!lxvZ!D1z!fMbwIr3^cpn z6JmQnxI3O5A%M8^WW;0BP{%NZ69h zP>HJzk1?mq>%Cr}_AD&>>{;fJ4OaAE1AyD!xyKGO%+HM0tJLHpNh@|l=X8;$g*$|- zV@q07ZSEYdEZ!i>lB5fz%Rv}3d;H)9I3NFv7iVFX;B%czy3Gr|@u;qT{7IpDR}}Ao zq@uQl!~KI$c%Zmu*dxZ_H$Zhwr;u1<_SxpWUyF;=mih9_6jEv8zg%*4xbmZ1c&IbR zrPM$Z*-ROcuelt-ENS!Lrb^m%l>SF^d1(Bbf%k~%jaaJ3WaZorbltl5zR<0V@*_CffsIU~BeQHeG{&WRZaO8=CyK`6M8m`)bzMKdO zK4Owp+MB?Y=F{+cp||@(FF;9aL%}m~>lm#1RnNW>`fv1}sNqZ+)L4(hmr|_Eldejy z37czl?-H9$7<=$&L~r%E$mo#R^wHk!#+!YQioTDLjhN+p=NWLx<3!gz{b9kA`Tkk^ z&F$RwPkQ6?1+D0W;R+u_J;cOL3OC1`B^Dbh*)*y5#Vv2V$5IQp-9?-%@(MudC=(ll zwO#xiZZ6i^SS^+2OfnhX&O0*_46Q*q4i=Kk(i_{&iolC+IFfy^FCS$l z+4ZvWRg~a<5p7pe`ZSP_e@pwm6R1ZC;H1q$!MKJ}!|>>XW^Z-?vAYM+|F8A9OTjcL6S*vnyn@2f8AdKP>5K3yi=~ zS~P=_a$WZaHo9^sg=LgJNkwV&lF)7RHOFH(o(UY;?<^&;JkhTT#755yYh^2{5*uw- zOmYr*7gd#oO1|7AAx19a%;WAQQ;**-hIF$%aD_F;VQ=(Ny$0Cb0~K-O(=Uvf`|{iH z%_W=#C^J&S`y)>Tb-zctG}e4^*v+hKc*EqNHJ<$V^U}?>7}sQ7=@^|rgtGjT#2EJK zXn!2=gWL^%tMQ;mp^v%k3Ll|puL+P zgLHyY8T!+Vwo|(XQBgu#N~*=qC^VhO&_H)&_$DuTTMBaS3_nBWI)|HsytnY_(bq=> zGpr0TvIO}`W^@`&U8imbf%TrsEH@dla$WjL7endO#O<{*&@*$;?c`h$9p8TcENt6W zpCBJIqpIhrqBxskAuQZScHiZx+Dcsby^;o(fk{8BkKu0L&{rJOt8br*7#$qp%=ZKl zw3pNYtM_9+Y5^T1F?Vwr#Hna;DBN59iYXVgd(4qN=1zSLS20ovJjFc_PWX!T$wz;N zf(2qXz?8Bxo<*4Z3pl;`@F0m$Mb>UfrUec3ZeC8Xa*Wl+J@l~#MU<@LH_AyQKb>n( zN0kz+kDuzl7sfU6zd7J6mhUkeBkdvaZS}?t53M~JrsnWWecpb*Q?h~O-PlCeIhok| ze&vcbh2f&4yjkuqL3MM*U5`I++Dx3uGVT(ln`A#+V($c7F5bno05W6$SV8t`stO4> zP@*8VryoITZI&g3L`G+p05RzhJ zy$}-J8Ks{eOOz^jkPii;eB2geNS43IYei zv?mh|>eAl#5u|ebrP`Vu`77kCD!LSZD_mZlV4f?W8c^#Fe5p_2G$;puk~{nb`Hty= zyFXUP2{OGYWmt>9siI?x+voC15*rU0?r9B^V}9Ajl|CQy*GxnUXxbV59K)FpsQ?m6 zO47tMy(+0n8%Us5bL47Wq+@LIV)FX>Rx*>@`JS1F1*6yl2hI-a&lWM7?}Xtj+7mLh zOpYHANb4n&%&o((4T$})o5g{} z$g)O{Au^sgkl%1-+g?FaC@mHrt;y0@xgYDCl8!T>wqR9japNlTB#WKdMIp*QgG-7 znWVAm!Jk5@%BV_pGhRP9tWq+RA>0JKAu1tUyRX(6$PI4j)};X5Pj>+iNw z#+z!soeN@Z7IuU&t;Sp|jxIjRRx9 z)W}pEMF|>*0?Cc2Z$`8g7WcSRCH74}DvrI$3VOt_6ok7(f|)9=$mkgn8FKe7x936p zxL4C5{exJWf`csd@M9&5;@dC-=J9#^MJ2|XyGDN}cj3{N58@oXzz zT)cE%r<$0+(?y5!)#M~~@14Ts+wjcgEKXztN;yN?2*>)lzjp2om1%c9{t(dNzjVY- zApGgR83t>JbjGJP0cVNTfeKk3ttv9d(_N3{*M)-q23@d%wKT*!eOnElZ%|d+nQF#KpWhY`>P0 z9Lku%2derd18-JlM?{alYC-59i6rJbj&$=`^R)y>c)Vi?XKEx zvIzEGkWu{nVa2#2|C-mNGf`5B0DK1H;xSyv4`qG~n?emn>$eNA98% z%>qerdAdZJ${kY++*SQw$qxrs1t{QF5f@4U8)rZ zjQdY^ihq1TP5QvG%;G7IdgSc?wZYY`kAUn=R{)(D#lG8q3?*6iJ&l)NDhA2J;$Ciz ze2GlQKAmtMEQfp6YItMkpe{7}dH+Ij2dnp~eU~3x zj%ocWg>jS((~mo(Mc@4%P_fm6(dLFyXRuc~u)ElE5}cU)riAWh(MC)E6hN4K%;-hl zrmXeSopesM`J+32O+EZZ|IR*~TK8EmGUbQ2`DAn3RxjYS&8Dh6yqWGutW%Y5WY^>4 zC3x$;8rc(7o8WX4yt!?y8f=04T|n^U#!Sxm8#|O!0&MJNCP!BzX3+`1mYxYSmc3-t zE`oHSCxAlJN{}t!mu@K)+T5Q2HiNwFn{tc?S%h* z#uU1V&tR?65Sc>4d*JPQ1yek}R6IEh*@*-C{GJQ3cZY_QO8eMJ)PO zv)Qu^<2MC)4$bccqxEX&dB>0r@=yU8DWoJQOQxzcpzK8oXKU(y-Bt#q{iS&;o!FRk zHJ#F=!ZEy^tf@@YdADtR+W90kKEy>fL(RSDGJSo~uVmMVd}yde_6vT+DBHKQJ79Uc9 zDvEN<8OKClwowbFY6rIEOae>!jQ7-SRv4JAgTG{Xqu1jQMI@|K=*s#{b;x$?@wSBc zBSrY1N>+sE?D^Q=FN%r7k*xVO7Aq2lI}G0CC2Kq+9LN!^u(M#1kvw9KmwPx$!FlN8 zWJB;Q(051N>+F2~?JM&gG6X+q}Igo-2 z*ni=qheF$luq)&xP%6oEBJ`*>>m+RNJ;w<@@nM@lo1PGGsEPq)`SIGA4x63nciO!L z`Hq(3TSJ7YUC%lZzdvC0IbhdRT^9!;CvL#UE+6cniE8G5K=o$RRr<-A7he1Lg&@B< zPe6jRNr+t_zRYtQqGY}m{uZ6UTQ7Zh8y>i&N;zyjum&SAyq@Vaszxwmtfcd@?-lG7 zVEeSYLBya~Uix!{UW#bu3+OB77FhT7P+l7Syj$(lTOsKzfVa=OA~gwj zmGDkVAv?wTnKJ6Jncw!$MwR5rh)7>fGtDEp5ffvfW-Q(e% z^ie?Gr*=*sa{U}b`K9wa>A^y*3F)bxoTk^{Moy)n49XwvObfQhrG57zeeP8WpB(;E`NyV zti3ri6Dc`py_v4Ny`J`E;_P=nyLBH)M1*mn0~qgWT@uoQG&8v`7@|tyESG5eqHCy~aAKvc0vo=P zrr#obf8U4 zWhkZka3b^YE7H6C(HjYpJeyNkn`M$WunL4;ah81~KZ>}4dDrE&E|aG+UDH?Ud{pa_ zXI0UGu`AmW%f4Envs2=kb%~g^zU1BD_fCn!veB3{0Ru}1N=enyHLvrNA67|={0Ibe z3ig4wZ3Q{X`N`AwIwNBfMaBj=*`I!OEAopwJW#q}erhJYdjHnkipl5o#38J``LIPY zR0Mkm?*W{q2x#dQEWk+<~VfG49M!VzC~@BOz36ouRu9|tm^Sb*vypS z)R5%k6A5ggZ}Q=ql5(uC2V#H3NsnuiWrm#u+Z#5mxD6OiAGCL}bQTlV@R|?exR@-d zNFyXY2zTVuov+L*yyt9LaHB6b*QdP9x(Ai`yCYTSqdT)QSr)^OQ{$_`d#yrW8)mim zM+cz?;Stw5Mh4W;Od$Oq~E=zE8Qe%fXA@-C#W-kv;8PiCbX8oPl4 zLLAXC09Wsg(~vIjfNaz`>Vc5elFA%%m6KV3hFDGB7~YcK5vi96kL?QQ^gFyHWMP0- z=kN;;jnhTl>O}Aix@+vVt)3QWlwt|>BuBg1Ns53mIB)ey6iIj0SD71>IdtoxQlooy zt}2COp!mg-QS$vyPfrL%!S|SB(OrHFFAK%SAdufTZL!2&{jj~Rd77W&lICAG%1uvK zMKE`?xj&&5x8`Fn7hX%F6!H?)*I2Z~bH%(PKu$srw|1ONLk;zYVvU+XWhCY5PG&Qc z!AD);Y^?8a@kjNz3w?FK_wD!hT${c-QuWG$} zlqkCCn-)J!YJV@&h?b+3aAWt?WQ4g1)M^SZl#Mc{J(0t8H0EOif8Y|4p@hS1hyH_7 zFNF+~z$rCrk^Y#wh{Y_wthU-nIz`hJ)HmRFCdcXB9;n|regK+PC;p1j{vxxph2KoQ zr4V(9BWpWH&I$j>W8HHVG)mY+0Eg069ww3=-$T#)`1a3u_iTa4ZnL^W1ODPFjh>}-`^fk zWn5ne77pgDBk(zMO{C&%vTgM=dy}9!AmR4d%_ugXxh&$*ksxNVi@evBMyL@iEBaES zc0gFY0diQ0^d2{>J{5jFv)n;s6F5vDTe+T~ z)xowOD0NS)VN`67%2|@{)O4xjQ~6e($UL+Zxs-3O-vs z$7%0370Ha#P~OR87AV-VTsbfnqm# z;)(G`iaQct-^96ixnX+=q#C%Lp?GqRNbtO<&WLl0dL(l5D0z%9XV8Q)F^P|~2vIoE zaUpMf+(OTq`T-!RF(?3CkLqZw;lDDl{%UXY(B7p+VogX>(8IBJTuWvPy^4b_-|R!VR-!jDZ%} z4!9K}1Pgs2uiY(hzvzi{#?EK`xn*CS8)v*kwpA}(4vXHiVDH;}QCr5g6LilgBU{)< zok00YQ#a1!xt5q(&()|(MAir<(*jsg*I}W)#H{?@7@1Ch78X{Xzzu(o9+g<S%a2+&t2%etT9#V zPSZ?1;nE4T8o;0W9`G+r##h&(#Ctt4mZSYX`xy@%PUkjZ4ci?WsQXy#ATw^CB0GwvF81j=_9`z`r?Nj@>kD3&J!J=F$| zPqEXt`RzOmUAy;l(`Ts5I-&7E1hjF_s7!Op+hCS3jWFpm8Bo&OW;KrD9gBpgy#Ij| z;f)CME~IpPM0=^%jqc6FBgmBUbEXx{^N7v?%$2H`+n~@b53zKX`olRgU38~kAZK&x zW5-X`L|gMWMIYjFMU$I&y9hRj=VYI`OlHn-)7U&7md^9WIAndth2>ir*o%!zHI{i; zh_|8@_mupLTHs#?G@b4bsE3 zRS>yEZIb~Vi9DTw*Yq_>@>b!>o#*()E9@Nu^3P><%Fb*GS+ubx?%!buH> z5$jucr@C8D3vebFDSPI!nS) zDTXVPrb9f<3-2jsnwdyVmYExyK!-Oas!X83Plr6RrQ9WIW+7NX_UVE!C}mQ6oIX3yajqze@k~C# z$l^s(d*8iO>TPK+>0*>kqlNUD%8ZQffq>1Zr-v$vv`eR3@{xDOGjnNu?49<^WJ_b* z-|5ej@P_dfdhNuQFi2 zp!g4dBBAzw^{)dje=6~hN1y8bO{)d4vnHT)*iTA+8}K*vu8Xe`xJKX_folY=5x7R+ z-xmQFowXmlSl*&HkrkrsE5Ukk_&P2=}v_WMr5Ok+Z`Y^$4sU`8&0lN2Kx!{Lv!hh-Vkcp&ZL_b8%J$< zx^-Z0kv_Ps+^@Ch5Y-h+{Tc-=3;$rU>pK|7OO1QuD+j&l~>0*I}wiLR`@e;;`8 zaqFbp71am3AnuTUlcR7P-N8#vHlfjENX8xOM-7@gdcp$Zp+_;6e)yb=*QQ`%yXOH% z3F`i^yO%J=;Ep!&-F@dS&{4%KXjgm6!!r0 z9(HN^&*vI)Z836t{n*8(e{{zN;#Gahaxe98{Clut9x<#V2lY*_TQ?bjEvD{SdIy73 zlDS$Tf5HS*divEz;=So-Kb2wA^p|pW4n4Ao8ZO?l4xt`G?tDufn)3Ts})os?f1mthvPeBU578l>)Qfj$)HjQ|{2uJ>2h&`?y^t zx-`~_K7@bnDHQ&2RH22+=YWtyxHMIdq#F9aHT(kKcp^#@6Z&x@xgrn5VWAM zf^{_cK)|KR&6Vzk^_CsD)}cN${HfEMw;`_7C3&Nk9Y;)o1uvb9z3Siu@M+bMsCtTn zkPBY0)eD|Yx(OR&dc^&kmV5rA%3tn9N5(8m6Kqs_JE*9crA@UMixKifZn9nA?#5g~^GpC9+|e_{Wm0eYPwOm<0!EpMm_MiiN#Nqk~&wp=!_&?SD>QTLXP(dHAe9GaE zk4#)L{Qv#*#NYWZiQA=o0jOVI`tS886aG*8_rH6+tNsg%W13L^f!PmxjQX#7db|3` zgcHxpmSF8rdtK(md@fUVs3NxCp*Me%&_l~Bp!Hz#W}tb*SC15+9{PX4Rnz~A}r)$#hz)?pQ)ml_hn6unOW zMXQk%fAh%%#PCaHMgHcK3Exs+J@)jN+%*E%2wWp@jleYm*9iPOAaM1Q32s-`YgcW6y`SAGx=Q(L9`+o-QS;aG|E{_7 z!`y&^KSz{`>8k?nu>={KQQgWZ+KBdG7;2J#Tug7&dM5@q<*qbikYvhpf+Yn;MJdKK zzq^GQ&u`!D-RP`5X-VjE?~3CuNQ!G{cv(|dG%7;GK}IdAMrDRXQXJ&efyW`PwOb($ z&n$N+C_6ZWvBg^L>HPi8G%w%%X?ST1S`R@dHpHf%k4})!PO%5Y1QXZ&2bOzF*{gPL z5;&WvJZ}IEs6=qGvDpMB){S$nQSUHjeKbDXp(yPvXrAUb_~v%H(O@^zq0Dt_T;~VT zcas|PevVlZ?PCM7`i@zFfy1gFRmZ}Qa@ThN50~{keZof3?Af)2`F<9o9nEDG{0iUp z>}MMCIT;O;=@RFAhv$mG8C=X1f>fpFwA58vQT_qbg!=tEG+y@i& ze#uIl`3|1PNl26T@y`uh1V=~5{J{4WFWh}y5RP1KETrXk0WZ)816byifvMfQlm z-k`-D>Oh!Z{*nzn3VU1Q9)EPl&TPBNlht94#FT?j9?e zCa7yx(ET-{vcx>ZPPP=b5jcvLcc8=C<-@9FlH0?J8dwyneYWgt<2L3O9?$itTjxYV zKdGHNI*RVg5tmZ?geAYQ2OeHb(fG8^JK|+tH`xTC{wH84c8E6H!4M5uGUYz zqQwEF-4Ht!);SJc6$B_L++N4FR~AGM4nQO|6c}gM(z3TZz?{`w`ewzI!G zT(AV`kVUrxN+V+S**e|*?AmubdGCpkTYr^&ut%My`6hQ3-&1wHY}F9E#l&HqYYLaW z-GC`1XU=D7J3B`6;7rsj%*!<@%%OPA?w~jQ*1i4GPPP#{p6S<1_rjP$LWR=+l|U-8iefXC zrkT|W>)WWgx=3}#Y-4H?YEc6sAp|qdh*RnCsm?UvMDB5pFLxT=WBXSY_Rp{lri^D} zF%*|3I+W7ha#6i!-YmZq<+1D%E~=gFAY>a$4E4dFPs|QJCVJEPyhwXnKe}v` z)M<;&NSaWlwBmEAE_vrFNnp0-Y3$TMSH67yb9r%1&vGiQ;Zk-Ky*D=nP>3a>?Uda> zIJx^j9sF1ggx~d`ATR1L($53ExaQNOc7;BIK6_ff`Jt=GIR0?YA0J$HbI&|JC@;VS z3o#}qxa-xjIJyX_fY#(pP0YS}!mU?>nsoX~cOMb5CiSCen#>jmK3>p$B}cY(hISyS zUO}x@+*BQN+HF5Yc(aS){otpq$kwD9b3nfs*(rH7rO=HH#B^0=rpCTX7AcCp3g?34 zfh}L7>ItFf%*V&B2#-KG@&skiyY#VqOyk z9RiUbZB`Vuk&*mjsLplPa^#T$>_1f~;rM)bYGTeb#NX0@w~kNfhZ&MbA6~8*{BSzu z$yYb`o5vnT-N>3Y#W&gMyO$fvQ*j!59#s_&Bo;pu+1(J*LvBMsa0IQ%Z)&|8Op^&| z*-v87C8}DvKXl-^6`3F_6VE>3GOQoKG4WKg2D<$k*z>$l5m zP-2AA1l3wk_C)ws5=$v(&GaKchP}h2FMvy{g^f(p{S=j*3O<{0($d@gE;d3*=&Ifu z%ECirRoUMMJ?@r~?-Z=M6$^zdoIEdIPkNzdP!dEAxa}vDQeHfwEy7+S&-*xB#QoJ# z{S5}=H}b8GS)8vrp7?hf-PA$78$xcHE%Al1kmt6wT2q*(+DIp98BcHSjvcwDiJR)G z@8}5U!?l$`9N`xaPkoV)!XAET3Dx*HSE9iBvv7oCa%+AtF>8u3vOvTF(V-EUq@>!z zhfe_bV2t8spFWxg7uP!$g5*kr=PwCr$BT%Aw{O0CDSh9J*eo(%h*F2<+W$UL&#f*q&x^R*PIU7S!8Eo7q*7?nXK7q@HEOB2$fK zoZC4ev<82y=aBZ@!A^SM9y#eJy$v|-^T$ zOWI$F!t%kf8Nu^e4=y5w-a5gg4o#zwlacXLMqdD@UQ<~pW~qGhl#<|MM>d6@mTh9A zUjZ>JN6yn*ZjG2T6Zv1s8Cbc~1##zrEvp*r?b`Z*d`-5yDJ7-hLXuK&YgEzi4dk|m z99u*|eipX^!g<(7-Q>Qh+1#P`0!}Yys|TGZ<2zP;&HWbq=`5pH#D7}W^ms`djo*xD z-&EDHMN@;_&0p=!N8%%p2S^ora(`_l$V*|OBi1A30q*k?4}r&t>a7b&Y_c9U>wZRu zV<0bgej8&%gUkad;;pv?6?h-RIo>)XOuV*33WQ7wyqJ)<9zJK>)x4x_w1^%%|_eNGuqkllLEv&CqnQsGj858`w`L^JYf12 z=T^KbF8FC=+9}VxN2?8JE1AUsKS9ZWs}0Sb`~OtM_=F?CLx#9U!JtinCoBs(aQ_#yPzsdML! zgonijOTN!+OD>^uG4m1wI~~PL-=V{?#Ouh?MHe+1;*}YpvJCp}2(F~pK$cfdSt8sc zs0Sm4_bJ3LppSWYb$04Oqv}t`XBR8ZQ%95z@%&aQ&o*(Yppe2f; z7Giz0(_+f%-nq8$CCOz;W}BH*`g6HrRcZWsT4&@lK}i*JSCnb0>^$V70hTiBphK%P z0-{o#wv)Wr`Pr{p4C@b~IL5c4_seqgKb=)F+C13C^wgR~4)4_l7&3M?Kj^Mq5DBgp zA+}AfKFjzi^s1}>u;I-{X*FFnTOKy){Ki91GtJ82r9{gf4LAx}W$VvFiln>8OQ~3lj#o^)6qF^GRDoDNT4q^ zfTyP`o2lntA$XYUaJQ{QxRlfTbcrD&pfp>}hIS}-a=>dRq}Z5oNbkeD`}GQ~&z4(& z!_1a8A{5fQ^hfPaSV=tn_vdqhKRmGoa{BsBRGmy1Q-Pl&?vORM@Ur-An{V92uJ5WG zl3roZ&DuQM$iV2AJbQ_vYV{=Ng?dIG**dbVVL%Flea4A_7d};$;lVS`AtU7^OiNvk z=T!Ve^L$zQy@djOF6S)ub;(l8!K{c zq3sZedM={yjqSD-QcnCZFI&)z(0$$2Xp(A*hhLKEdS$C4DueUp_Um8XrV!|5ARq1o zd4|3rilT3|6hCe85Bjg=MNl^AX$PSk|j%4^1uLsASjYSaux)XoQD~dEK#E5FcOrENM;5G z3XU)w5Mw_tpOYTeo<1>hyVWx<6gj^{ejF-}7*by?liMW4wtg zfAgz4lyhLb^;uI<6!Woc4p%SZM^o18R?g(eTBK%8z8b#@JEW?CmX)+g(!7tH5UBuR>!pbcU7O&sZ0)c> zWdm`hK91r!11EZu$Zfy=$t z3e>jIXr<~rd)>d9E#_}|2d#)A9xjtc46%9#N%7Y0)ZH-bT@FRQ`SmX4-D^aSrfefu z=hS77&duk+0On$#!%b(9iORH#mDjnOBYs{4@nNKGfu#HS2Wex*qelnbGp<%S7z2`e zX|hNAn08yQgP6XFjcr7^bJ$vH4|AyYwlWM&g z;qILcn{4Xp+R)HdF{}QjOVS7h>F4kbtjY;%a+2fw;~A;_uk?Dw)20oKV-L(BhTT9Q zg)|&9THEaM(cXRFOVooRaI;O)`a=CvI}d{@(GbqIC$8fp-s!Psf%$isam~+o%?7 zexIlX6$GoKPWHE_B0u+VC{?=!E@l`~jLjVS`qe>ylCZfPhIT|T4_mZNAxUSYdU|Da zGA=6ob#wQMdiSR8z`-P-_2q!W=Zf!@0xl5)f|!Msmd@y>Wt}>fhKf^4AXE*Y*^j3S zI}wqssD1g&(P7rtd_WuP6c_w_R3E0Xm0)EqsCo_<(3?~_xPx&O{HQqo>(ol9Z98Ji za*n{vQwRwfHBmkDI-lWdg`>mAYLx_d4FMO;Vo~tP3QsOy`rVblg((7b(xE2NgALdL zrB-YwHdnBBC~P6RQ*=A@RN(5RMYU3!e3cKDa};w>tw%-z{5rj8a zQh%n|n9FLjd$Z?ufO;2NfXW-qShF-1KoTA*j-Iate8$RdiOa0pQ0CR%SFO!h=<7ZB zm{!5L#qSo^C@<=c{4g(iB*(O4cRIQ$KY^In3-ev~LHFx1?r30A#0|DoWnLa034LO& zyq~R)OX~ppLb%)XkGh_IF9_x;ejr5SM~#G@l>DYNX&YT&7b|K~#=M8n81$}b%6;}3 zYR&BS;0Ox*-1->cM=Z5=FYj*2+DjRTef64p>6M~lAwhDHiWE>2{igksl;+T7y;Cvu zo;Ld(Il?OvjI{veU$5(!0_0Q826K5*5+x0^3m@Hc*^5>+spKhGy$9g@5F9P!5a+6U z%%7_BDK1J~Rdw1@eyULK*Z?(`v8Q+h@LKQ72wn%B)PNY;}hONcFy|81_3jg#Ox*wbtM>)%Nny@U3sG7s>z3wz4yJRbhtdUqmv4g{j zy=OI$)O7f$6=ra=P>8}4wdcfx$ebPq?vA3FObnw@r^0TwnUKc3C;X%72*IpAgNT3H z9YFaIye=zI|0sb1?mVTxaVP~Y*}}_JP(wsX4D?1RAZ4inK|Rv}W$wrb+HR)m(bM*R=V}dl4Th&i2rei8Q$=4h9I6%k=952XAXH{N-cLm zoA^#p^#JYup!h2DuXB$-g^xA%RmCELA&xJKZe1=d#@R8ibSTafe@lOxJN^5E`VTsX zR?Cvwc5J1SE{ks2CBS-CBg$3R{sbiEMCR!Zmo|lmoVVL`K z2Nu9dr%lBk6ZEC{I4l$}uLKz=tf#fM#$^b37JfXyPagC|FwT{V9BU-7y!egYL40J` znFt$CD+JaU$+?4Bvv$ER0DX7X2}Uh6fJDByhG@m8i}5p$PQ)XsT}0sTEJ7_!YUy*P zP>7~lY%AWvsmcAY#;cVe>*=Vx*2z<&-)BFm7G1#npKmU@nlaA1_*?enO57=jC1TNu$!eC}1@ z!G1z6C23toP8LO^NpjCcSB0kosv|Xtp%=DKTTa{Asxn@;4YC!0Asx7d`z1LaAeZlM zuuA2U=LBjLYi$EE0f_E0TlR6peQLQ}DfT0eJLsGR7WdHV#}+#PF5Bnjqs}W6gu8{k zHcxP*Uxco;k|e|IRHjIckM)Pp zJesYLlPUJOSWizi=k~xs0V$ege+m5%vhC5f>GfOML63SKxhsNvm@F!5T@Fk?rbkRa zHp2s_xIE*?kV8^t$sdoLnTS?3yitX;ESxB}+_WIMO4T^3X-Fo@fIeqchb~(7e{Xse z3oYhQ$d74-`%|nfO)tMkE+5g zH}@Ty-Rfz1n`C+k$ticB+QB1u3#QI-J|##q74e~#zbZQ{!?yh;CgiUa8wghg?=EKnBitci#d7ztHc_RA!|_`399cD6dn$_zAQ!ArAKX z!UD#ppEq}>-HdbVdfO^KAH;G1jAJzBpmM5x*=m2)z4qqeV~=*hxC2%`8wJU6UN;DI zF7vT|2>e~_Iap7oS}|9E?eMPM28*RLgmXnd_t%G^Qq*Vkj?(#+`kmbiIn{Xn5 zB01}1eY82Kd&%^|ka#Uuf0W1dXyy{^DWtauT4*O9bSayx6%mjuMhnPWJpEyO|IXCr zv2&2mRI|?!Ezx4f4{arFVaqx^Aknc0*Dui(!XA~GQZR$HynVxW6zgANRd$lxee=hpYJ~~wErmBwp{5zIY@2M5W_3IygDH+C)F7k zNB5PqoBtqj4_C%MzJ}GOv|i}(;{0OYtXuISD~9w`=N^+(d#jkcWhdE=0=LXS!f7Id zU5b9%l6=iYv2>cmB=OPorCX*lcO9H7qL&<*-$7GDv47r%s$sP9(9oG&0ZPSjZ6 zQz{WJ_>EgPAl|dTfOCC*|3aHjvTXluB}cArqIrocs2j=8qp14>e=Sw*dqXCt&Ax4DXA?k`|wAzR0`0B^m5%k zF`TyiuWu{2?Qrf^?O1W|Zuz(_>3u!Zcu>eGnF-9A`dC{l)riZBfCs_e|6G5=N8TIw z;CBC5A6nzX?-TG3Qj{$MyR(0K>&&|*(C46sx;kAJmF3!ZBj75y)Oaq5V?Nn?nz+B$ z!!c4B;(y^&LH(<#S3ZFlB$rf5qg2vvs2K{*{aFv8>yf1`&;oz|9W95g#K)3wDoA~* zX>M5k4P`1-7gjqEb_aL7>BhSdcqLFnwJZ@FNhSEBsfB`0wx1VZJ1Q1-E%$FU!Ncfoq9zszO8C0CYc;wDjtBoRQnFkyW(IZn36P?uqAQf5Ktxs=x(;1km zvtdR1NiktDWzDS>X^+n0tSb0|crCR{c*z#q=F$FzdC77uDH^qNMpQ?mtp$VUeor35 zsAZj4%}~$2`#mihRTLTYqvQhfbjUiUc?%EzoOak2n7rynEK`8gfX;W6%+d1_gFjAn zZ#HnF#j(XE1ytiBqoayP1nMWGi4O~+9jOoHI)4kce(j>bC5?fu?Y@-)N?e8&`m3xHyqNmPp9)(?6T0Cj#3&F+f zw!Onyw>igv*E~)*%C+_Bd$$K6ysc=MeobY`_r*f`&GhLx^W$KK{jd53%sSNTV1XT% z9s|x%FOR;r`$OyoRZT|xXk3+>S>mJn-(7*Ycs-sJuPtyxx%MNVp zlA3`kKc={K(W-M{4yIJ1IBpQ{c~Y8w3p?5Js_rF;Mvt6_vK$5vyE2l_T3r&*dwd0$3e9iK7gN6^< z1YgNR-<{DY(}iPh=%>*Sw>be1o!jINF|QKB~$BOvasecsa*)RLqf=t_%f8ccUN)v*|w!DJxE9I3fL)J}7EFA|Y z5(U)XR~x^N&d!)3LXpsbfftn?&IOYRA7^{+>|UKYg2he(Bfzz*VuZ4EjG3x7ZQ0H9 zc*qX^%5TP*17W_=Pvr|_w(3d? zz)izRMly5%eT4yzvOxOPa-Xd4Kr7KoP->XGInyTYB8?z@PkLeI;!P-VOv^_m^c2lE0=%tq#ex7X1 z3(&yrm@rkm)~cGY;DT^ZiJQkfT-Xb1pxwBcVCWHQKv%i8A2tQrX=52|ar%)Rf#a`Y zL5Dt)ZuokcES;Zj`i85+KRfGIWYgf{h{Mj)ul0W8N%zNE*J3sECMRZ$)J2O}(b;2$Eb2I7p%F50JjE2LJiUlDjA zw?C&UP_CkYf9`#g3w)9p;|LE6^pU{g%Zmj*Bw3sP@{M|9#O)ai?^*@P_M^1A1DJRN z{a|`U#W?A*zRr_b6@p<-Rv**(h#w^l+KkdN3h z9T+T2aZvTsPB;#&XyKRtG|JpjnM%7C9?bELY2rDGJmd|x03&?FGm%Dmcwos?LaO8a znbcauktsQHSJA7sBaGL?u4o*ws7wTV^o=s0+U=KI3(dTlYx$R*_7VlFVf>3yed&r2 zkRDHM^wJ*T1Cch8WrH1gl+kEI+6WP574wXc-`}Q#Z9CVG5w=aW4)MZkwcHbesnFdn zskf8=6}@{8?V&Z%tWP))eIl3CG6N{in-(5gscg@jLRqy_i)rjco=e0e1?a+@@00iA zSD@v*yk7pqn%C%-D()JyRu8y5*xObEgt;#Rq%RX616j3)Y?ycP!2qF*vM4i`q1Ek7 z4^71+i8s#02m_G{|i>6)wQb zvsr6+=U`98eO5D{SQBc@)>7SDY*K?*sMpKFUoUP=JNDDT1u+$-;Q^n$MHp97U%zsb zTU9>)&3gfE-&wjKTszeHj-qf0{%XH-MbdJ7PuC; z7PuC;7Wg+1_*eDQpRkKR?f--H<+t7caQ)9I|NJET|2=*APpa~NzvZ+xe?Ria>7OV5 Y_u=c_wZOH&wZOH&wZOH&zk$Gi03A_f#Q*>R literal 0 HcmV?d00001 diff --git a/examples/refactor/soil_microbes.biom b/examples/refactor/soil_microbes.biom new file mode 100644 index 0000000000000000000000000000000000000000..0791df714fe0f8b6c108f88ca965ab59e6f6cd3a GIT binary patch literal 81409 zcmeEv30M>7w>L^FRTNaJXj$TpR*`)VMC*bUAyulV2ndLX3Sox;NvcS3g(_7ttqrbKdiw-#O=f=ba(v z%D&xumn`~Wk&Ni`-FGtI$Sj@tD*i_Ze&$)sn>i%j7oR@~ewjThBXb$te>#8m%RCu* z8PT~_;PIuP{57((2kwHnUk~lsBLg7C^2J;5GMQjdfLMWf0gv>#WFj9SP{vU zmHA%gI~k8tvjSJon>TZXD2is={BGuixP5ui-t3{7^Z%0b?Tj3MtT&*ojLg52Bf^N~ zh)l3pW*?}J{EVE_7++6kf6V~j^E1eQD{tjL>cZ|10s%){m2m%lTncpti*2Y8RDDRDj&`)Q}B6QU1z;@{=f-&AMepb2Jo zzW)2a#sZeR&tW+nzi*NFYh)5FylzHV2>uBm2{HkMU zYM^JTZL(d@K->9EcN# z6NocjvpK-)oEIj@Ys)M^td9iFL=v|UlqZhsS$R5ZHt+TD1^N=;@8&lnadsjl^S3=f zoc5Z{yUv`)_+n0rgzuM_9|^etpHO=#fpR5q0H3?C&R+jgSInb$J8M_KtP*^R?Fj9^F2Ltb zSAUOx2n5>|$AS1dP_C4tD1&$jxqt%ja>e}-Z$Y^Luh9OB<%zc-4&W7@FLB&S;6?nN zclL7i$9Q60|K(jWGGe|Y@BqK?dV|A@j07Iw8D1YX2|I%N13W|HM;(@DU4Un(pCLe; z*xs}90JA}1GHYPl)su?bwXp5#0sKM*u7k$`{6hV3J&1$)i=I@xZUAvmJL=8EE1Vy0 zgvSHC!t=BV9uM#e&yxm-7sn-NSAbV&e%%Z&cP4J(^aj(iM1I$k;`K*RE;Jsb&9N0$ zPkn$_sGi$k_0*rq$8dUTf_QPf&-Pzm#4jYzZ->`YDsHvl+n+OTwc+t|@~Z={FTgKU zUtJIf^+$bxS7>6?191}e7sV+g&+5bD06w94%K#Q9JVK*GQRu@;-LOu1n>;? z^TQxcf;SPrkbHdv76*uvm^RpfI4Ez%0KXDI5O)+_F2FN1zaE3fiN;lkp2y+apBY!- z@{m2eT!3F_`zJu0_&#Db9~c9?&f48Ybn^y$96+2z|D_Wfh;szDQ2QGLd_voI0&&p1 zVLWHto`lDX@-b|l0pWkI@=LC(8iRl}VnxAyPd9&OU`fqce7(CbCcs3&r=}+2BHe_&UjTEU`M%C9{IWxHYA-`#Z`d8u(wb zJJ&kk0O*+Cfj3~Y>wq$IK41A?Vy>^q*zc|dy&RAF(P{AcBXx%^i$UjNBh)Cj)G~{e8vH>}6z(|1S+SPeyL`c|RB!^S}{J z{9Ii8Tvxo61X-@7hL#L0M8#<_|LO4GL7H@)*s+`uJ}P12tr z5*$HAP5X9D@91JWDLb0k2M^>Fu^q=NBZ}P@VP29SH!c#g8jF+&TB;fD)4^RN!NKV% zggl0A)m4xBUO{n18bTBKE{fpvg6A)Y5P~^L;1^toe73w)B zf=8Q+R-;t0g})F+T_^{Rv{lpJ{=xqg(`D?+%T{zU9QXXKoDsCThgyoPbrK*5nv|`JNvkSp zrTL!;`B}LirWMNim*7`fEV`vE%U~$)WbDhcnEz14BG&j_^7JxP1VU-E;>&Y>5jwoa z=&rR4`LMVy;W`YdK|sTHdQP!~O>e?tGN^ti@_S?Ct@8fj{hz5_H7>l)8~J|u<4IVf zPO2upE4eq78{^)R72Tnc^uk?zy5k4Zjn05cx9PyH?I8yEWN$gOLzptkinLWK$;AWB z{g$6%Q0sG7;8!_yrM;+i85kxoRidXnYjs~$x#rt#EW({lhit_jy8!+t8i z6+KZEj9xRf$8Iv5ybO&^$UlOpcf-0QFqN=RtF_8@@;3+?uVbPk+4t3-Cv*1T%82W^ zsvl&z;p7!%UhE+ghjM0LKqG;dW~0%O6q8Fe7Tn!2u}?j^#nr-OGDr zWuOMjsVl{O`*!zYaw<7>i z7-?kT&g)&zv+rpfpr98iwpuPLUqNoBq&|&ZPi^Pj)!|w5^#vui{Qbh6yf^25O}`ON zrM7W4;+8R;e-Sd%eg04`tKqmHugB+ZkFJvE%CT5NjaJ4`b^L~`rfBlOV{8n5cR^>u5X%MSBCt7m&64N+ z^mFhY#A|ze29bXuH-Ci|H!gHXz%EoJfxq`eu`;ooaE{r^yk~Pa{h3~ir+Sd%Q)N#J zexcx?R)YJ8xek@MV|ay6>leS_m;9@O<*6DC>xa$hh}8IrpYx6vREKh`%wAUt<1r0p zj1uExS~v=hNo040QOyD!1c$qB_(hCmR`U8%BFvN*p}#~{(3p-7aDmiNQnY6?#?jF6 zG3_ReQ;Ji~Qztj4@1?2QE9N4?i0s}gg5R@h6g#X?*OOYKPpkQ7H!=q9hR7+l83|m} zXsD^d!PeL4(A3b0M&1MUIJrSg#(I8mZR50@I<;eMx=GOF;?OURAztju?es|g;H1EZ zv#`Nhtxla$#E2nrN<(4<&OeUe-p^weG@4*>H$#(#uXjYhs4aU&30FOb;Ia!oDI!C4?Jc3X7U78fdB`2=yNFz*`ocy@2G4*&REou{Yo*F2C_<-myqV z8KuLujxhX3zz+2Xvd@Az)+n3%Z0>`fhQmf|X}!_Rqh^zl6dW=V|0YVM(cF7}>aM93 zQ^ei?j^pT`l(AM$6f@VmOi&_+N=~z2DRUCtq2&;bv=LIV%t+o|qA#At{**lCjSu+5tvfHW3Sy6LidA$iBvjjcxe!q)DiT! z`fXulSQI{9W9R~XCxU^?ZH!HmPj#vaI>7zVL3@*ZD)HI)Z(NQ`!EdNRv%U*cQ}|*g zQ%-QeshQI@Zi~YpZ3IUN$20_%a#ibxUk&%>xem)BS!YPKoGW!!9hsX= zsR&6Q!>^je22?Y;VL~IOb=~y?am1w>or@<8eKv%VykT;MSd`1 zzvsCm1&=k@^|F?5m4|PzJp_y-W=-~E$~bV7 ze#7|t3RqO9>d+b9dRhyuVs)hxF_UPB8r3o<^kmY}?FIW$XDRpbPh;#!m}SZNH2hR$ z(*w$v*bipp+OfRBMy6fu7VI*$TEs-!plAOmt3tr1#Bxq%vdiejRfA>$xk$Ay(y+M+CnuD@7~nz z)x<1yPS$WmTrO~JVqbG+p%&*A=#H>0i<2DH136wzF`x3tG{$Y`MoV46>82KPUhxDb z6h%f=8Z#W+6UmVU!CQJd^+<_2LgM8DgOzFH`M9Rhs-{kTIZ6nv8MsVOVY>>iAT}UB z>za`FHj4cMQ~0KQv=2!R>t9A@<{b6OGnz_j`efGrB?+ZJ@}r>gP2ek=?t{bI8!d?H z>Tese(AiFxZ#WTz2^qoLGP*A|qO+*HPN(3)D_PX;Kj^CJD-?Apoa1HVOXVm9rrS^r z*rpI+gtHOjax5hVHT~L5CC~;bpWJzf9gE!F*3|lYDsD1czwp!4ZDgu3?hl{-t@$P7 zP^+?O&+{lcnUiAtjU~PT8P^y>;@_rM72iQjrt}>wEw`Yr4K98>iB_H9WBtp5i_()c z`rpU*?&m8|r!ykbHQce=)8YuF>xV08$rHRdXKLoT@)yVxU4ubB7c7HqbN3KCKHeqK zNcr6xJs*VzL9%GggrYlqUe zZ)CbCb8vs(ZQQnu_gRHxydnL1)a~jDt$<|uX*uPdI|#pS?-x#6wmJ>Ss`8R@(<#Rv zne{~$UQ27mq`NmHW2RgE-%i}B;*?JWayRz8NKP*)zDnpWMbQsCokexXCvnpm`K8uG zof_&!!{OHQd1Ga%0z3XxUUt_K{ChLDkGf!3@v0ohF<_TtDGjByBoB=U6ZE~bFN|MK zGhrldKn=dXL;INS))a1mMy*kx+a{)F7iluTA^%LfSot%-j8I~Zzem8YcOCBRyOA;Z zo8z{Oj%0#=4QnXOJKpQ+?eebXCytx6>g;?{!oHx5lD;Gtz2|BYhw|&WH$R!RMow=? ze>?t^L)*Bru-^lj#_xJ;x(ofuQn7PB&C$}HpB}Q$eGrpwmD*t^JZ+v`V)Yo#fveFm-#OQTI8(si7=y)1YANe-f#pv@!9mY&^@DT zIDz#es+iO*uOb{uXlBZzLJ=D3{QPYSle!@bSh9)w#6sJ)R@3reTk3hzH4Vq7DuYf} zS-f*>TT+jaxuH{|nc#V(gYrO>;IM;_%^%dsJ%tQ8VLk>w9pXeJ^3%KbYxEiohY@z2 zr?EO!?M*8P!@)^N8vC%|^45e;wS1R?SmW<8ow26JfzddVp>4c!yLYX!*6N;2?08WV zZnf}e{!O{5NF$tjz8~f>#;W~^9p_S7&$*#j#&xV`p$2#4rf%Z|jL@>`nDhSb>4Zj< zJBo0k;wYs^3;U7udkQ1M^eDT&y5csefKXz`$jY^2DGePV5p+<6YBe#Yay{0ZTSZDy zLpxf7+KHb{LQ)9z!a>>bLy150u0>lD>mv+&EC#Y#bGmj=JCjLvn5ytB|KBSCf9Hwzz5*K)2E?FpNfbqf1Mu?72PLWfeHlSNHtRSjcVq`_nD@t>V@sbR{b zHlddtqmA#}o##-^nEZB#8*#0;`ux!E+)G<1cI~DrE1lXR?MMlHPZQdo9i3*V*rV*0 zfCHpxElt#qf|@*Bup=um_;S|NHzpWU3{U z9`49OruZ$u*E;da-H|sN{cS8-iG|gs&8UQY5*hp4`)wL~&{)+T`LkT_VIEWXLDpe} zN{O;NKd6xZqLx{czb)xq;`VPMyX8(yq6#YoxlyW zFF6Tlt&D=P#)swHNz%{}`j7OEK4Aq@82`dZgNP=4G!7f@!1kwxo(M1rd&=w$4`#^P z42J&9Vy>;o8VVo0$sBN>%${t+MdkY?ou}UTvXK3)CHo<{T%LcAb{vI$XT~O{C#7Ra z!58c_N$oY>wNyOD4LZ^9zK8_gUZujVp^3BAdF4+~3h$GT#8j~WBW z=kYwWT6>=IBl+$a_TW#0a=fEq-WFvT-g(^8!k(aZ#v2P-CvRDzMFzr{DP+8hLs)W1 zSF`wUgMTIcNUIyfr#&j4s7Go>KUdft*~V3Nj7x81gy+iHWfzx+1ZyQZI7QWK^>k*E z?q#qO?CEl8-mwT?KE7!J?WnoB{0K^)ESx_$W&KyMV?=Z1^HjTK-*z&;>sC6yd-v|% zVA3${P!C_-Iv0`?<2m^Dp1lgzv-!UU(?=@5Oxt+Va8$sW{p=*5;!$eU8~qN9};NeV;{)dv8GSd%+xy)@WYSIsq0dEkd){6 zVG}|tmw&U2pgU1~v)P^N(n7nTPhxhmhEO<*(I)#)?jcTdZEpEl%ZaRjppa50y(In( z?(voz*E%`hqdu9Qnl!D^8>Y21ROXZw@=6IDJ$!>JnqT9`_2ncMTX?TTCcd4tr{V@t zPxxnh7`TyrG{wEIH28V~gw{4J4#npQi#NI#C~m`Y4xAd#Fr)O{YE}I5Hr=lr-Gf3N znno78c02R#(%q>9GzS%HNH|uWQ=0Bq)D^$3G?w4-!+t`3--!mz6Q6QMY^RfHA5zX{ zrZ8}ubwf<3TXG()X}^8oDrHYw>mF6UsH;yob8H*MjH zsy~+A@5GK&ACqserCs*RPRTcAMW>m?2OUcHZ}05T&hFEqW<&+@LY{`1a>@v{$?w~J zDm3v7++=D(8I={w=KYG(=U!9F?>6JItBI3J z(W!(k=k7m}CLbV#Ch1H+aw0zupQ%tV8JT4i=6&5Ck1Vwrcd>Ac2^v3Ef0Gm0=$jPZ z`%KpHHKtxIk2P=jv>>3L_Kc*?3)!EYAt=?VMH3En4CSv3t>V8L6V{Rm!9P1jq(7I> zk37JO>6lkXzwMMvq9u;Zi>6mjwZ9&0T<(?PXi{s&wK;L6wO=p$VtABMKm@k@Lli1| zxmNJ5ws&Ug7wz$vjC?x7{$f|32yzVE>X^iN#y)Po zM$nltgql*rCD^GN3Atl!Rmo);k-SzGJ&c0C6G#iUYkJh}#S%1QR~Du?n)Wz{#8VSD zG5OVmq1?I*TX!7y;1!gm6BQj=X^t5>ouAGu9AwJ_5 zqs_?O2R*Z&2NRLU^YsIQ3ZosapA&wzM}>9LTsLqTev|RRv3BFz1l#%fja5O^zZ4yZ zSUn$N@kvvj*h4sb)(`Y9^Y?j|xJ_fYyySt*vDOyr)eR{uhXX0+e(SSD~RXar!NmNF2Y)Y!TGBkFm$SkRD&?{LVz+t~P~YiTytgxYX2g07| z0m`P(XRN>C!XZwbYir|cw}k_u23`x!5Rt*`5N!n3fhaZ zY?Dh!_ASbGR@FF7lbCB?Q@6Y1O<9^w#vF*zn^(2Zbx(uhJg4;@So!3|~hJQ`i? zuFiXjp_esPh568fTZR?}tQ2Uf=TCdIcpt~QwsiNrWvH$1Ghy`J?7cRqOrdL?;0lMY zD>u)JzN_I_vS;kuVntqmUUm`9TipWJkGAT#aj85~>o2uX4B^;>!60FrcS}p)SwSew z?j;aNe{rX7@KV%mS?|`i9qBkk3f53ZJ&zd=7#~%0o(LS#X4FRnhdZUVB3r^XMOC60 zgn1{7-TMaVcJz038_O8AI=7O%FDbK+6&xETheq(ebs7urmQ4zy1^vQR7C1Ju2;cm* z>4IpfATRB_<+25x3m0N9sV!S`$?)XM9IXYi{Vska%I`~>e%q<~?$wfSwyAxyd!Nm3 zN>Sfk=6|QOJn1*NyT8d9NB?$TNVBY4X2z~RmpYi1#u-i@>WC~R2|M2M<*p*WD_lQa z)5*Pvun%UR9@mp+k>yB*<2;u#qC@bse;M1WGAI(g4T;kr)B4!<7s-9t{^rb%u9){| z&HRdx$Sk?6&L&C{O%b!iVm0sNxZxQ}97ecLdn!+^DdG=ar?5$6Gs9>~MmWiw z#XnJQ(V?aoy7(=hmZbe#d#h;Z+?gZyyNR4XF18KX{irL8PX zERNYV?Y)q=eR{H}h`(!fPjdE+?1rRIL8h8*q}x<1iHgOtx;m9FR$OaI#I>UN19{7Y z9QQIc!~<5h0!2kXI$%T@L9A*Dn^4=N(3eWDRbWXGRn2m|d89ZF;YK;d#_Y zKlWKf8QxNBG_?1Dv618S4l=JJ6#2kjZ)|;fE&^*bT;02ndwp%<#gPLEhG856vkK{Y zg8wsbQr3R5rBLfiYZ8+5p>^B|RTEIRgcDcjIyCS*+UX->(2`N%?uylCV20L}Bv&ZX z>h*l$Ip-XGlJ)d#Mh;le(fqLjX0qB8hChI2m*a64XjOWRapnRx5OX#!g>}jg4R{T^m>3ZBCySn zHOg;=wuFH!&c*a&1Z7fh=pC2kh{6` zA;iigVfOf@9+J?rE0EAt?6fVcI6R`e!*ngVFB=7ewz$?|c+K=M%{LTKp4 z!=7kjHa*q~ABv(acVZyv)a%jmDV3zfUo^u^7u45`|0Z}|8LOpLaxI)A?1{jUy*PLB z9{>*t@2e-$Ryx+-o+jb$B5RWG5ep`7R1e4YvOn8Xy|JJDviuH|nWGpqoy!O;hEUi@ za~j`B8GiIUy(YO|UVVe(C{Lv)Qz9BWOt9ZQjaJ$*| zKZW&1_O{V6Sm9~6+wL~s6G2}q(C*?#h$G*Il^t55EfjDH1QHfUl<9w3f z{90UKRA(cX8Zp`Xp13mUAq&H6FPhjC%zj^#JVs^f?33#ub8fbCWSQDrn#o2VWI=)s ze?nj_pVh0!rHWk*Rcm_w@!z|@2i9cbqrqEfVR{TQy?06!6u$}W_OxV zF@~SX=is=GI{GzTl-JQQLJUc^i5)Oar)F+N}^RS)n~0kX>la! zcoSYTccSx6L?64eo{2A;I4l6h)ARhRq!y>o4pYy~wAsROqt%wNA&95Z56UxyJ10q; z-Ldsmu75ZcrNmDX>rs4mdlBIhb!3Qrf#^=S$Ls5te9~)^eko5v>?rBWJ{MFK^T?(fBly^Cd~Ex4`ckbu{WU zA(^=C`^YgBMCp1C$;9Y)BlE@~<+}Z@%*#&WiHH(Sb~AT7_4V~(CtT%-{AfIBvc5rC zj@GAW7K13Ttl+xi^IqaEAYByO)ws_PNFrTXi{Dd#M9?uvUes8P@^CbFGc{AtAt*QL z4vD2sY!HlL6G>4!xlb~Cl9H!Rh0(O+gK!=%1y=;QkuNbqvNoD#{ffYk_uHQMP??*q zoHW>KCGTVnJl-H$1fsg~@3BJ&YtYU-|Bt4l(NR08QR9gU2@j1J>P76`h+j#yeAizH zTcTZ@u#p-Sl%W)#>2?bnwO=$|k+8%j?h(@sXr}EW!~pK{beEO?LUgNLMRA=om&P|m zsp~apWwW|8at(&mI)(`GiahP1seFP5azdb5ID`tt3<&+EucJ6EKG8K{SK2wZ3wz6iIu$08 zk*}VpKW5^RbdB$HSAb1Vs}8$@svZ=!)wXRSF++#kgp~CN_Q8Z{6Haf^Zik_qyNR67 z;}7T!V?=iAM?gTl$Pk-BT;jQz)HEO4Go4v-;dYE4t=(x^;kkRRghmMQeZO zDz2eLJ&_SlaP?JGlKWV-ii#F|9~)-gis@-qt+dJ1wdvpgB5HqOZ}RaxUQ8f*`q%J; zQJ#IllXi!ZlEGi|>a86=wLe?Qe17y;!`YC&8wc^!?-w)#gjMhC|G_A9YpYLFh)Yu0 zM#cf2wq139^NDA4r6*gGjGy3`2M%$k9uv>LR`bX_&u~GHxOH#9L|#X}&hB+iRxBIy zawxWXG1{<^mLu1YQ|J@fve>O;ooe+Sgo~$9y;im#)Q$poG= zg%ud_mqus&wO}&q6HRqBCiRqXNs!AOgX_sjtyOt9s*>8GY2UFnVIJzd{lS`h-E$f# zOQ5tz*IgOuAJb3lU%aAL|6LyTd5pm+BR$Wo8=d%2Br7eCOsqPI~7)z#TV8dXhj-ZP23dRC~t2|*hN>075Z9v4H11$jnw&# zHKpA-Uz8fM$)^3qE@{-^Qpa z0HbeQ8gj?_T3h?bv(d&ho1Wi;ab#73O1$3Ld;8iG#_pbdrH(P`X;JM>r&Yg6NyI&I z43=%N=tbU&OeBS_jdDnIKRuR|>Ws^(KDByy6^_jFG%{BrquhT^-8A)S3yR&lq_X4E zHIC0{n-%M`|Bi%@WHXl}Y|1I6^d<6)CxQyupc~ooPrUkZdC6GvpTpl1jecp+-k_E0 zRZI&Rd_-&5TWM+WgTmC!qkpChzRR^1RNb67U6?QTsw|Q^$UkY$Ssv8$vfYMyCL;|O zO8PTzFyZKO%c)yNCG6?=Km1;mg~?^cwyju=?NPtA_1rT$vNS`BJrIAL*-CkGEGh8d z>BM73So=edQukJebjRzI^*R`{D;v4xIb$!M5vB^RCHH@ddEsL%tW)!`&ZXiW+7y(R z-QsFEaeW$$f->p#lSS>AoP_e+J1E>Mb?nI`4}8T1S6!meraIZd7ddQ@1U-oL=xE3LHf3KbP?eal6%#NLqV@iO2c6FX#~%q+TjDIcJYbgFn98pTX#m^dWE=gGYw{0$BQcW zA?vJipXZ4bz9!2x=-RE=wxF?B)GoGl$j?c}uHDrkWt&lKeTxsHCHt>cHPeM=E>`9m zgy$7BY-N7yC$|Cv^S5fute6F_?0m`g4_2}5S=BKK^uQbq99NdN@z=g!$2WJbc&VB! z-mSsH35BF5h;BK%MSq60_a9b;tf<6^{mTgTsi?v`RWZpKqcSO{=08pb8b~ z0{%4qQu{TzR=aW*=RI|lUXW+bxhNc*x=d$I7q$z^I00P0cd{CMX0cqT>knI9HxxE` zY~4Q92r3tpB?Pi}j{iCIspEx>MUsH=Pt8!#~4 z`&+cmt0afKF|2Q;ag~)#|9$o4TEmI_$72m1Uvve}%zS!X$gk|0?M&XjJ(3fTp4z~z z(R&}qF)Pj0byDg${9<7;?#o+l=Dp_LEk-tH*RZ?M3HiegceoF^M65SuDg|#^o{;7? zljEO8C7Bfy+;(!{bYE`~T;-1q~O_0shBld85 zqEFpTR@o7ivJ>@sjQ8Xa-`uSJLGG`SNsn$X7%x>i*J8EjfzD!=DcVnTX#0k!wV?*-+z6!9^*>T_os1g2TO9tvY+Y0GFsW^c*{*_A)^ z#v@Co97D(2S4_EmjP}DUv5mCVEZyJ!xG@B54Y*TN{~ zIknyc8|==1v2(Q5G&|k+Dl}uF@00_>$fCtCO5ocQ)X+Wbp|q(6z_wU&*W^#7YqCv zo6G_Vo?8ynUcBvm5$JqYpI#bhk!vw=lW=yyVqX29tg5x{_A0^SWopwcM`L;iuD76V z$W_Nit0(_txRX+o13o1tv8p4OzgRXb@Ty)+uaC&zWKF%!w6V?VuQAU{URkTYR3qSl zC;4e^Q90#g-{PmsxDA53b*vTTIhjspk2o3oG-hH0?gZw2On7Z=)=`?Vy{)fsU+Qlg z<)|oD3No{y|6E$bL6sceHUo6k;i$7DzdPxi%I04pJFYyy&~|n|cO$rmc^j{`ih7n) zH*_0Mn;^5vM>|%>Mw@w0$418>16AwzB^UfQ5;K#0la$a1rcSfZ^|*{cpTICb_sNcr zg#q}IqOp#o7Zd$2T=W|h+;0;zPdbpNM|{03w!Tn{RjacfWIW3%YIPI;uTF{x#ggeU*``FNN>eWnTW`szD#)`veDE ziLAVe--_14@NliEyt`%{O76V6+e{;r_H>Xk?Zpw>PbJ0yV5rJ_1hrirt4dx)uN+NeP0B5(@J}%N`u(!z~5W%rp7ir zUFG{T3Wt}atvhV+@0P%7MSq|Ue8gK>;1{bzhyDwez%IMav^3My-z!A_f0n?c>tBq` zpH)us-|TbHBVs%K-2v$7uX|rD@YMoeE%4O>UoG&}0{;Un5bU_5?%V-o|L+%cB6@G^ z$nmf7&D_1R^5K=cH3#;%e(7`S|Fp^Qf$H(D%g@eijXiU5f11wS-p#4|_r}Y*KW*xN zxpS!PI<|EChZD8)&JKRB|0d2Hkwm7pclgeJ{?nynuKFI=-Its69xD0qNA@kxk%zY*`TN8TUm95IvTEm#aR=5#euue5&inJ1 zvb}{*Zk}#2woz8jF8v(()2a&B1E+MgCQPy4zq*BHU!{5UySEH>s^WGp`ObOOqkZm+ zv~E_KiLM)(zJODGe9h5osk!#0_Z!B)i;mfEXp&zRw(`)Y$QO@pRh@S?iX7}j68vH= zg#Wbar*(=GOSd44{TuN7+$0Aq0 zx)gm3y@K+2w|t~=1#Qy%cKq%oH;Lce!roeVpvUcW40+uh+mEJ&C*@n4DxzQ4*vPr( z>_&yV{#xlH?x|C3u)2Rb@6e?D63{%dAjNZ_m3BMN@EP=o!8vMv`IT{?Pe)qY)% zQ_I&9e>yV0?A){aSATf>Wrv=o>cd3aga`4JlMIBcOV#?D%aA|+p!C6ILAq{a*rLtO zONQAA*5|v-erT$)z;_!bn-rVmvaeZ;E$#NVoc{=Wb#c=AYvj>c5ISV)LZ}8)6y(9ILJ^c1!aO&)cmi-!*UV&5y>LbZ-wL(F%HTd*;y&9$1KYdF#-}<{cYV zo+V#!(9eur*<$=*SyatVE#>SY0}eYjZ+Cg#LYMs;A~aT6nxYxoY(Jb{uJj?;PB_Bw*zJZCT#OJ==W}?`}ZFWE@ym6n1s<m_{@RA&MOj(5j>}caE~u%dM5==e~^GeN#i9bPXvWRjT(`czA#E^Zw*0z&A9}t{}RXGFK5Fs`pswW5TTB zqD?UT94_m~=hdqkf&Yy*3u`wO;Cq+U{SQbsU%T+t0$(lg-^v32Zv8}Dhfcf||K8wS z_4WP}UM2eAKWiaTlj5J}`G5R-gVO6K1>h?}lH^V2Lyw5<^ta;B(_i<#THvb%zFOd` z1-@F~s|EfCSm5uvr#ogl13qs0yAHu@-4{_A6=<19@#7e1-4~NN>)wDr=PXeV)C8z| z^XnNmUr&rb#@*R%%fH^ehSuLQ0qWjB>pzL7GZO93S@#CMUaE;y-5WWOQ}OR5g8c{T z-azZqh<`5`T9?QKsCxsgOC)|HTS8CiLQ!D5@O4v7fVwwOx#GI15^@2e;Oj++>!(8M z`AAjklKH!CDx^-Ns7xTVOr4ClZYrdGeQf{;-@dqh>i1wjXY8u44e-CyC&1&hhwBXc z|C%sM$OZU^m%9v(FMVym51@RB>q0?z)z=0r0IjDawvR;nbJnvG_eTOhXFV&iUrXo# z_yg3yV!smC_XO<)_`z;Bpxmk}#?>{zPje=A0z?&e#e9kDl7e`^AK>fAirW`&XYtbI z#>D*+Z$aFgaW1acB!WQ1abo`iaR7f%JBsV1O5lKop*RhAc>sS<+-BJR>H_>h{Q+?GnU6%f zf7c}w*OeC6MTPRFE5#dZT~vKtfH$aJ#C1`jey9uZ2Gvho4^;vu#T#rrRDHlK5@`;$ z9;&{m9`XdD^A%LVv_x3e)k9^hGGq6FehVD*@BZ-}`~ zL7X@aXX8&a0Pk`JN^}6_Is*LG9Kv{ncshFlT+SAG2d)vU4;-L9KRJRwJZ+?Q8 zCzXeHfj9{}0X)OobvLYBpj#5~Js=L6k45!+p?ODKt^?X{1Atd(zb)b01^9&e`99eG zi0byj`|W;sJpn$UdW!3&ipTX?zZKQ*g`a!(WK0Y2gJ;yS8O9t;3pB@7GdE3T&s z^#cQdU#K65>!?b|os(bj&#XiFkWN3~pOZrGqx3~}e4%_Cgyqc;;1`NJ1md9m5!La9 z_Qw_;2k;BU{R-ls{Wb)6g}3Kncs&7rq4Vb>@N%UM1mf&qu3z@eWi0Fir zca?F3l_x#yitD99+m)vCB&S;>L(foq5PhOl?U)4(Gl@$uF(Dh9fqa} zne(vnr1|g!apL>e8NW9K_>gdXFIaif=@A2m6P3A^u%9<9Ze}=x#IX;YentQv68iaq zIEY=006rw_;s@KVbbR>3%A4WC;}pilSv=zkfYlG+Lqfkm*mkA)2!h2)=eJ;3oOFC( zK^$bfm=Oo*hD->YJecy;p&(AY-ZJAa#sF{7X@E=^tUQ1>34g|cI0$dXGjRk_5DzO) znl}QhJb*Wec>4`jo^)Qj0E?66?IMVS#F3~>wS*spgE(k^0p2A1^b#CSI^1L;;PewM zze>d0@31)Oc#DL^N$0=IusG@Q=n9B~_={+{Rl;Ac!pf5#kFLSW19+2=cOAq*_%H!@ zli=e9Y`fC&76sxUyqN%mNcdGWh=bf8>1sy^;?*-be8bVcixP=~n% z;>7D(GjXY_4Hyu9oiGN(L*{+D+JGS?0u;pk36GZ!1Q2%{+(PCdy4q5i=MKDF7*C%C z#=^>#_T#&-a&@Ew1H{F__O^{ar^Y@jU=>P=D6}c!oarc?ja5_L3eq9>L1f z1^9)=&tD)8+J9YuSBV4(;vU1}r4t~Cdjf8u{nC|+<79Xoz%Nwaryve$PhEgtXySSX z;w1VtC$G=p@c^&z@ty*Y2Y7|I+Y5O7oN=5A;-NfCj~{8Ua_3(6Oox{X@D0tEFX3?j z&rnmqz1mX@a;?G z=Pda4=j1gTRxV%$2{+4ul{@=@5pn|n;&Q<))V`v1O?dm}!Rs-TNd7so%!kK`;ums* zRsfHa%Fl%$4r+hVdL}e}Ngxg~Khd38$AmlweFra3iqH2T4jNDT0H4r)6oELXKk3iO zXEBI}%zH%X2)gVjQvxd&W?=-xmBQoa%*$mU9-3!G3nkF;qa0o?z^8-*g18Uxcz{>< z_^N=%0|tS|e+2PRUJPdP^gsQu62wXL58xR(K74|2e;-(57KPTVSu=+|5D`a>a zF#bXv<1;)?Dj(Ot>M2^+gy-!Gyj-cct%a8>osd9W9k_+guM7cx;pbQN@bc%(&kgW+ zfMeZ#UM|2hy#1O%Jk&3x)o6j0D_YQjm)i=DpEIu8Ks>a6q6Hpk zT(`sPI}_LcjH3>CoK&8sz_$zV47F<~ygYzk=s4X4FHb6dyFncEzQG9K6MEm!11}HY z6RKY?tbTwv33uv)$4SL&KWw|A1s-U-RCsw({V0c!b)01jIr8(O8PdQ4j~c-xe+Sz~66=!RtF`9;ShKsNDcK{B%cV z9K=K8dG;-LK#y?`+#R$tNkckpq~fNx(aJ|{t(M85!j z;r)RLuLr<0ydF~^9y(7Hy^jZPI2OEo=|l+PropX*{iO0S8^l5VO|)PH^*0WPgUqi0 zt4V~f3>U;f_=UL*;(<6wT5rahT?^VDA$^PeehZHW>&)dM^fX5Z-j8pVKdaZy$&&c>5~AwlA%&BCK4|b9#8WOX2YV&+v9w29KXJ zpDu^T&lyKbARg*RqUZR~`@iqu<<80X3V1wVKzO^X1o2QF4W-6UWq7&L6BjU_tb)hS zIexAN@lbmJYDlR41FT$_2bQ27D)9I@<69NPL-Ll6^aD&ac)2hm!K@0N>EK zLcq(NlkYX~cRo(vsl zHQ{km^QGgMkzdS!m*qXXih ze2M18(0I~?mpf-3*MrB;8Q1#o_&NCnemHLC1DWsYm;n63JH8=^hw>(xA47RFg4aVj zfzCEBV+`UY>?al1Ch$0bz}fa)!1t0(;c%kII}-WZ3{Fp7(R^0|23|GZB! zhr{X2JlBD&3tNCVs69pRt3&-31*fOZ>_RHsym$wQhw>_VZylc3o$&h3@b%Ar{REGb z;%OJWz5q|~^PAnUa;4pT4~T>2Y0-@g{QcivSUp7VvBT?O36BTj6CS@09xtsp=s){G zJk-xc&v~GE;%9idGkF(|Z!37Y0N>F0+W~mFbB?#x@NxmZq59gu%LVv`k0Ug!TI7XktGEtEc{)`PUX6Kj--RD~N~kJM-Q;B<>Hx%bkT@c22$Yi|$_oks)1Ff`rD_`u5rfbjP4g~v-L4$wdR;PG?Dw?CZTqGcKh zy#qiTWL*`;+JPWW3Djryy*VjyGC?2?n!hEzQbr~i#4QHp=@_5!bvp0t@8;=oN^}Sd zo?j?=9^4NB&o2V^^>u;k-JE^TV`mS9g6Ee=ou6$WOrrgn>;KU{c-i0Y-;4KwUx2$S z`=`UP?{^QQ&j;1*Hyvvlc-a%>+1g<2+WOWL_%D7n8O5NUK72U6K1JvlVAJkL3r+7p zPsoh4sR;S!$ue&)XSwAV((i0*+}LrJ^ZT|&jgB+<6~USGkJ}p8JGybs);8WPYzR)I zhYS#U68|UdsR%NB3peJGY@Ku_dxE z2q*8TELXM{uIYK-fGU}MU?|j1L3kmHvP*b-?^%4Dta%cxgL^U9a{xiCVZ{B>XIi3^ zO1MQo#fW#|{K_Dn=J?mT-?h+tM7G~jYo!vKf&KVHY$j)-Y{m7h#|azbvtHHPWpJK9 zvC~WQI#Hbcgk7&&@YY8yUW+`LTgLoEDjZyLUFGzo;Db@YO}2qy!4;niKDNV{(;H9wmVysc6leX5Vs%ifL?%?7x z!niW`owTgX2YBi#qP%%D<382LqE0x!R_hYS;Z&H;Yr*m!s|UkQ3r0z8(~Ekx-blY# z9r|Xl@h(1<9#Gb}i`J-efV6CgpyZD-+BJ;1&_FrM@q9v2a1bh-;U|^_t)?tVq^b`v zaOs1rom^UEHA7~S+ZX->b&97w8G>@*JSvM`jZ#RYE__Of`OHbaPW*ie<;r_O_xmiM zS`*$+DnB5%T0J41=PmLV*jW>n%nk zT&Oq6=9f-7Js2LnG)6NTU`(0$^HKNlgrE3=iyZ!vRH58)#Zc7g5q^HFkY@WQ+2>d< zPbpR}3U$-y%GNd1*z9$6>C1hox3hJ985`2(r>0+R!!NhfD=@7~|9wzrs~@lR9`y!! z1d%z8%usIVIiBGZcG}!Rue_jGHB*0@;WHn-0sOsPjTq;) zFWD<5L$VADmgYLD^t}Bn2>dLV`pozLZ7S$eLy+#+;!)cCc8uLqC!?ORGV}Wab8DJg zn@XA%)e1PBX3$vP@~Nwe>0`*_ubV2OvHlcZ>Cxks zx?51PkDNd_eqw)ST;)+|DmwC9;CjLZpO30X4H2Cx&-v1(Bh`>g{45xQjqDi@HNl^G4=zY{{vwxnnz{N$fGjX3*%<~ zYt{l<|3H}f+XM}^0*$131r~Y_*mnTYy(ot203&YZn;)<)5s@J(ASz0jL&zke zpe-`0h(J&gQKkr}0Ro9A^N@%#4+%p^AV2~MA#?I+3;pixyL#RAU*CJ*_x|s3*2=2X z-gRo%{-vtUPSrlU>Ly!W;lr5H^WWOm;z6`IWAsI?uodqxg~QIzC_PH5Jvsx6z~da={x0x0jns$Zp+0b(~f%krXA~k z98Egsdr9SAt{HhD(++QoW(c$C|0)EI9NK^Khr;Kkwcb0OUYK#roX+3y47~U~h{46_ z>>MwDSdi~;|3STRG0^{N$ecM>Jp%qHGHuq*SwC~bU(&zpzr8bZr~Bp4{yXPy`|n?> zoz4C)L*TFdcV-@2Kiz+3;F#%;e{}X4;{S`(-q!B>68ii<=iiUh4#6{WhNkn2rX6&R zre&4SY<&KW?5)%Eo74Br)c1NiU-RW>ff+rErpL=nxxbE|zY+7_hWvN)z+ds3od^Eg z{AB+v8}fJaz}D$;Gt=mqaq^!v5By#K{p)!BcXrs#&YvxuvFB#mJDdLRK;WmhS~ zK7;nn*`+pTmr~`Yr2m9C>?gnetm@7DInyG(=l}IQyZ@{a{Lcve7^gp;>wo#Xq%@g% z;Qn**IW=8B38$v{KN|gymvs8YfZNmj(=!Xn(}j*tJFo{$^QAKj#nb#RGY;6(`5FG- zH0K{h&wigpU>1Q{1ZEMKMPL?z{~`o_t^=t2X{hX*c`^2H+?JvRzDq*y2}+t60MMk?jT$t>j_HG=ItjRta*2d2cUCg`x6ra#L}=trBBZ*Ya= z(*j-r)e)G0nijO(M(oJ}u6nT|u*ODt0f+qHWe}Z-jB|{IX!=XF2L}65liSTRtK%g6 zs@SdoD@tdEoTjwdB%xq0j*1)fFzoCeT)G0X&R=f)wFrC*wGz=5D+yQySXC`jA#ZD= ztFKjFpy$xrv4^`0Ndx2BFD1_t6-)Ibew!&0PwBTg5L6*rpfT z-$3ucjsvM8gMBP_XG!RE*!}zbdqI}#C7V!x&@m1k# zA-H;)ddFDEl=atR(RJ<8*xHyX@<=K64MMtce5S5b0#msOozll0`b@-tjV*S)3{D|em!i_qxGw&r zr)DX;Wc-2RFc(GeC5I+!sJV09b0EP}UR92kBZ``Qalv9lNCJX=YCPRF7cc89S_Z6- zy+=S_8>kFi3TKxc(c2pnGuuOyD7;)T4Uz)tkbQg`4E5*SZg&^8X6>WaK2?X8pchTFi5g45(# z!ZV67cG(j}ojkhC+M*?LZ%52!=orYZK~I_}y0Lbr=oCyz+Mu3qE?+rNL8eCz<#z#w z&CTeiMn8YN!|H;wh!IouE7ZP*)Yuy?>6nPJsB(3Cm^OC*0Z`b^g@e;GL#hnO*|#IJ zxIQQ*+Kb*_8iE(XM7Gj(k;nMXmfkLtl26}`LKsDp9inJ|9%%(=eCQR9W2kjMT}12; z=2?6$=`Uj*#TtHt<-uvl%B%dQPwm`hb{A`{;h28PJDoltg)d+2ay3#Il+xfx#VY#u zVn9Hd&a*U>G9D|oU7~nrqhltfL@I9JxyA!MSlBU~h~fYfJhVeW#bp^2S@a1jo=%D$ zJu2BR$PH3NSJZ;%DwSf7n)VH_Aqe|CZ3JfsaZ0;;AFSwPAr{ITq5YLTT&&-?q3195GlivGL#0k;BVBYj#Hkk z8%8oP0{FmY`dt<7pbJ7T=mY;AUuO9N!tZfIHkpGLl3!WA$rMdme3FT~tU7_XH+KBO z&?Lhs{Her#4e46M0cozDV{9dakISF%LQBq0JX3Js@=t5XFFEF1$)|wH8<6K-{lleP zeZS9i@6Ri6)$p6W&uG%(G^Bl^04DIv77jZqZQrhU#!?A+s|@dGIKgj&{G+)Bk`WYi zXG}Ips+_mZu>6kS;~0jJSA7&W3_!f#d7wRO5`;L`BgiMiRC-l4Cex2=NgcuLPq}3$ z0K}9~1ye~6R9m|MnK-s8znUM_Hn<_zt`Y$hpHSrC zRtDV1rsjbLxuvR6-)_B&ev+0;2CHFUcquUH{bf>!+nr8A-X@2o3wsf_(YKd1{kq{@ z2;js(nb4;wn2=pMM*)-s*_srkCu)!tx6p^O5UN!KHX9dR>3yeJRJ^ow zd&t;6*&gbAyclFaUC5TaJ7pN8=lCc(^<@juu9>FFoMEpJUeY6M%eLj z35J0&n#|oF+Njc2ZAI58b?+TuSwXc+O+@F7g6ZUCX4J$lqQ@?MCKJ`JH-E=H1OspCR%l7%~=9|TxQNRCmpt&J&w5VvrCx0`$uH1SAY zM-^lG4P`oOidJ|97}2>McSzj%E7Fp`?Uyt;mqud6U0Z)GXmmRbrt6f~)|DGe=!^XPmLVJJG68T34*h*W^b`QCY=D@Y%Qz? z$699KMN~|y9{<+f)IywPz*55#(JJT*qey1LmwQTi*K@u+D~-C}@44bGb^Nzt8sIyW zh0pC(^`^Y1oW4S$TOQ>k>wX~S)Pn-%m5Hp$?sgNaiW{9;G&;-$lD)xOEhX`LSpZD5 z@|rDi+Y>Yvq9dv{RDq|d?~ub5yMcV;G=^Es(w5RSbmnUW0Im5_7ln; z2FYs`wFwT2N`zRJ7Q58-VFY|k@EYm>dU1DOLGRM__St^J!zLXMWv zF?-+gJ}V|v%yaUsDv|w}%ybl}OheMR-uxpilKG{7pm@DSm+3RnOF$O%QD*-~=@rZ? ztmydEh@ZCfN!N&$6>c?TH@*-^uBoqZO~npn0uDykY0%?RM8k$n?)`?W)aDilR!h{3SFy{WwMg%oOeO79P5#ElY!D1KgPiOj8A=$-ilxO; z)QaCT%i6v2DLRtMlpypM@DUwuoRax0R6nl*Avr#Ll#^m24%kUB!$!aaU7}>=MZ+)UDq1Rnusc|F3iAA`Xf zo2;jP6IYma%5pKDO0&em`|>YeEZGTy(Ge|&iFTP++BSvRYrNd;OZOXE7d1c#j@21+ zE1>A+f8L*FRju>{NS`R1`Us81uNh|@%jo2E?x|SIICnj|04PII$-FkwWCgzK5UG|v zj+feqH>Dv{czTex0I}JU4vn6HMY%p%LJM^s&NPFWU&A6jqiM8p7BDjB-ybNo;|N?Nj;^SwZADv(t)TXvd^mA=MUZ?^W~M zj^`~q%i1LDrz{!$#^?gxp(r%STIZy{n+JuZ5aIoUGpfUqeRss({!t4!#2sqWWb2qdLE8JdQK5guHK*oL!rX_ zH`PHp39L6U{h5=ixMSVLRJqrn9-`5j8c7lErXEyCv~QMDb~DdP4*~e>zzrje<=B9e zE_Yx(1C@}j^srTiZ1(O4!8ER-m$e7_)N8~ArW7@j&_21S7X)Om;0SgR-7JRhOoDzR zXSTl?y+qmqvT*LsWoj$DsroVQb_g$Al4V_}IX#wp4tia<#}g*b#6Qt1fwo)8VyRM^ zI}kBoT6_tgT2ev#4a=q#R-S5foxBG=ln{;*9R><_dyk&xU8V%WCa|(BOach~76bpn zEv3u5Sx(YapTI&ow#t~em-@REw_aZu!z!T3hu=Uv`8aqG$BavTCx~3G?W1_7v|B`A6irJ#S2UBF(7ZXY`x3v7w3$gp$@Drs|NxuH8q--VQgi_gjrB`Z_Y8}2c^(NLuk zbzIn`G6LT`wK3At6!0<)ktw+y+V8R!GkTJLmATn}lAg;}>NnqO&tU4Yw9r!!Ma)eX zX30~)E%w2VIOrFqGbY-Sm>#wp9;Z?OM2!DJFVScsKc#1r>r^obF$nIa2{&N`v(GfJ zaO7~B_lN_%z1Eb{?GE)Gbe4(_K*faO+*yd+X;)=GkgjepEg3CG4}7 zSt;P8$|E@z!Z_AYOEvRKd+$hYrNt3@u*y6@m+}Gu+41Rqwa-@~$PkRV-j^=q$%4vb z48bJxK$1Cj%cKpg-w%R=l@k`wmvyT`wNTm*-Ly*Cyt9+zlgf1aQwLt z*~d?k4N`6T+D;+qJ{Be(u~MrgEbwmYHJ$1it;Xm^mvAd;oRNhILV#gTj~J-TF+WVb z!CnE;a<3NmL=`&o?#EPg#6cp-6#;_>?>S`LL`u znL6w!jkmU(u$Hxnwm9|^5tY%Rb_f+uEZa8>EhbL3=`!K;jJo}24d3#pqZjYN7oWvi1Sf?MAWKE&#Eng_7Dh|lW=(8#DScWGbJsBx+zqy}D z*J2EV184!H8$FFOUnT7m&Y4x5q%B0~$Q24HSi{ST%kd()s|IVpIK=MsF%34q#XN{O zF$`v*cCy_0FZSA7zd|v$aK#r>^gGEF$)Rb;*r`KP0gzA8G`c))2eG-(D+5&%;&W5RLsG4LV!biLK5_V*S{H5O)HT^>z9s|lm-wPSv3xS~P8rOEbx^@Eq;;!tH0romIyFFA z0ZV*iPT*(^s+k&Uw7NCHX7tv9f~N?Vh)h zP_#ai6_+WCR4^md1dp-y!ZfMF5CZ=c+4dW{zqQPF3gTffV%k<>YF*23x3R!0D~HEz z$jZ~P7@mWLCxDloKAfxpspq5cmZr);w{BHyorJD35AM+#3u;;)%cYgRVrYd09<%0Y zH#S!nkZchAWuh!TDko)xU6{RQTVDPf@I-3+o zUP>x=*mgr0ZDzhp4CW9-?xClUOZWA zC+h>5WVFkoALQ2yzD(kXon=ENQ{l{ek>Ro@Aw#8?QZX9XNz5Y|bB&}gSy<+nPeU2v zJDjwcK&iRL?X!9kqBA46oBftwXjT(ruM>2jJI+S?7^(>0J_(WJN22*XPN_#RLbuIR zKImuI@lVCcD4rwuHpe1(^h*%aAoz2dG7>8>UnLe@WQG+TD=n$Uk%fZpQSlc$oizXG zVfj{ZCdw50#;%RcjM5C0Jj0w+-Zu|T%513b2cgazj_YZdqJdI~aO6(n8~0p&r#AZV zM3b_^lL;{GXhLj}F;oJxfkC#>6nDWL%7ni&XIn7#M{H0JQ3zAQk4O{#|vL}s? zb+0rJj?Sk>LH+b%%2;|6Zy*WSs$}_i_`T1xn1fWN^fk9f6{`3E9qkIe0TsY%1{5U0 zS+wRzlcDG+F4IRO&*I-DZO^N@NFOKf1Q1!z;1u)723fYwE{Wwn|R@}LSQzK@mNE0 zANgc_I8+u7EW;b=z0n=L1_?5j)_SvYPhf9PHXw!9} z5}?$!{sp;t#n2cnww9?puL*Swou_p#{44LIytq)VdN#xRT!@=zl4*>~p_PZBN3Kk$ znH^m8+<1+VG`&MOx9~~XTFt^I8FSsEj*qX&SGza&mw}tl;+7ULDmHFJZS)_ydGqZd zM%uo8H=j|&U%K0B9a^%ph-mm5cAiuSw^0Ju>f8wY&~l~C3U-_JOBo>4HU0BEwuJRO zO4@k1I`di1_m%O$qDbSBlE`O-1HiD2UL<<+(D?d+$hGvQD*y;9y}I#12@F--P>4l6 z6(gV54ItXAI@;c7j9iMT4yF?ESgp`^O>uq%@=1Am>U*xrmpVX+*{5^Fo`sb@&qZYk z%7l~DO2c~hsKIm^G6}mR3nDhaWuc<2DQ!SG;K4$y$l0`Fcs`)hk+=;q^%egJ3%{`z zsWY_(cO?+@s!36NSi|`NPFdWvJ=5pz7(Q@YkL@|nom+*KBZAOUJ1rSFlL z=h`h+DeEGsd)7$@xD|87+eN)8VW>Ubjo^?POI8nnJeMfw^;?;k*Ji+IYyt-drK>2q zP>G1C+*(y_7J!le6#D+#F=9{4Yl&by!H%~_mf+ZD-dy^Mlmp%91@L*KFL{ivRq52k zScuPxOS&|*nJ34#VEqd5P5WKUu}hi-m3Gm{)|d>-ZHZIP%V75m2Lml6QSBE`l`Z3J|#H4WDsV_?RHi+TF2r-by#iL;Y7T0Ta5 z=q0w0yNUIhu4R<=61K3v?Vij5z*id)qCP*7#4D^^9FYas%Zgm_$B+WcYl$Vc*qZ9n z+Ym~VT}2*#PbUmiV!X@jp{%&`T64gdzM%RF&7*X6rrP&O?0LlrUZMYe`|+0~&FUn$ z>#%Nx9s63j121Q9969KHJ8Y>&2A7j}>k>hYyXzYHnUcDJ%S1j1<{ze_D%*>ell#fN z+a_+{PGk6ZMO$=(u(L;TL<>oh{}Ot`$7IT7(pb8E%s6}s_yz$=jjPGQE% zOD9%!>(UP(m(w-%G*6K4_81IEpPKnA(hsY$C(|H1TMe5}O~gs=X(6Cr@RV=#Cw2uC z%Ql3cUJJf?kcz!DN_-H!g{)_I7fv~f3&yA&)6pDOYtXH-;-wX_7j25>&FxE+@<|BMQ`z}?x4PG8lo+*3~?O8cBQH6s>))}ZA0Vz9C3RBe>GD@ z&N7wk$5c^QyUS0B)K%Ab+YM9*kBxbgw-~g1Hak0TZ(>0|n2YP&PpXRan?s?*YVdGZ z75R`>a_3R-BXMb47C}~77+Q#LBlXvy7QT;;s=?M?WxkJOg{4ofs=I?dEGcU?mTnw* z>37+!35s)glr@xVd*fr?VQG9#{hVw?NOJz*)2mTYzz zEyTIY&LEsLl=rRPjA-j83Y=tScsE9)1sB8?lQc4P0!UnKu{tN6yXvI;vdcNgT<^LJ zQcbGRnY*@i*%%8WcuBXxWd)xyXl)Tzw5JO46NRc{+w3LkE??k`ZldMyQSguS;S z-%se-W@AC?c4ICdHcYBE9B)%WF5}p6#y#S0Sx=NvtIJgjFThS1+x< zP2Tn3m7x)>7Asah77|kJA8dhq#ZRfOYb4GO9#<0D0Tzd8&vWeG=vBSvZDjGJhXV{x z1L#SwR}7n1m@D;;E81SgJSnRMhfJYQbW?1Nn58t4zW9S=cW4`GZd$V``%bXs6`e22 zRb(cBnrxLTq=mn~FSu8khJgUu6<<6_lFew# zvDaxEiP;z<_-*PYmJv&N-d~To&82+~js}_48va`gZbeqTeL2g#C+ zIvcFj5MP7xgdK-l^il5#Qg0nc`itsOtEx)qDp0xm`|yz*{Z)?r%N1C3Ias{OU(3@) zu#Z}y&0@&+A#Bpz-4^GW8uHFqG$b?f`xfyX09C{G?SzfWBYdUJz4|y*B0O|tvB(Cp zDD9hhlpJ~&EH)RTf?qdncNA2=gdE7r@p4t@wHbUL)YNVjrBbevvD=gUf9?dl^bd#Q zrQf|!fBL(w|HBXEKf(FeVLEW_sqcQWf4b8CtMl)<|5Cq~pH9H@{>2G+)l8I!zev9` z9Tk=}&HvH(f4rV9opI0pNuA;Uso*~rp8YekXtmCMg>ICyE%?&dYitqv~QeCW{PWt*#x zvGzOXelW`j^tt*TiT_rJW1(q6BE2kkFi+r2SAv8Dp(Ryg&bfy>Y#(p3s44I;<^58y z(J3kCSZMt2nx`JW73^_}35<2~Ko`s%w%W5vGw#F@Xu*Px`0Iy0Z0U*}tBLl|D=_JJ zGeCTZxEh(1q$6#%k2sFg$s`B2rPQJJ!>;&^N>4fpLTyyISgZd)l1Q#sD!03|F2i z&x@zQbL0Vdmg1YT1k@45iH4hpYF>D(E!g1{opa>umh9M;qjm+Wosx5GTJrHB#*v-3wQM5dR@5G4tlSx&a_DeeKn;R1&wp)k zgp18jMaCHF{(8mtyRBZuA!`Hufl2#0dVts@#(v&X#1rs=%W>5;1B_pGCWDVMV>4^M zG1l^SB-!9}`89mT0^Y_XJGYI7363?{jNLr5rWi|qLgFoH+2Z(t`h-=nF3?{%((S%1 zM=O6f65guy4iZR;9MSxYra!!Xc#pNeTORi&16f1^Hj?W3<4B|KiO(kmvLlT+I{Op6Pj1XcFhL#CnjK z;GrI$9#G3a;zs(N!35QS`Y@PC6p%Ay7po;-i_qj`f#s{Z9S|k4GMyktrMhsJRtxY% zNXZ(f?jO|B;!IwEb^xf&1YQUlw%M~eZKHGCSc65+z0pe?Z+*wX;bo7(7MRYV7bl3VjxaR0kq{3aB9Qley^bL0J z{a#YEHP0cBDMv2Jk|R~l!?`Pq$W4X9&%`}4^n6+&f?NkJmhM2Pb6}IBsF%1$_4smE zj0){NPwecQNmhkRRA0mKrO}+m)W=lr2Gbr!LCINl!~KspTI8_#zj|yg$SQ90j@|v6 zX;~bCRX5VO<2&};%MOV`z3_%qo#X2H&E-wx>bU2O5F1R)dav^>+zF3)#kX9J=6zs@ z1dPM}D~dbLn?qVkm}Mm;qb%!2IjoJ^GsXnv;8fUyH8;;Hto88GOs%1PyHyzEZbPZ) zo3D^y&BJIkmpK~Ni^F_bLaQFZ*)lB>Fr2xB_KqAN4B%3&d5Qx?!V|9A+P`sIsICIH zWaO_|Jr+wy#77I_R64uazMePDWPWD~_o4F)u)A|icla~f2>J`DE4{70M_aCg$?p0*F*Jql%VTK+Db|FJitdrR!6(3GiLJtTOK6PR3WByn(w)6U(>wP$2)rWI= zZsr$svV_4H??`1IDU=-OhBl;QRk=b)Jfs{hrT4nx7l#0cDQ8*L=v-1vbepSo502i_ zhw5%)i-$ywjoS34(t#kaY}R0rL_Ws`G}U*uXOtZ*&TY(-O3v7TK7fHe^F>ydgfzVNryK-ggBk(+3`@)XfP(C#f_E z+@ptM$Rdf^Yd-B|R^TWgYUkqlo2$OvmmPt>6iZTXx)7*s~B3u@&@gRh{Mr`jXDY z6RV%)E``mhi!W;w-r_7AyvLb8c$BpE;WoV|sv&uwzbJNRm=%|qSJ<#jRb-#^>NNADsx z2H!4l-;xCr#%X){H0+l@K8dM#ytU8@euxq2uj}(`&x*QeLT~ZJp3^?E<-^*LodXtU zgb@W+$yTsC2n^~7Pj+|ouI53ZwV1Nh{GfPQvbz;uRs)YLaCf#zT1K>T$_I2%zfOI5 zu&m3bqK%oN?jz6GqeLZ8NGDvDP}lrg(B_5a`Bf( z*N4I+J*SulaH??&v|sx~`@?qKcu=$8MidwzDK}Z~^MDe3C`rY&L!*z}4UCW^-;x3E z1$x&=+gD}qJ6HNS@9J5VyXSUkOrz1NTJhB@Q>0q4uLU^m{`X%+x}U<5t_NdR)CQOu z;N4YfJjK~dVx`|dbhSZZ7;Q1*I!#%!p}+Cj=dZ{x36!hMErDQ7?sDhxhzH!fRkg0r z*^P7<=v|#1 zqw%7oFC7WlSB7jU<~e95BeQTz?9yS#WNb_xZ3zr=mE+S<-uIR!Z?2xuS*k89Y;II( zlsff*Yh>ts2HBZ>;oP7?7abpkA`xB~>SXTLgCZ3VXguYL{g@;>Bc;JU(g^aVW)7!J zQesZ?uCGU$(wpmD1z`G^Qm*%Dqi-9L*ydaYcu2`k7Z@z%5JuST6JvW&kTXOgO7pHZ zJ&UZGj*d)p#%yJqKG%T}23v*^m#=ii?B5Nldcnt}-DZDq^f`F>^?+WwLK(2pM4>XZ z=Fqwjx6~w6frv3xpV-IluSN_^;>YxGlhkwUsCLYLBao>y8_Q08n!0G$_cE>sOoJD6B zWQV=~*m*(w3m`|eRaNEVzPEGRmtLR0XJPoqEf)gbZf;+AT`hQy`K|H`i`st~ofo`- z7;aI1@-3(xptUF>Ec|16?@8aiYt*CXScfgUuxYP))ZFUu{NDM#JJ$eg=Vspe_Hnp( z$M^|f^}QM?sub0P@-ySV`7SG3<1+76xPG}q?=@fDqF);35N}>2FlO*Zrr;zO7*AW$9-2nSwvaD3gee=@}2u+p}~BO zu=6=ICf!mky^tZ@&;iNeHksLmp<>;Gi911TgYom0m-jWnjdV%me&(ymP1>HR*xpj zbcWQ&7kOSL8hcl=YHLBqa#4#8hRpA`b2s>QJPFFeO_E0{_*mK zfPfgn!mg+0$FZu(Qet88=Ses`02_I$tid8$@zvQ3xoB_0x}J>e^zteXrP4$3EeliX$FdqX~T?`Q6{KvfllH~1D? z(}x@*9RkE~rUCjAC%zz2{kO;6>6<^sZydp)$-HT5{v11!&%J&&PF1 zJ@Ks9DW>a-flUtz%WpG$zNnd=A}OETNG+3!m;xH`sDcT!6Zi1Ts5XM6Do~!V%xF;9 zsAL!95C_2n&d`NByiD6fOveJrar?*Hs#|wsl@0k}j|-#JE5PY$btaOJ#}if)FDw z;kILy^As03G%u#p+29KjJ>e7gdOa>V`(=t(*?#(&Mw6-DcElt1d^#s3%cp?Fi34{{ z9qq6W?)rS1L$JQ(e)@eIZjtRVnzwu6qxuHZ;B$jr8MR374yQ-nsh;OuaTZUqlBzhz zsym>`p2?`h5#zQ?ubRXq>dt0(lzxv_Yo(=)&@rE>5bhR2O6<6gu$swILt)aL3~7eY z{3)y7wzb_m=|K?unZC9oC@cw4G}z_)vW2dCc`H|iU)X}a5YgJBN{e`sN&mnsOYF$v z)KjvZpL4n{7Yk_@Tauq6{m6r|!Jgj6FAz)JMJa%C(o(t4-ETcH=5OFalSfS2BR~wC%?ymwn&Qp;M~$wWS=zF)CfE; zKsqMz1MS6>>uqkHBxMV#a64;wt8lydZh2I4;Dd*J zInmA!O&eVv!|YFU+S)RehWCjjh9mw(lQzje|!4 zxd_Xhno0L~(Mh5=#R2*Da2o_Y*p9Q&B*Cfnh9*()v^LelcFGR-$kWy*j@{!yX3{*j za2{n56Md?<73rrr%>xN6U4oHC%RZuObGOga3GGPwifxdI6evs zXhgq^Y}khol_DH7?i(htdglcg4(e*Z#5NJ01Yvgxa~t-GTXhGyhy4gUM?_0L*!h!( zmXWoabf;{XPfI+yUr6Kk5Y_?fhtp5jOO}WXjn-aqwes!TM!OQ zw#buhNk?bEj2qN_a@vQBa`)J@petQ72zU^czG_l_2W{&bO?MN~1~5+_^5Bn$-zXZZ zg|yLJePF+{W8^o`xAwj|5YXg#Ay+SSz|(HzDsFD*i0kkb&(&{}C&oq^^;ITmk5oeW z#sGH7d9k9Ut-XaYm%m|Zx9NW1;q!_fJDsTarO#gg^2tq0LQfH%4@K|}bRCtCt>Y_J z2e`lT^!o0Xec$wK&k0YsH;5W7ZWW=!-TbU74XWi8I@aW;0zyn98CG%qIlS|9+Br%h zHCKXusKeHLLqcwaojulZ)YCSwP&Tz6RQ$dUrqg+TLK;z`MGm+ETVh8W8G2Kay7PAu z&w03rLED)&l(lCK8UB{ub3UE$+VNeJIgAP4iL@Alfnt(yO zLRZQ!M*As!r9;y7aICk_o;%{~&SzpWtet`Jg$jqqIcDsHIcH`pZ}s^2;7u`HLL zI`W{JkgILqR$T|z7OK(x#mR2jW6MplWWm*N*#eKuGX?){?|9~bn3+`Nd!X7MyP^LD zd&hBFKet(bi28GD_5Yu}<9`jf`*-20Uitj<688*Fe=6|L?AhEb0<#FrA~1`i^>aod0S+ zUH_Zkdg?Fs(+p03HvDhynbv_Nk8GTdV?GBkUG5+GKaOmiS(j1&Xa1i@HZGW^tNkng56Z%6 z`iy+j8R63hPOARp{vZ55EC0@$3H&{i&hTgSoq0v!kAp1#;ULNx`b>QB|KR@=Jk#Ep zdQ5Iy_4z~aU)npv|FhlG{F!=Z^!oGuX}xCfnkn}O|Bv?lEHH!jOgsOBKP_N7Su<17 oKa*#3vk1&0FpIz}0<#FrA~1` Date: Fri, 15 Apr 2022 13:21:05 -0700 Subject: [PATCH 02/27] REFACTOR: beginning of refactor archived old code, new files, basic model rough-in --- .../041222pytorchdraft-checkpoint.ipynb | 495 ++++++++++++++++++ .../041422pytorchdraft-checkpoint.ipynb | 495 ++++++++++++++++++ examples/refactor/041222pytorchdraft.ipynb | 2 +- examples/refactor/041422pytorchdraft.ipynb | 495 ++++++++++++++++++ mmvec/multimodal.py | 297 +---------- mmvec/old_multimodal.py | 283 ++++++++++ mmvec/q2/__init__.py | 12 - mmvec/q2/_method.py | 126 ----- mmvec/q2/_stats.py | 30 -- mmvec/q2/_summary.py | 121 ----- mmvec/q2/_transformer.py | 36 -- mmvec/q2/_transformers.py | 0 mmvec/q2/_visualizers.py | 88 ---- mmvec/q2/assets/index.html | 28 - mmvec/q2/plugin_setup.py | 252 --------- mmvec/q2/tests/test_method.py | 98 ---- mmvec/q2/tests/test_visualizers.py | 97 ---- 17 files changed, 1791 insertions(+), 1164 deletions(-) create mode 100644 examples/refactor/.ipynb_checkpoints/041222pytorchdraft-checkpoint.ipynb create mode 100644 examples/refactor/.ipynb_checkpoints/041422pytorchdraft-checkpoint.ipynb create mode 100644 examples/refactor/041422pytorchdraft.ipynb create mode 100644 mmvec/old_multimodal.py delete mode 100644 mmvec/q2/__init__.py delete mode 100644 mmvec/q2/_method.py delete mode 100644 mmvec/q2/_stats.py delete mode 100644 mmvec/q2/_summary.py delete mode 100644 mmvec/q2/_transformer.py create mode 100644 mmvec/q2/_transformers.py delete mode 100644 mmvec/q2/_visualizers.py delete mode 100644 mmvec/q2/assets/index.html delete mode 100644 mmvec/q2/plugin_setup.py delete mode 100644 mmvec/q2/tests/test_method.py delete mode 100644 mmvec/q2/tests/test_visualizers.py diff --git a/examples/refactor/.ipynb_checkpoints/041222pytorchdraft-checkpoint.ipynb b/examples/refactor/.ipynb_checkpoints/041222pytorchdraft-checkpoint.ipynb new file mode 100644 index 0000000..019843d --- /dev/null +++ b/examples/refactor/.ipynb_checkpoints/041222pytorchdraft-checkpoint.ipynb @@ -0,0 +1,495 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "213bcdfc", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset\n", + "from torch.distributions import Multinomial\n", + "import biom" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "382bb9ce", + "metadata": {}, + "outputs": [], + "source": [ + "# some example data\n", + "microbes = biom.load_table(\"./soil_microbes.biom\")\n", + "metabolites = biom.load_table(\"./soil_metabolites.biom\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "96fac3bf", + "metadata": {}, + "outputs": [], + "source": [ + "class MicrobeMetaboliteData(Dataset):\n", + " def __init__(self, microbes: biom.table, metabolites: biom.table):\n", + " # arrange\n", + " self.microbes = microbes.to_dataframe().T \n", + " self.metabolites = metabolites.to_dataframe().T\n", + " \n", + " # only samples that have results\n", + " self.microbes = self.microbes.loc[self.metabolites.index]\n", + " \n", + " # convert to tensors/final form\n", + " self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)\n", + " self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)\n", + " \n", + " # counts\n", + " self.microbe_count = self.microbes.shape[1]\n", + " self.metabolite_count = self.metabolites.shape[1]\n", + " \n", + " # relative frequencies\n", + " self.microbe_relative_frequency = (self.microbes.T\n", + " / self.microbes.sum(1)\n", + " ).T\n", + " \n", + " self.metabolite_relative_frequency = (self.metabolites.T\n", + " / self.metabolites.sum(1)\n", + " ).T\n", + " \n", + " self.total_microbe_observations = self.microbes.sum()\n", + " \n", + " def __len__(self):\n", + " return self.total_microbe_observations" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "234ccc47", + "metadata": {}, + "outputs": [], + "source": [ + "example_data = MicrobeMetaboliteData(microbes, metabolites)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0ab12e60", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "424846" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example_data.total_microbe_observations.item()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f106a231", + "metadata": {}, + "outputs": [], + "source": [ + "class MMVec(nn.Module):\n", + " def __init__(self, num_microbes, num_metabolites, latent_dim):\n", + " super().__init__()\n", + " #\n", + " self.encoder = nn.Embedding(num_microbes, latent_dim)\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(latent_dim, num_metabolites),\n", + " # [batch, sample, metabolite]\n", + " nn.Softmax(dim=2)\n", + " )\n", + " \n", + " # X = batch_size of microbe indexes\n", + " # Y = expected metabolite data\n", + " def forward(self, X, Y):\n", + " \n", + " # pass our random draws to our embedding\n", + " z = self.encoder(X)\n", + " \n", + " # from latent dimensions in embedding through\n", + " # our linear function to predicted metabolite frequencies which\n", + " # we then normalize with softmax\n", + " y_pred = self.decoder(z)\n", + " \n", + " # total_count=0 and validate_args=False allows skipping total count when calling log_prob\n", + " # as there having floating point issues leading to \"incorrect\" total counts.\n", + " # This multinomial is generated from the output of the single\n", + " forward_dist = Multinomial(total_count=0,\n", + " validate_args=False,\n", + " probs=y_pred)\n", + " \n", + " # the log probability of drawing our expected results from our \"predictions\"\n", + " forward_dist = forward_dist.log_prob(Y)\n", + " \n", + " # get sample loss, a sample in each \"row\"/ zeroeth dimension of the tensor\n", + " forward_dist = forward_dist.mean(0)\n", + " \n", + " # total log probability loss in regards to all samples\n", + " lp = forward_dist.mean()\n", + "\n", + " return lp" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b74bdf61", + "metadata": {}, + "outputs": [], + "source": [ + "mmvec_model = MMVec(example_data.microbe_count, example_data.metabolite_count, 15)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cbc8d647", + "metadata": {}, + "outputs": [], + "source": [ + "def train_loop(dataset, model, optimizer, batch_size):\n", + " \n", + " # because we are wanting to look at all of the samples together we are having to \n", + " # handle our own batching for now. This method currently leads to slight over-\n", + " # sampling but can be refined.\n", + " n_batches = torch.div(dataset.total_microbe_observations.item(),\n", + " batch_size,\n", + " rounding_mode = 'floor') + 1\n", + " \n", + " # We will want to implement batching functionality later for\n", + " # paralizability, but for now running on cpu this works.\n", + " for batch in range(n_batches * epochs):\n", + " \n", + " # the draws we will be training each batch on that will\n", + " # be fed to all samples in our model. This step will probably be\n", + " # moved to a sampler or collate_fn somewhere in the dataset/dataloader\n", + " # but how exactly that will work is not clear at the moment\n", + " draws = torch.multinomial(dataset.microbe_relative_frequency,\n", + " batch_size,\n", + " replacement=True).T\n", + " \n", + " # \"forward step\", our model generates our \"predictions\", so there is no need to\n", + " # call `forward` separately.\n", + " lp = model(draws,\n", + " dataset.metabolite_relative_frequency)\n", + " \n", + " # this location is idiomatic but flexible\n", + " optimizer.zero_grad()\n", + " \n", + " # the typical training bit.\n", + " lp.backward()\n", + " optimizer.step()\n", + " \n", + " if batch % 100 == 0:\n", + " print(f\"loss: {lp.item()}\\nBatch #: {batch}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfb75b21", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: -4.114527225494385\n", + "Batch #: 0\n", + "loss: -3.6144325733184814\n", + "Batch #: 100\n", + "loss: -3.0469698905944824\n", + "Batch #: 200\n", + "loss: -2.70939564704895\n", + "Batch #: 300\n", + "loss: -2.5499744415283203\n", + "Batch #: 400\n", + "loss: -2.473045587539673\n", + "Batch #: 500\n", + "loss: -2.4374732971191406\n", + "Batch #: 600\n", + "loss: -2.421781539916992\n", + "Batch #: 700\n", + "loss: -2.4101920127868652\n", + "Batch #: 800\n", + "loss: -2.4041030406951904\n", + "Batch #: 900\n", + "loss: -2.4012131690979004\n", + "Batch #: 1000\n", + "loss: -2.397974967956543\n", + "Batch #: 1100\n", + "loss: -2.3931915760040283\n", + "Batch #: 1200\n", + "loss: -2.3923048973083496\n", + "Batch #: 1300\n", + "loss: -2.389982223510742\n", + "Batch #: 1400\n", + "loss: -2.3868303298950195\n", + "Batch #: 1500\n", + "loss: -2.3855628967285156\n", + "Batch #: 1600\n", + "loss: -2.382643222808838\n", + "Batch #: 1700\n", + "loss: -2.381664991378784\n", + "Batch #: 1800\n", + "loss: -2.3774473667144775\n", + "Batch #: 1900\n", + "loss: -2.378610372543335\n", + "Batch #: 2000\n", + "loss: -2.3776485919952393\n", + "Batch #: 2100\n", + "loss: -2.376375675201416\n", + "Batch #: 2200\n", + "loss: -2.3723671436309814\n", + "Batch #: 2300\n", + "loss: -2.372851848602295\n", + "Batch #: 2400\n", + "loss: -2.373134136199951\n", + "Batch #: 2500\n", + "loss: -2.3704051971435547\n", + "Batch #: 2600\n", + "loss: -2.37052059173584\n", + "Batch #: 2700\n", + "loss: -2.371293306350708\n", + "Batch #: 2800\n", + "loss: -2.3711659908294678\n", + "Batch #: 2900\n", + "loss: -2.3693435192108154\n", + "Batch #: 3000\n", + "loss: -2.370833396911621\n", + "Batch #: 3100\n", + "loss: -2.36956787109375\n", + "Batch #: 3200\n", + "loss: -2.3683981895446777\n", + "Batch #: 3300\n", + "loss: -2.368025064468384\n", + "Batch #: 3400\n", + "loss: -2.3673665523529053\n", + "Batch #: 3500\n", + "loss: -2.3669538497924805\n", + "Batch #: 3600\n", + "loss: -2.364877700805664\n", + "Batch #: 3700\n", + "loss: -2.3676393032073975\n", + "Batch #: 3800\n", + "loss: -2.3655707836151123\n", + "Batch #: 3900\n", + "loss: -2.365952253341675\n", + "Batch #: 4000\n", + "loss: -2.366527557373047\n", + "Batch #: 4100\n", + "loss: -2.364421844482422\n", + "Batch #: 4200\n", + "loss: -2.363978385925293\n", + "Batch #: 4300\n", + "loss: -2.3649704456329346\n", + "Batch #: 4400\n", + "loss: -2.364382743835449\n", + "Batch #: 4500\n", + "loss: -2.361299991607666\n", + "Batch #: 4600\n", + "loss: -2.3609752655029297\n", + "Batch #: 4700\n", + "loss: -2.3623459339141846\n", + "Batch #: 4800\n", + "loss: -2.3606176376342773\n", + "Batch #: 4900\n", + "loss: -2.3621227741241455\n", + "Batch #: 5000\n", + "loss: -2.3601856231689453\n", + "Batch #: 5100\n", + "loss: -2.3616325855255127\n", + "Batch #: 5200\n", + "loss: -2.3607864379882812\n", + "Batch #: 5300\n", + "loss: -2.3603267669677734\n", + "Batch #: 5400\n", + "loss: -2.3611979484558105\n", + "Batch #: 5500\n", + "loss: -2.36138653755188\n", + "Batch #: 5600\n", + "loss: -2.3617565631866455\n", + "Batch #: 5700\n", + "loss: -2.3602635860443115\n", + "Batch #: 5800\n", + "loss: -2.3588624000549316\n", + "Batch #: 5900\n", + "loss: -2.363048791885376\n", + "Batch #: 6000\n", + "loss: -2.357430934906006\n", + "Batch #: 6100\n", + "loss: -2.359692335128784\n", + "Batch #: 6200\n", + "loss: -2.359476327896118\n", + "Batch #: 6300\n", + "loss: -2.358708381652832\n", + "Batch #: 6400\n", + "loss: -2.3578848838806152\n", + "Batch #: 6500\n", + "loss: -2.3591620922088623\n", + "Batch #: 6600\n", + "loss: -2.3596458435058594\n", + "Batch #: 6700\n", + "loss: -2.358290672302246\n", + "Batch #: 6800\n", + "loss: -2.3569066524505615\n", + "Batch #: 6900\n", + "loss: -2.3586177825927734\n", + "Batch #: 7000\n", + "loss: -2.359415054321289\n", + "Batch #: 7100\n", + "loss: -2.358649969100952\n", + "Batch #: 7200\n", + "loss: -2.35966420173645\n", + "Batch #: 7300\n", + "loss: -2.358867883682251\n", + "Batch #: 7400\n", + "loss: -2.3568341732025146\n", + "Batch #: 7500\n", + "loss: -2.3596749305725098\n", + "Batch #: 7600\n", + "loss: -2.359412670135498\n", + "Batch #: 7700\n", + "loss: -2.357198476791382\n", + "Batch #: 7800\n", + "loss: -2.358001947402954\n", + "Batch #: 7900\n", + "loss: -2.3569891452789307\n", + "Batch #: 8000\n", + "loss: -2.3587193489074707\n", + "Batch #: 8100\n", + "loss: -2.3581130504608154\n", + "Batch #: 8200\n", + "loss: -2.3578381538391113\n", + "Batch #: 8300\n", + "loss: -2.357231855392456\n", + "Batch #: 8400\n", + "loss: -2.3578529357910156\n", + "Batch #: 8500\n", + "loss: -2.3557262420654297\n", + "Batch #: 8600\n", + "loss: -2.355126142501831\n", + "Batch #: 8700\n", + "loss: -2.3567700386047363\n", + "Batch #: 8800\n", + "loss: -2.3553476333618164\n", + "Batch #: 8900\n", + "loss: -2.356520175933838\n", + "Batch #: 9000\n", + "loss: -2.3572936058044434\n", + "Batch #: 9100\n", + "loss: -2.358710527420044\n", + "Batch #: 9200\n", + "loss: -2.3547816276550293\n", + "Batch #: 9300\n", + "loss: -2.3565027713775635\n", + "Batch #: 9400\n", + "loss: -2.3561108112335205\n", + "Batch #: 9500\n", + "loss: -2.356635808944702\n", + "Batch #: 9600\n", + "loss: -2.356121301651001\n", + "Batch #: 9700\n", + "loss: -2.3586411476135254\n", + "Batch #: 9800\n", + "loss: -2.3572912216186523\n", + "Batch #: 9900\n", + "loss: -2.35567045211792\n", + "Batch #: 10000\n", + "loss: -2.3584144115448\n", + "Batch #: 10100\n", + "loss: -2.3562276363372803\n", + "Batch #: 10200\n", + "loss: -2.3546085357666016\n", + "Batch #: 10300\n", + "loss: -2.3559350967407227\n", + "Batch #: 10400\n", + "loss: -2.356455087661743\n", + "Batch #: 10500\n", + "loss: -2.3574140071868896\n", + "Batch #: 10600\n", + "loss: -2.3562002182006836\n", + "Batch #: 10700\n", + "loss: -2.35746169090271\n", + "Batch #: 10800\n", + "loss: -2.3548736572265625\n", + "Batch #: 10900\n", + "loss: -2.3564090728759766\n", + "Batch #: 11000\n", + "loss: -2.3564658164978027\n", + "Batch #: 11100\n", + "loss: -2.3554699420928955\n", + "Batch #: 11200\n", + "loss: -2.3563244342803955\n", + "Batch #: 11300\n", + "loss: -2.357598066329956\n", + "Batch #: 11400\n", + "loss: -2.35477614402771\n", + "Batch #: 11500\n", + "loss: -2.3572442531585693\n", + "Batch #: 11600\n", + "loss: -2.357273817062378\n", + "Batch #: 11700\n", + "loss: -2.3560562133789062\n", + "Batch #: 11800\n", + "loss: -2.355698823928833\n", + "Batch #: 11900\n", + "loss: -2.3559463024139404\n", + "Batch #: 12000\n", + "loss: -2.35664439201355\n", + "Batch #: 12100\n", + "loss: -2.355379104614258\n", + "Batch #: 12200\n", + "loss: -2.354964256286621\n", + "Batch #: 12300\n" + ] + } + ], + "source": [ + "learning_rate = 1e-3\n", + "batch_size = 500\n", + "epochs = 25\n", + "optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)\n", + "\n", + "# run the training loop \n", + "train_loop(dataset=example_data, model=mmvec_model, optimizer=optimizer, batch_size=batch_size)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/refactor/.ipynb_checkpoints/041422pytorchdraft-checkpoint.ipynb b/examples/refactor/.ipynb_checkpoints/041422pytorchdraft-checkpoint.ipynb new file mode 100644 index 0000000..019843d --- /dev/null +++ b/examples/refactor/.ipynb_checkpoints/041422pytorchdraft-checkpoint.ipynb @@ -0,0 +1,495 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "213bcdfc", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset\n", + "from torch.distributions import Multinomial\n", + "import biom" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "382bb9ce", + "metadata": {}, + "outputs": [], + "source": [ + "# some example data\n", + "microbes = biom.load_table(\"./soil_microbes.biom\")\n", + "metabolites = biom.load_table(\"./soil_metabolites.biom\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "96fac3bf", + "metadata": {}, + "outputs": [], + "source": [ + "class MicrobeMetaboliteData(Dataset):\n", + " def __init__(self, microbes: biom.table, metabolites: biom.table):\n", + " # arrange\n", + " self.microbes = microbes.to_dataframe().T \n", + " self.metabolites = metabolites.to_dataframe().T\n", + " \n", + " # only samples that have results\n", + " self.microbes = self.microbes.loc[self.metabolites.index]\n", + " \n", + " # convert to tensors/final form\n", + " self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)\n", + " self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)\n", + " \n", + " # counts\n", + " self.microbe_count = self.microbes.shape[1]\n", + " self.metabolite_count = self.metabolites.shape[1]\n", + " \n", + " # relative frequencies\n", + " self.microbe_relative_frequency = (self.microbes.T\n", + " / self.microbes.sum(1)\n", + " ).T\n", + " \n", + " self.metabolite_relative_frequency = (self.metabolites.T\n", + " / self.metabolites.sum(1)\n", + " ).T\n", + " \n", + " self.total_microbe_observations = self.microbes.sum()\n", + " \n", + " def __len__(self):\n", + " return self.total_microbe_observations" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "234ccc47", + "metadata": {}, + "outputs": [], + "source": [ + "example_data = MicrobeMetaboliteData(microbes, metabolites)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0ab12e60", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "424846" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example_data.total_microbe_observations.item()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f106a231", + "metadata": {}, + "outputs": [], + "source": [ + "class MMVec(nn.Module):\n", + " def __init__(self, num_microbes, num_metabolites, latent_dim):\n", + " super().__init__()\n", + " #\n", + " self.encoder = nn.Embedding(num_microbes, latent_dim)\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(latent_dim, num_metabolites),\n", + " # [batch, sample, metabolite]\n", + " nn.Softmax(dim=2)\n", + " )\n", + " \n", + " # X = batch_size of microbe indexes\n", + " # Y = expected metabolite data\n", + " def forward(self, X, Y):\n", + " \n", + " # pass our random draws to our embedding\n", + " z = self.encoder(X)\n", + " \n", + " # from latent dimensions in embedding through\n", + " # our linear function to predicted metabolite frequencies which\n", + " # we then normalize with softmax\n", + " y_pred = self.decoder(z)\n", + " \n", + " # total_count=0 and validate_args=False allows skipping total count when calling log_prob\n", + " # as there having floating point issues leading to \"incorrect\" total counts.\n", + " # This multinomial is generated from the output of the single\n", + " forward_dist = Multinomial(total_count=0,\n", + " validate_args=False,\n", + " probs=y_pred)\n", + " \n", + " # the log probability of drawing our expected results from our \"predictions\"\n", + " forward_dist = forward_dist.log_prob(Y)\n", + " \n", + " # get sample loss, a sample in each \"row\"/ zeroeth dimension of the tensor\n", + " forward_dist = forward_dist.mean(0)\n", + " \n", + " # total log probability loss in regards to all samples\n", + " lp = forward_dist.mean()\n", + "\n", + " return lp" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b74bdf61", + "metadata": {}, + "outputs": [], + "source": [ + "mmvec_model = MMVec(example_data.microbe_count, example_data.metabolite_count, 15)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cbc8d647", + "metadata": {}, + "outputs": [], + "source": [ + "def train_loop(dataset, model, optimizer, batch_size):\n", + " \n", + " # because we are wanting to look at all of the samples together we are having to \n", + " # handle our own batching for now. This method currently leads to slight over-\n", + " # sampling but can be refined.\n", + " n_batches = torch.div(dataset.total_microbe_observations.item(),\n", + " batch_size,\n", + " rounding_mode = 'floor') + 1\n", + " \n", + " # We will want to implement batching functionality later for\n", + " # paralizability, but for now running on cpu this works.\n", + " for batch in range(n_batches * epochs):\n", + " \n", + " # the draws we will be training each batch on that will\n", + " # be fed to all samples in our model. This step will probably be\n", + " # moved to a sampler or collate_fn somewhere in the dataset/dataloader\n", + " # but how exactly that will work is not clear at the moment\n", + " draws = torch.multinomial(dataset.microbe_relative_frequency,\n", + " batch_size,\n", + " replacement=True).T\n", + " \n", + " # \"forward step\", our model generates our \"predictions\", so there is no need to\n", + " # call `forward` separately.\n", + " lp = model(draws,\n", + " dataset.metabolite_relative_frequency)\n", + " \n", + " # this location is idiomatic but flexible\n", + " optimizer.zero_grad()\n", + " \n", + " # the typical training bit.\n", + " lp.backward()\n", + " optimizer.step()\n", + " \n", + " if batch % 100 == 0:\n", + " print(f\"loss: {lp.item()}\\nBatch #: {batch}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfb75b21", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: -4.114527225494385\n", + "Batch #: 0\n", + "loss: -3.6144325733184814\n", + "Batch #: 100\n", + "loss: -3.0469698905944824\n", + "Batch #: 200\n", + "loss: -2.70939564704895\n", + "Batch #: 300\n", + "loss: -2.5499744415283203\n", + "Batch #: 400\n", + "loss: -2.473045587539673\n", + "Batch #: 500\n", + "loss: -2.4374732971191406\n", + "Batch #: 600\n", + "loss: -2.421781539916992\n", + "Batch #: 700\n", + "loss: -2.4101920127868652\n", + "Batch #: 800\n", + "loss: -2.4041030406951904\n", + "Batch #: 900\n", + "loss: -2.4012131690979004\n", + "Batch #: 1000\n", + "loss: -2.397974967956543\n", + "Batch #: 1100\n", + "loss: -2.3931915760040283\n", + "Batch #: 1200\n", + "loss: -2.3923048973083496\n", + "Batch #: 1300\n", + "loss: -2.389982223510742\n", + "Batch #: 1400\n", + "loss: -2.3868303298950195\n", + "Batch #: 1500\n", + "loss: -2.3855628967285156\n", + "Batch #: 1600\n", + "loss: -2.382643222808838\n", + "Batch #: 1700\n", + "loss: -2.381664991378784\n", + "Batch #: 1800\n", + "loss: -2.3774473667144775\n", + "Batch #: 1900\n", + "loss: -2.378610372543335\n", + "Batch #: 2000\n", + "loss: -2.3776485919952393\n", + "Batch #: 2100\n", + "loss: -2.376375675201416\n", + "Batch #: 2200\n", + "loss: -2.3723671436309814\n", + "Batch #: 2300\n", + "loss: -2.372851848602295\n", + "Batch #: 2400\n", + "loss: -2.373134136199951\n", + "Batch #: 2500\n", + "loss: -2.3704051971435547\n", + "Batch #: 2600\n", + "loss: -2.37052059173584\n", + "Batch #: 2700\n", + "loss: -2.371293306350708\n", + "Batch #: 2800\n", + "loss: -2.3711659908294678\n", + "Batch #: 2900\n", + "loss: -2.3693435192108154\n", + "Batch #: 3000\n", + "loss: -2.370833396911621\n", + "Batch #: 3100\n", + "loss: -2.36956787109375\n", + "Batch #: 3200\n", + "loss: -2.3683981895446777\n", + "Batch #: 3300\n", + "loss: -2.368025064468384\n", + "Batch #: 3400\n", + "loss: -2.3673665523529053\n", + "Batch #: 3500\n", + "loss: -2.3669538497924805\n", + "Batch #: 3600\n", + "loss: -2.364877700805664\n", + "Batch #: 3700\n", + "loss: -2.3676393032073975\n", + "Batch #: 3800\n", + "loss: -2.3655707836151123\n", + "Batch #: 3900\n", + "loss: -2.365952253341675\n", + "Batch #: 4000\n", + "loss: -2.366527557373047\n", + "Batch #: 4100\n", + "loss: -2.364421844482422\n", + "Batch #: 4200\n", + "loss: -2.363978385925293\n", + "Batch #: 4300\n", + "loss: -2.3649704456329346\n", + "Batch #: 4400\n", + "loss: -2.364382743835449\n", + "Batch #: 4500\n", + "loss: -2.361299991607666\n", + "Batch #: 4600\n", + "loss: -2.3609752655029297\n", + "Batch #: 4700\n", + "loss: -2.3623459339141846\n", + "Batch #: 4800\n", + "loss: -2.3606176376342773\n", + "Batch #: 4900\n", + "loss: -2.3621227741241455\n", + "Batch #: 5000\n", + "loss: -2.3601856231689453\n", + "Batch #: 5100\n", + "loss: -2.3616325855255127\n", + "Batch #: 5200\n", + "loss: -2.3607864379882812\n", + "Batch #: 5300\n", + "loss: -2.3603267669677734\n", + "Batch #: 5400\n", + "loss: -2.3611979484558105\n", + "Batch #: 5500\n", + "loss: -2.36138653755188\n", + "Batch #: 5600\n", + "loss: -2.3617565631866455\n", + "Batch #: 5700\n", + "loss: -2.3602635860443115\n", + "Batch #: 5800\n", + "loss: -2.3588624000549316\n", + "Batch #: 5900\n", + "loss: -2.363048791885376\n", + "Batch #: 6000\n", + "loss: -2.357430934906006\n", + "Batch #: 6100\n", + "loss: -2.359692335128784\n", + "Batch #: 6200\n", + "loss: -2.359476327896118\n", + "Batch #: 6300\n", + "loss: -2.358708381652832\n", + "Batch #: 6400\n", + "loss: -2.3578848838806152\n", + "Batch #: 6500\n", + "loss: -2.3591620922088623\n", + "Batch #: 6600\n", + "loss: -2.3596458435058594\n", + "Batch #: 6700\n", + "loss: -2.358290672302246\n", + "Batch #: 6800\n", + "loss: -2.3569066524505615\n", + "Batch #: 6900\n", + "loss: -2.3586177825927734\n", + "Batch #: 7000\n", + "loss: -2.359415054321289\n", + "Batch #: 7100\n", + "loss: -2.358649969100952\n", + "Batch #: 7200\n", + "loss: -2.35966420173645\n", + "Batch #: 7300\n", + "loss: -2.358867883682251\n", + "Batch #: 7400\n", + "loss: -2.3568341732025146\n", + "Batch #: 7500\n", + "loss: -2.3596749305725098\n", + "Batch #: 7600\n", + "loss: -2.359412670135498\n", + "Batch #: 7700\n", + "loss: -2.357198476791382\n", + "Batch #: 7800\n", + "loss: -2.358001947402954\n", + "Batch #: 7900\n", + "loss: -2.3569891452789307\n", + "Batch #: 8000\n", + "loss: -2.3587193489074707\n", + "Batch #: 8100\n", + "loss: -2.3581130504608154\n", + "Batch #: 8200\n", + "loss: -2.3578381538391113\n", + "Batch #: 8300\n", + "loss: -2.357231855392456\n", + "Batch #: 8400\n", + "loss: -2.3578529357910156\n", + "Batch #: 8500\n", + "loss: -2.3557262420654297\n", + "Batch #: 8600\n", + "loss: -2.355126142501831\n", + "Batch #: 8700\n", + "loss: -2.3567700386047363\n", + "Batch #: 8800\n", + "loss: -2.3553476333618164\n", + "Batch #: 8900\n", + "loss: -2.356520175933838\n", + "Batch #: 9000\n", + "loss: -2.3572936058044434\n", + "Batch #: 9100\n", + "loss: -2.358710527420044\n", + "Batch #: 9200\n", + "loss: -2.3547816276550293\n", + "Batch #: 9300\n", + "loss: -2.3565027713775635\n", + "Batch #: 9400\n", + "loss: -2.3561108112335205\n", + "Batch #: 9500\n", + "loss: -2.356635808944702\n", + "Batch #: 9600\n", + "loss: -2.356121301651001\n", + "Batch #: 9700\n", + "loss: -2.3586411476135254\n", + "Batch #: 9800\n", + "loss: -2.3572912216186523\n", + "Batch #: 9900\n", + "loss: -2.35567045211792\n", + "Batch #: 10000\n", + "loss: -2.3584144115448\n", + "Batch #: 10100\n", + "loss: -2.3562276363372803\n", + "Batch #: 10200\n", + "loss: -2.3546085357666016\n", + "Batch #: 10300\n", + "loss: -2.3559350967407227\n", + "Batch #: 10400\n", + "loss: -2.356455087661743\n", + "Batch #: 10500\n", + "loss: -2.3574140071868896\n", + "Batch #: 10600\n", + "loss: -2.3562002182006836\n", + "Batch #: 10700\n", + "loss: -2.35746169090271\n", + "Batch #: 10800\n", + "loss: -2.3548736572265625\n", + "Batch #: 10900\n", + "loss: -2.3564090728759766\n", + "Batch #: 11000\n", + "loss: -2.3564658164978027\n", + "Batch #: 11100\n", + "loss: -2.3554699420928955\n", + "Batch #: 11200\n", + "loss: -2.3563244342803955\n", + "Batch #: 11300\n", + "loss: -2.357598066329956\n", + "Batch #: 11400\n", + "loss: -2.35477614402771\n", + "Batch #: 11500\n", + "loss: -2.3572442531585693\n", + "Batch #: 11600\n", + "loss: -2.357273817062378\n", + "Batch #: 11700\n", + "loss: -2.3560562133789062\n", + "Batch #: 11800\n", + "loss: -2.355698823928833\n", + "Batch #: 11900\n", + "loss: -2.3559463024139404\n", + "Batch #: 12000\n", + "loss: -2.35664439201355\n", + "Batch #: 12100\n", + "loss: -2.355379104614258\n", + "Batch #: 12200\n", + "loss: -2.354964256286621\n", + "Batch #: 12300\n" + ] + } + ], + "source": [ + "learning_rate = 1e-3\n", + "batch_size = 500\n", + "epochs = 25\n", + "optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)\n", + "\n", + "# run the training loop \n", + "train_loop(dataset=example_data, model=mmvec_model, optimizer=optimizer, batch_size=batch_size)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/refactor/041222pytorchdraft.ipynb b/examples/refactor/041222pytorchdraft.ipynb index 968710a..019843d 100644 --- a/examples/refactor/041222pytorchdraft.ipynb +++ b/examples/refactor/041222pytorchdraft.ipynb @@ -487,7 +487,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.8.12" } }, "nbformat": 4, diff --git a/examples/refactor/041422pytorchdraft.ipynb b/examples/refactor/041422pytorchdraft.ipynb new file mode 100644 index 0000000..019843d --- /dev/null +++ b/examples/refactor/041422pytorchdraft.ipynb @@ -0,0 +1,495 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "213bcdfc", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset\n", + "from torch.distributions import Multinomial\n", + "import biom" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "382bb9ce", + "metadata": {}, + "outputs": [], + "source": [ + "# some example data\n", + "microbes = biom.load_table(\"./soil_microbes.biom\")\n", + "metabolites = biom.load_table(\"./soil_metabolites.biom\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "96fac3bf", + "metadata": {}, + "outputs": [], + "source": [ + "class MicrobeMetaboliteData(Dataset):\n", + " def __init__(self, microbes: biom.table, metabolites: biom.table):\n", + " # arrange\n", + " self.microbes = microbes.to_dataframe().T \n", + " self.metabolites = metabolites.to_dataframe().T\n", + " \n", + " # only samples that have results\n", + " self.microbes = self.microbes.loc[self.metabolites.index]\n", + " \n", + " # convert to tensors/final form\n", + " self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)\n", + " self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)\n", + " \n", + " # counts\n", + " self.microbe_count = self.microbes.shape[1]\n", + " self.metabolite_count = self.metabolites.shape[1]\n", + " \n", + " # relative frequencies\n", + " self.microbe_relative_frequency = (self.microbes.T\n", + " / self.microbes.sum(1)\n", + " ).T\n", + " \n", + " self.metabolite_relative_frequency = (self.metabolites.T\n", + " / self.metabolites.sum(1)\n", + " ).T\n", + " \n", + " self.total_microbe_observations = self.microbes.sum()\n", + " \n", + " def __len__(self):\n", + " return self.total_microbe_observations" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "234ccc47", + "metadata": {}, + "outputs": [], + "source": [ + "example_data = MicrobeMetaboliteData(microbes, metabolites)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0ab12e60", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "424846" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example_data.total_microbe_observations.item()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f106a231", + "metadata": {}, + "outputs": [], + "source": [ + "class MMVec(nn.Module):\n", + " def __init__(self, num_microbes, num_metabolites, latent_dim):\n", + " super().__init__()\n", + " #\n", + " self.encoder = nn.Embedding(num_microbes, latent_dim)\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(latent_dim, num_metabolites),\n", + " # [batch, sample, metabolite]\n", + " nn.Softmax(dim=2)\n", + " )\n", + " \n", + " # X = batch_size of microbe indexes\n", + " # Y = expected metabolite data\n", + " def forward(self, X, Y):\n", + " \n", + " # pass our random draws to our embedding\n", + " z = self.encoder(X)\n", + " \n", + " # from latent dimensions in embedding through\n", + " # our linear function to predicted metabolite frequencies which\n", + " # we then normalize with softmax\n", + " y_pred = self.decoder(z)\n", + " \n", + " # total_count=0 and validate_args=False allows skipping total count when calling log_prob\n", + " # as there having floating point issues leading to \"incorrect\" total counts.\n", + " # This multinomial is generated from the output of the single\n", + " forward_dist = Multinomial(total_count=0,\n", + " validate_args=False,\n", + " probs=y_pred)\n", + " \n", + " # the log probability of drawing our expected results from our \"predictions\"\n", + " forward_dist = forward_dist.log_prob(Y)\n", + " \n", + " # get sample loss, a sample in each \"row\"/ zeroeth dimension of the tensor\n", + " forward_dist = forward_dist.mean(0)\n", + " \n", + " # total log probability loss in regards to all samples\n", + " lp = forward_dist.mean()\n", + "\n", + " return lp" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b74bdf61", + "metadata": {}, + "outputs": [], + "source": [ + "mmvec_model = MMVec(example_data.microbe_count, example_data.metabolite_count, 15)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cbc8d647", + "metadata": {}, + "outputs": [], + "source": [ + "def train_loop(dataset, model, optimizer, batch_size):\n", + " \n", + " # because we are wanting to look at all of the samples together we are having to \n", + " # handle our own batching for now. This method currently leads to slight over-\n", + " # sampling but can be refined.\n", + " n_batches = torch.div(dataset.total_microbe_observations.item(),\n", + " batch_size,\n", + " rounding_mode = 'floor') + 1\n", + " \n", + " # We will want to implement batching functionality later for\n", + " # paralizability, but for now running on cpu this works.\n", + " for batch in range(n_batches * epochs):\n", + " \n", + " # the draws we will be training each batch on that will\n", + " # be fed to all samples in our model. This step will probably be\n", + " # moved to a sampler or collate_fn somewhere in the dataset/dataloader\n", + " # but how exactly that will work is not clear at the moment\n", + " draws = torch.multinomial(dataset.microbe_relative_frequency,\n", + " batch_size,\n", + " replacement=True).T\n", + " \n", + " # \"forward step\", our model generates our \"predictions\", so there is no need to\n", + " # call `forward` separately.\n", + " lp = model(draws,\n", + " dataset.metabolite_relative_frequency)\n", + " \n", + " # this location is idiomatic but flexible\n", + " optimizer.zero_grad()\n", + " \n", + " # the typical training bit.\n", + " lp.backward()\n", + " optimizer.step()\n", + " \n", + " if batch % 100 == 0:\n", + " print(f\"loss: {lp.item()}\\nBatch #: {batch}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfb75b21", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: -4.114527225494385\n", + "Batch #: 0\n", + "loss: -3.6144325733184814\n", + "Batch #: 100\n", + "loss: -3.0469698905944824\n", + "Batch #: 200\n", + "loss: -2.70939564704895\n", + "Batch #: 300\n", + "loss: -2.5499744415283203\n", + "Batch #: 400\n", + "loss: -2.473045587539673\n", + "Batch #: 500\n", + "loss: -2.4374732971191406\n", + "Batch #: 600\n", + "loss: -2.421781539916992\n", + "Batch #: 700\n", + "loss: -2.4101920127868652\n", + "Batch #: 800\n", + "loss: -2.4041030406951904\n", + "Batch #: 900\n", + "loss: -2.4012131690979004\n", + "Batch #: 1000\n", + "loss: -2.397974967956543\n", + "Batch #: 1100\n", + "loss: -2.3931915760040283\n", + "Batch #: 1200\n", + "loss: -2.3923048973083496\n", + "Batch #: 1300\n", + "loss: -2.389982223510742\n", + "Batch #: 1400\n", + "loss: -2.3868303298950195\n", + "Batch #: 1500\n", + "loss: -2.3855628967285156\n", + "Batch #: 1600\n", + "loss: -2.382643222808838\n", + "Batch #: 1700\n", + "loss: -2.381664991378784\n", + "Batch #: 1800\n", + "loss: -2.3774473667144775\n", + "Batch #: 1900\n", + "loss: -2.378610372543335\n", + "Batch #: 2000\n", + "loss: -2.3776485919952393\n", + "Batch #: 2100\n", + "loss: -2.376375675201416\n", + "Batch #: 2200\n", + "loss: -2.3723671436309814\n", + "Batch #: 2300\n", + "loss: -2.372851848602295\n", + "Batch #: 2400\n", + "loss: -2.373134136199951\n", + "Batch #: 2500\n", + "loss: -2.3704051971435547\n", + "Batch #: 2600\n", + "loss: -2.37052059173584\n", + "Batch #: 2700\n", + "loss: -2.371293306350708\n", + "Batch #: 2800\n", + "loss: -2.3711659908294678\n", + "Batch #: 2900\n", + "loss: -2.3693435192108154\n", + "Batch #: 3000\n", + "loss: -2.370833396911621\n", + "Batch #: 3100\n", + "loss: -2.36956787109375\n", + "Batch #: 3200\n", + "loss: -2.3683981895446777\n", + "Batch #: 3300\n", + "loss: -2.368025064468384\n", + "Batch #: 3400\n", + "loss: -2.3673665523529053\n", + "Batch #: 3500\n", + "loss: -2.3669538497924805\n", + "Batch #: 3600\n", + "loss: -2.364877700805664\n", + "Batch #: 3700\n", + "loss: -2.3676393032073975\n", + "Batch #: 3800\n", + "loss: -2.3655707836151123\n", + "Batch #: 3900\n", + "loss: -2.365952253341675\n", + "Batch #: 4000\n", + "loss: -2.366527557373047\n", + "Batch #: 4100\n", + "loss: -2.364421844482422\n", + "Batch #: 4200\n", + "loss: -2.363978385925293\n", + "Batch #: 4300\n", + "loss: -2.3649704456329346\n", + "Batch #: 4400\n", + "loss: -2.364382743835449\n", + "Batch #: 4500\n", + "loss: -2.361299991607666\n", + "Batch #: 4600\n", + "loss: -2.3609752655029297\n", + "Batch #: 4700\n", + "loss: -2.3623459339141846\n", + "Batch #: 4800\n", + "loss: -2.3606176376342773\n", + "Batch #: 4900\n", + "loss: -2.3621227741241455\n", + "Batch #: 5000\n", + "loss: -2.3601856231689453\n", + "Batch #: 5100\n", + "loss: -2.3616325855255127\n", + "Batch #: 5200\n", + "loss: -2.3607864379882812\n", + "Batch #: 5300\n", + "loss: -2.3603267669677734\n", + "Batch #: 5400\n", + "loss: -2.3611979484558105\n", + "Batch #: 5500\n", + "loss: -2.36138653755188\n", + "Batch #: 5600\n", + "loss: -2.3617565631866455\n", + "Batch #: 5700\n", + "loss: -2.3602635860443115\n", + "Batch #: 5800\n", + "loss: -2.3588624000549316\n", + "Batch #: 5900\n", + "loss: -2.363048791885376\n", + "Batch #: 6000\n", + "loss: -2.357430934906006\n", + "Batch #: 6100\n", + "loss: -2.359692335128784\n", + "Batch #: 6200\n", + "loss: -2.359476327896118\n", + "Batch #: 6300\n", + "loss: -2.358708381652832\n", + "Batch #: 6400\n", + "loss: -2.3578848838806152\n", + "Batch #: 6500\n", + "loss: -2.3591620922088623\n", + "Batch #: 6600\n", + "loss: -2.3596458435058594\n", + "Batch #: 6700\n", + "loss: -2.358290672302246\n", + "Batch #: 6800\n", + "loss: -2.3569066524505615\n", + "Batch #: 6900\n", + "loss: -2.3586177825927734\n", + "Batch #: 7000\n", + "loss: -2.359415054321289\n", + "Batch #: 7100\n", + "loss: -2.358649969100952\n", + "Batch #: 7200\n", + "loss: -2.35966420173645\n", + "Batch #: 7300\n", + "loss: -2.358867883682251\n", + "Batch #: 7400\n", + "loss: -2.3568341732025146\n", + "Batch #: 7500\n", + "loss: -2.3596749305725098\n", + "Batch #: 7600\n", + "loss: -2.359412670135498\n", + "Batch #: 7700\n", + "loss: -2.357198476791382\n", + "Batch #: 7800\n", + "loss: -2.358001947402954\n", + "Batch #: 7900\n", + "loss: -2.3569891452789307\n", + "Batch #: 8000\n", + "loss: -2.3587193489074707\n", + "Batch #: 8100\n", + "loss: -2.3581130504608154\n", + "Batch #: 8200\n", + "loss: -2.3578381538391113\n", + "Batch #: 8300\n", + "loss: -2.357231855392456\n", + "Batch #: 8400\n", + "loss: -2.3578529357910156\n", + "Batch #: 8500\n", + "loss: -2.3557262420654297\n", + "Batch #: 8600\n", + "loss: -2.355126142501831\n", + "Batch #: 8700\n", + "loss: -2.3567700386047363\n", + "Batch #: 8800\n", + "loss: -2.3553476333618164\n", + "Batch #: 8900\n", + "loss: -2.356520175933838\n", + "Batch #: 9000\n", + "loss: -2.3572936058044434\n", + "Batch #: 9100\n", + "loss: -2.358710527420044\n", + "Batch #: 9200\n", + "loss: -2.3547816276550293\n", + "Batch #: 9300\n", + "loss: -2.3565027713775635\n", + "Batch #: 9400\n", + "loss: -2.3561108112335205\n", + "Batch #: 9500\n", + "loss: -2.356635808944702\n", + "Batch #: 9600\n", + "loss: -2.356121301651001\n", + "Batch #: 9700\n", + "loss: -2.3586411476135254\n", + "Batch #: 9800\n", + "loss: -2.3572912216186523\n", + "Batch #: 9900\n", + "loss: -2.35567045211792\n", + "Batch #: 10000\n", + "loss: -2.3584144115448\n", + "Batch #: 10100\n", + "loss: -2.3562276363372803\n", + "Batch #: 10200\n", + "loss: -2.3546085357666016\n", + "Batch #: 10300\n", + "loss: -2.3559350967407227\n", + "Batch #: 10400\n", + "loss: -2.356455087661743\n", + "Batch #: 10500\n", + "loss: -2.3574140071868896\n", + "Batch #: 10600\n", + "loss: -2.3562002182006836\n", + "Batch #: 10700\n", + "loss: -2.35746169090271\n", + "Batch #: 10800\n", + "loss: -2.3548736572265625\n", + "Batch #: 10900\n", + "loss: -2.3564090728759766\n", + "Batch #: 11000\n", + "loss: -2.3564658164978027\n", + "Batch #: 11100\n", + "loss: -2.3554699420928955\n", + "Batch #: 11200\n", + "loss: -2.3563244342803955\n", + "Batch #: 11300\n", + "loss: -2.357598066329956\n", + "Batch #: 11400\n", + "loss: -2.35477614402771\n", + "Batch #: 11500\n", + "loss: -2.3572442531585693\n", + "Batch #: 11600\n", + "loss: -2.357273817062378\n", + "Batch #: 11700\n", + "loss: -2.3560562133789062\n", + "Batch #: 11800\n", + "loss: -2.355698823928833\n", + "Batch #: 11900\n", + "loss: -2.3559463024139404\n", + "Batch #: 12000\n", + "loss: -2.35664439201355\n", + "Batch #: 12100\n", + "loss: -2.355379104614258\n", + "Batch #: 12200\n", + "loss: -2.354964256286621\n", + "Batch #: 12300\n" + ] + } + ], + "source": [ + "learning_rate = 1e-3\n", + "batch_size = 500\n", + "epochs = 25\n", + "optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)\n", + "\n", + "# run the training loop \n", + "train_loop(dataset=example_data, model=mmvec_model, optimizer=optimizer, batch_size=batch_size)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mmvec/multimodal.py b/mmvec/multimodal.py index eb54ded..1f50449 100644 --- a/mmvec/multimodal.py +++ b/mmvec/multimodal.py @@ -1,280 +1,27 @@ -import os -import time -from tqdm import tqdm -import numpy as np -import tensorflow as tf -from tensorflow.contrib.distributions import Multinomial, Normal -import datetime - - -class MMvec(object): - - def __init__(self, u_mean=0, u_scale=1, v_mean=0, v_scale=1, - batch_size=50, latent_dim=3, - learning_rate=0.1, beta_1=0.8, beta_2=0.9, - clipnorm=10., device_name='/cpu:0', save_path=None): - """ Build a tensorflow model for microbe-metabolite vectors - - Returns - ------- - loss : tf.Tensor - The log loss of the model. - - Notes - ----- - To enable a GPU, set the device to '/device:GPU:x' - where x is 0 or greater - """ - p = latent_dim - self.device_name = device_name - if save_path is None: - basename = "logdir" - suffix = datetime.datetime.now().strftime("%y%m%d_%H%M%S") - save_path = "_".join([basename, suffix]) - - self.p = p - self.u_mean = u_mean - self.u_scale = u_scale - self.v_mean = v_mean - self.v_scale = v_scale - self.batch_size = batch_size - self.latent_dim = latent_dim - - self.learning_rate = learning_rate - self.beta_1 = beta_1 - self.beta_2 = beta_2 - self.clipnorm = clipnorm - self.save_path = save_path - - def __call__(self, session, trainX, trainY, testX, testY): - """ Initialize the actual graph - - Parameters - ---------- - session : tf.Session - Tensorflow session - trainX : sparse array in coo format - Test input OTU table, where rows are samples and columns are - observations - trainY : np.array - Test output metabolite table - testX : sparse array in coo format - Test input OTU table, where rows are samples and columns are - observations. This is mainly for cross validation. - testY : np.array - Test output metabolite table. This is mainly for cross validation. - """ - self.session = session - self.nnz = len(trainX.data) - self.d1 = trainX.shape[1] - self.d2 = trainY.shape[1] - self.cv_size = len(testX.data) - - # keep the multinomial sampling on the cpu - # https://github.com/tensorflow/tensorflow/issues/18058 - with tf.device('/cpu:0'): - X_ph = tf.SparseTensor( - indices=np.array([trainX.row, trainX.col]).T, - values=trainX.data, - dense_shape=trainX.shape) - Y_ph = tf.constant(trainY, dtype=tf.float32) - - X_holdout = tf.SparseTensor( - indices=np.array([testX.row, testX.col]).T, - values=testX.data, - dense_shape=testX.shape) - Y_holdout = tf.constant(testY, dtype=tf.float32) - - total_count = tf.reduce_sum(Y_ph, axis=1) - batch_ids = tf.multinomial( - tf.log(tf.reshape(X_ph.values, [1, -1])), - self.batch_size) - batch_ids = tf.squeeze(batch_ids) - X_samples = tf.gather(X_ph.indices, 0, axis=1) - X_obs = tf.gather(X_ph.indices, 1, axis=1) - sample_ids = tf.gather(X_samples, batch_ids) - - Y_batch = tf.gather(Y_ph, sample_ids) - X_batch = tf.gather(X_obs, batch_ids) - - with tf.device(self.device_name): - self.qUmain = tf.Variable( - tf.random_normal([self.d1, self.p]), name='qU') - self.qUbias = tf.Variable( - tf.random_normal([self.d1, 1]), name='qUbias') - self.qVmain = tf.Variable( - tf.random_normal([self.p, self.d2-1]), name='qV') - self.qVbias = tf.Variable( - tf.random_normal([1, self.d2-1]), name='qVbias') - - qU = tf.concat( - [tf.ones([self.d1, 1]), self.qUbias, self.qUmain], axis=1) - qV = tf.concat( - [self.qVbias, tf.ones([1, self.d2-1]), self.qVmain], axis=0) - - # regression coefficents distribution - Umain = Normal(loc=tf.zeros([self.d1, self.p]) + self.u_mean, - scale=tf.ones([self.d1, self.p]) * self.u_scale, - name='U') - Ubias = Normal(loc=tf.zeros([self.d1, 1]) + self.u_mean, - scale=tf.ones([self.d1, 1]) * self.u_scale, - name='biasU') - - Vmain = Normal(loc=tf.zeros([self.p, self.d2-1]) + self.v_mean, - scale=tf.ones([self.p, self.d2-1]) * self.v_scale, - name='V') - Vbias = Normal(loc=tf.zeros([1, self.d2-1]) + self.v_mean, - scale=tf.ones([1, self.d2-1]) * self.v_scale, - name='biasV') - - du = tf.gather(qU, X_batch, axis=0, name='du') - dv = tf.concat([tf.zeros([self.batch_size, 1]), - du @ qV], axis=1, name='dv') - - tc = tf.gather(total_count, sample_ids) - Y = Multinomial(total_count=tc, logits=dv, name='Y') - num_samples = trainX.shape[0] - norm = num_samples / self.batch_size - logprob_vmain = tf.reduce_sum( - Vmain.log_prob(self.qVmain), name='logprob_vmain') - logprob_vbias = tf.reduce_sum( - Vbias.log_prob(self.qVbias), name='logprob_vbias') - logprob_umain = tf.reduce_sum( - Umain.log_prob(self.qUmain), name='logprob_umain') - logprob_ubias = tf.reduce_sum( - Ubias.log_prob(self.qUbias), name='logprob_ubias') - logprob_y = tf.reduce_sum(Y.log_prob(Y_batch), name='logprob_y') - self.log_loss = - ( - logprob_y * norm + - logprob_umain + logprob_ubias + - logprob_vmain + logprob_vbias - ) - - # keep the multinomial sampling on the cpu - # https://github.com/tensorflow/tensorflow/issues/18058 - with tf.device('/cpu:0'): - # cross validation - with tf.name_scope('accuracy'): - cv_batch_ids = tf.multinomial( - tf.log(tf.reshape(X_holdout.values, [1, -1])), - self.cv_size) - cv_batch_ids = tf.squeeze(cv_batch_ids) - X_cv_samples = tf.gather(X_holdout.indices, 0, axis=1) - X_cv = tf.gather(X_holdout.indices, 1, axis=1) - cv_sample_ids = tf.gather(X_cv_samples, cv_batch_ids) - - Y_cvbatch = tf.gather(Y_holdout, cv_sample_ids) - X_cvbatch = tf.gather(X_cv, cv_batch_ids) - holdout_count = tf.reduce_sum(Y_cvbatch, axis=1) - cv_du = tf.gather(qU, X_cvbatch, axis=0, name='cv_du') - pred = tf.reshape( - holdout_count, [-1, 1]) * tf.nn.softmax( - tf.concat([tf.zeros([ - self.cv_size, 1]), - cv_du @ qV], axis=1, name='pred') - ) - - self.cv = tf.reduce_mean( - tf.squeeze(tf.abs(pred - Y_cvbatch)) +import torch +import torch.nn as nn +from torch.distributions import Multinomial + +class MMvec(nn.Module): + def __init__(self, num_microbes, num_metabolites, latent_dim): + super().__init__() + + self.encoder = nn.Embedding(num_microbes, latent_dim) + self.decoder = nn.Sequential( + nn.Linear(latent_dim, num_metabolites), + nn.Softmax(dim=2) ) - # keep all summaries on the cpu - with tf.device('/cpu:0'): - tf.summary.scalar('logloss', self.log_loss) - tf.summary.scalar('cv_rmse', self.cv) - tf.summary.histogram('qUmain', self.qUmain) - tf.summary.histogram('qVmain', self.qVmain) - tf.summary.histogram('qUbias', self.qUbias) - tf.summary.histogram('qVbias', self.qVbias) - self.merged = tf.summary.merge_all() - - self.writer = tf.summary.FileWriter( - self.save_path, self.session.graph) - - with tf.device(self.device_name): - with tf.name_scope('optimize'): - optimizer = tf.train.AdamOptimizer( - self.learning_rate, beta1=self.beta_1, beta2=self.beta_2) - - gradients, self.variables = zip( - *optimizer.compute_gradients(self.log_loss)) - self.gradients, _ = tf.clip_by_global_norm( - gradients, self.clipnorm) - self.train = optimizer.apply_gradients( - zip(self.gradients, self.variables)) - - tf.global_variables_initializer().run() - - def ranks(self): - modelU = np.hstack( - (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) - modelV = np.vstack( - (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) - - res = np.hstack((np.zeros((self.U.shape[0], 1)), modelU @ modelV)) - res = res - res.mean(axis=1).reshape(-1, 1) - return res - - def fit(self, epoch=10, summary_interval=1000, checkpoint_interval=3600, - testX=None, testY=None): - """ Fits the model. - - Parameters - ---------- - epoch : int - Number of epochs to train - summary_interval : int - Number of seconds until a summary is recorded - checkpoint_interval : int - Number of seconds until a checkpoint is recorded - - Returns - ------- - loss: float - log likelihood loss. - cv : float - cross validation loss - """ - iterations = epoch * self.nnz // self.batch_size - losses, cvs = [], [] - cv = None - last_checkpoint_time = 0 - last_summary_time = 0 - saver = tf.train.Saver() - now = time.time() - for i in tqdm(range(0, iterations)): - if now - last_summary_time > summary_interval: - - res = self.session.run( - [self.train, self.merged, self.log_loss, self.cv, - self.qUmain, self.qUbias, - self.qVmain, self.qVbias] - ) - train_, summary, loss, cv, rU, rUb, rV, rVb = res - self.writer.add_summary(summary, i) - last_summary_time = now - else: - res = self.session.run( - [self.train, self.log_loss, - self.qUmain, self.qUbias, - self.qVmain, self.qVbias] - ) - train_, loss, rU, rUb, rV, rVb = res - losses.append(loss) - cvs.append(cv) - cv = None + def forward(self, X, Y): + z = self.encoder(X) + y_pred = self.decoder(z) + + forward_dist = Multinomial(total_count=0, + validate_args=False, + probs=y_pred) - # checkpoint model - now = time.time() - if now - last_checkpoint_time > checkpoint_interval: - saver.save(self.session, - os.path.join(self.save_path, "model.ckpt"), - global_step=i) - last_checkpoint_time = now + forward_dist = forward_dist.log_prob(Y) - self.U = rU - self.V = rV - self.Ubias = rUb - self.Vbias = rVb + lp = forward_dist.mean(0).mean() - return losses, cvs + return lp diff --git a/mmvec/old_multimodal.py b/mmvec/old_multimodal.py new file mode 100644 index 0000000..ff741e8 --- /dev/null +++ b/mmvec/old_multimodal.py @@ -0,0 +1,283 @@ +import torch +import torch.nn as nn +from torch.distributions import Multinomial +#import os +#import time +#from tqdm import tqdm +#import numpy as np +#import tensorflow as tf +#from tensorflow.contrib.distributions import Multinomial, Normal +#import datetime +# +# +#class Old_MMvec(object): +# +# def __init__(self, u_mean=0, u_scale=1, v_mean=0, v_scale=1, +# batch_size=50, latent_dim=3, +# learning_rate=0.1, beta_1=0.8, beta_2=0.9, +# clipnorm=10., device_name='/cpu:0', save_path=None): +# """ Build a tensorflow model for microbe-metabolite vectors +# +# Returns +# ------- +# loss : tf.Tensor +# The log loss of the model. +# +# Notes +# ----- +# To enable a GPU, set the device to '/device:GPU:x' +# where x is 0 or greater +# """ +# p = latent_dim +# self.device_name = device_name +# if save_path is None: +# basename = "logdir" +# suffix = datetime.datetime.now().strftime("%y%m%d_%H%M%S") +# save_path = "_".join([basename, suffix]) +# +# self.p = p +# self.u_mean = u_mean +# self.u_scale = u_scale +# self.v_mean = v_mean +# self.v_scale = v_scale +# self.batch_size = batch_size +# self.latent_dim = latent_dim +# +# self.learning_rate = learning_rate +# self.beta_1 = beta_1 +# self.beta_2 = beta_2 +# self.clipnorm = clipnorm +# self.save_path = save_path +# +# def __call__(self, session, trainX, trainY, testX, testY): +# """ Initialize the actual graph +# +# Parameters +# ---------- +# session : tf.Session +# Tensorflow session +# trainX : sparse array in coo format +# Test input OTU table, where rows are samples and columns are +# observations +# trainY : np.array +# Test output metabolite table +# testX : sparse array in coo format +# Test input OTU table, where rows are samples and columns are +# observations. This is mainly for cross validation. +# testY : np.array +# Test output metabolite table. This is mainly for cross validation. +# """ +# self.session = session +# self.nnz = len(trainX.data) +# self.d1 = trainX.shape[1] +# self.d2 = trainY.shape[1] +# self.cv_size = len(testX.data) +# +# # keep the multinomial sampling on the cpu +# # https://github.com/tensorflow/tensorflow/issues/18058 +# with tf.device('/cpu:0'): +# X_ph = tf.SparseTensor( +# indices=np.array([trainX.row, trainX.col]).T, +# values=trainX.data, +# dense_shape=trainX.shape) +# Y_ph = tf.constant(trainY, dtype=tf.float32) +# +# X_holdout = tf.SparseTensor( +# indices=np.array([testX.row, testX.col]).T, +# values=testX.data, +# dense_shape=testX.shape) +# Y_holdout = tf.constant(testY, dtype=tf.float32) +# +# total_count = tf.reduce_sum(Y_ph, axis=1) +# batch_ids = tf.multinomial( +# tf.log(tf.reshape(X_ph.values, [1, -1])), +# self.batch_size) +# batch_ids = tf.squeeze(batch_ids) +# X_samples = tf.gather(X_ph.indices, 0, axis=1) +# X_obs = tf.gather(X_ph.indices, 1, axis=1) +# sample_ids = tf.gather(X_samples, batch_ids) +# +# Y_batch = tf.gather(Y_ph, sample_ids) +# X_batch = tf.gather(X_obs, batch_ids) +# +# with tf.device(self.device_name): +# self.qUmain = tf.Variable( +# tf.random_normal([self.d1, self.p]), name='qU') +# self.qUbias = tf.Variable( +# tf.random_normal([self.d1, 1]), name='qUbias') +# self.qVmain = tf.Variable( +# tf.random_normal([self.p, self.d2-1]), name='qV') +# self.qVbias = tf.Variable( +# tf.random_normal([1, self.d2-1]), name='qVbias') +# +# qU = tf.concat( +# [tf.ones([self.d1, 1]), self.qUbias, self.qUmain], axis=1) +# qV = tf.concat( +# [self.qVbias, tf.ones([1, self.d2-1]), self.qVmain], axis=0) +# +# # regression coefficents distribution +# Umain = Normal(loc=tf.zeros([self.d1, self.p]) + self.u_mean, +# scale=tf.ones([self.d1, self.p]) * self.u_scale, +# name='U') +# Ubias = Normal(loc=tf.zeros([self.d1, 1]) + self.u_mean, +# scale=tf.ones([self.d1, 1]) * self.u_scale, +# name='biasU') +# +# Vmain = Normal(loc=tf.zeros([self.p, self.d2-1]) + self.v_mean, +# scale=tf.ones([self.p, self.d2-1]) * self.v_scale, +# name='V') +# Vbias = Normal(loc=tf.zeros([1, self.d2-1]) + self.v_mean, +# scale=tf.ones([1, self.d2-1]) * self.v_scale, +# name='biasV') +# +# du = tf.gather(qU, X_batch, axis=0, name='du') +# dv = tf.concat([tf.zeros([self.batch_size, 1]), +# du @ qV], axis=1, name='dv') +# +# tc = tf.gather(total_count, sample_ids) +# Y = Multinomial(total_count=tc, logits=dv, name='Y') +# num_samples = trainX.shape[0] +# norm = num_samples / self.batch_size +# logprob_vmain = tf.reduce_sum( +# Vmain.log_prob(self.qVmain), name='logprob_vmain') +# logprob_vbias = tf.reduce_sum( +# Vbias.log_prob(self.qVbias), name='logprob_vbias') +# logprob_umain = tf.reduce_sum( +# Umain.log_prob(self.qUmain), name='logprob_umain') +# logprob_ubias = tf.reduce_sum( +# Ubias.log_prob(self.qUbias), name='logprob_ubias') +# logprob_y = tf.reduce_sum(Y.log_prob(Y_batch), name='logprob_y') +# self.log_loss = - ( +# logprob_y * norm + +# logprob_umain + logprob_ubias + +# logprob_vmain + logprob_vbias +# ) +# +# # keep the multinomial sampling on the cpu +# # https://github.com/tensorflow/tensorflow/issues/18058 +# with tf.device('/cpu:0'): +# # cross validation +# with tf.name_scope('accuracy'): +# cv_batch_ids = tf.multinomial( +# tf.log(tf.reshape(X_holdout.values, [1, -1])), +# self.cv_size) +# cv_batch_ids = tf.squeeze(cv_batch_ids) +# X_cv_samples = tf.gather(X_holdout.indices, 0, axis=1) +# X_cv = tf.gather(X_holdout.indices, 1, axis=1) +# cv_sample_ids = tf.gather(X_cv_samples, cv_batch_ids) +# +# Y_cvbatch = tf.gather(Y_holdout, cv_sample_ids) +# X_cvbatch = tf.gather(X_cv, cv_batch_ids) +# holdout_count = tf.reduce_sum(Y_cvbatch, axis=1) +# cv_du = tf.gather(qU, X_cvbatch, axis=0, name='cv_du') +# pred = tf.reshape( +# holdout_count, [-1, 1]) * tf.nn.softmax( +# tf.concat([tf.zeros([ +# self.cv_size, 1]), +# cv_du @ qV], axis=1, name='pred') +# ) +# +# self.cv = tf.reduce_mean( +# tf.squeeze(tf.abs(pred - Y_cvbatch)) +# ) +# +# # keep all summaries on the cpu +# with tf.device('/cpu:0'): +# tf.summary.scalar('logloss', self.log_loss) +# tf.summary.scalar('cv_rmse', self.cv) +# tf.summary.histogram('qUmain', self.qUmain) +# tf.summary.histogram('qVmain', self.qVmain) +# tf.summary.histogram('qUbias', self.qUbias) +# tf.summary.histogram('qVbias', self.qVbias) +# self.merged = tf.summary.merge_all() +# +# self.writer = tf.summary.FileWriter( +# self.save_path, self.session.graph) +# +# with tf.device(self.device_name): +# with tf.name_scope('optimize'): +# optimizer = tf.train.AdamOptimizer( +# self.learning_rate, beta1=self.beta_1, beta2=self.beta_2) +# +# gradients, self.variables = zip( +# *optimizer.compute_gradients(self.log_loss)) +# self.gradients, _ = tf.clip_by_global_norm( +# gradients, self.clipnorm) +# self.train = optimizer.apply_gradients( +# zip(self.gradients, self.variables)) +# +# tf.global_variables_initializer().run() +# +# def ranks(self): +# modelU = np.hstack( +# (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) +# modelV = np.vstack( +# (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) +# +# res = np.hstack((np.zeros((self.U.shape[0], 1)), modelU @ modelV)) +# res = res - res.mean(axis=1).reshape(-1, 1) +# return res +# +# def fit(self, epoch=10, summary_interval=1000, checkpoint_interval=3600, +# testX=None, testY=None): +# """ Fits the model. +# +# Parameters +# ---------- +# epoch : int +# Number of epochs to train +# summary_interval : int +# Number of seconds until a summary is recorded +# checkpoint_interval : int +# Number of seconds until a checkpoint is recorded +# +# Returns +# ------- +# loss: float +# log likelihood loss. +# cv : float +# cross validation loss +# """ +# iterations = epoch * self.nnz // self.batch_size +# losses, cvs = [], [] +# cv = None +# last_checkpoint_time = 0 +# last_summary_time = 0 +# saver = tf.train.Saver() +# now = time.time() +# for i in tqdm(range(0, iterations)): +# if now - last_summary_time > summary_interval: +# +# res = self.session.run( +# [self.train, self.merged, self.log_loss, self.cv, +# self.qUmain, self.qUbias, +# self.qVmain, self.qVbias] +# ) +# train_, summary, loss, cv, rU, rUb, rV, rVb = res +# self.writer.add_summary(summary, i) +# last_summary_time = now +# else: +# res = self.session.run( +# [self.train, self.log_loss, +# self.qUmain, self.qUbias, +# self.qVmain, self.qVbias] +# ) +# train_, loss, rU, rUb, rV, rVb = res +# losses.append(loss) +# cvs.append(cv) +# cv = None +# +# # checkpoint model +# now = time.time() +# if now - last_checkpoint_time > checkpoint_interval: +# saver.save(self.session, +# os.path.join(self.save_path, "model.ckpt"), +# global_step=i) +# last_checkpoint_time = now +# +# self.U = rU +# self.V = rV +# self.Ubias = rUb +# self.Vbias = rVb +# +# return losses, cvs diff --git a/mmvec/q2/__init__.py b/mmvec/q2/__init__.py deleted file mode 100644 index c8d78a9..0000000 --- a/mmvec/q2/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from ._stats import (Conditional, ConditionalDirFmt, ConditionalFormat, - MMvecStats, MMvecStatsFormat, MMvecStatsDirFmt) -from ._method import paired_omics -from ._visualizers import heatmap, paired_heatmap -from ._summary import summarize_single, summarize_paired - - -__all__ = ['paired_omics', - 'Conditional', 'ConditionalFormat', 'ConditionalDirFmt', - 'MMvecStats', 'MMvecStatsFormat', 'MMvecStatsDirFmt', - 'heatmap', 'paired_heatmap', - 'summarize_single', 'summarize_paired'] diff --git a/mmvec/q2/_method.py b/mmvec/q2/_method.py deleted file mode 100644 index bb6dfbf..0000000 --- a/mmvec/q2/_method.py +++ /dev/null @@ -1,126 +0,0 @@ -import biom -import pandas as pd -import numpy as np -import tensorflow as tf -from skbio import OrdinationResults -import qiime2 -from qiime2.plugin import Metadata -from mmvec.multimodal import MMvec -from mmvec.util import split_tables -from scipy.sparse import coo_matrix -from scipy.sparse.linalg import svds - - -def paired_omics(microbes: biom.Table, - metabolites: biom.Table, - metadata: Metadata = None, - training_column: str = None, - num_testing_examples: int = 5, - min_feature_count: int = 10, - epochs: int = 100, - batch_size: int = 50, - latent_dim: int = 3, - input_prior: float = 1, - output_prior: float = 1, - learning_rate: float = 1e-3, - equalize_biplot: float = False, - arm_the_gpu: bool = False, - summary_interval: int = 60) -> ( - pd.DataFrame, OrdinationResults, qiime2.Metadata - ): - - if metadata is not None: - metadata = metadata.to_dataframe() - - if arm_the_gpu: - # pick out the first GPU - device_name = '/device:GPU:0' - else: - device_name = '/cpu:0' - - # Note: there are a couple of biom -> pandas conversions taking - # place here. This is currently done on purpose, since we - # haven't figured out how to handle sparse matrix multiplication - # in the context of this algorithm. That is a future consideration. - res = split_tables( - microbes, metabolites, - metadata=metadata, training_column=training_column, - num_test=num_testing_examples, - min_samples=min_feature_count) - - (train_microbes_df, test_microbes_df, - train_metabolites_df, test_metabolites_df) = res - - train_microbes_coo = coo_matrix(train_microbes_df.values) - test_microbes_coo = coo_matrix(test_microbes_df.values) - - with tf.Graph().as_default(), tf.Session() as session: - model = MMvec( - latent_dim=latent_dim, - u_scale=input_prior, v_scale=output_prior, - batch_size=batch_size, - device_name=device_name, - learning_rate=learning_rate) - model(session, - train_microbes_coo, train_metabolites_df.values, - test_microbes_coo, test_metabolites_df.values) - - loss, cv = model.fit(epoch=epochs, summary_interval=summary_interval) - ranks = pd.DataFrame(model.ranks(), index=train_microbes_df.columns, - columns=train_metabolites_df.columns) - if latent_dim > 0: - u, s, v = svds(ranks - ranks.mean(axis=0), k=latent_dim) - else: - # fake it until you make it - u, s, v = svds(ranks - ranks.mean(axis=0), k=1) - - ranks = ranks.T - ranks.index.name = 'featureid' - s = s[::-1] - u = u[:, ::-1] - v = v[::-1, :] - if equalize_biplot: - microbe_embed = u @ np.sqrt(np.diag(s)) - metabolite_embed = v.T @ np.sqrt(np.diag(s)) - else: - microbe_embed = u @ np.diag(s) - metabolite_embed = v.T - - pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] - features = pd.DataFrame( - microbe_embed, columns=pc_ids, - index=train_microbes_df.columns) - samples = pd.DataFrame( - metabolite_embed, columns=pc_ids, - index=train_metabolites_df.columns) - short_method_name = 'mmvec biplot' - long_method_name = 'Multiomics mmvec biplot' - eigvals = pd.Series(s, index=pc_ids) - proportion_explained = pd.Series(s**2 / np.sum(s**2), index=pc_ids) - biplot = OrdinationResults( - short_method_name, long_method_name, eigvals, - samples=samples, features=features, - proportion_explained=proportion_explained) - - its = np.arange(len(loss)) - convergence_stats = pd.DataFrame( - { - 'loss': loss, - 'cross-validation': cv, - 'iteration': its - } - ) - - convergence_stats.index.name = 'id' - convergence_stats.index = convergence_stats.index.astype(np.str) - - c = convergence_stats['loss'].astype(np.float) - convergence_stats['loss'] = c - - c = convergence_stats['cross-validation'].astype(np.float) - convergence_stats['cross-validation'] = c - - c = convergence_stats['iteration'].astype(np.int) - convergence_stats['iteration'] = c - - return ranks, biplot, qiime2.Metadata(convergence_stats) diff --git a/mmvec/q2/_stats.py b/mmvec/q2/_stats.py deleted file mode 100644 index 980e937..0000000 --- a/mmvec/q2/_stats.py +++ /dev/null @@ -1,30 +0,0 @@ -from qiime2.plugin import SemanticType, model -from q2_types.feature_data import FeatureData -from q2_types.sample_data import SampleData - - -Conditional = SemanticType('Conditional', - variant_of=FeatureData.field['type']) - - -class ConditionalFormat(model.TextFileFormat): - def validate(*args): - pass - - -ConditionalDirFmt = model.SingleFileDirectoryFormat( - 'ConditionalDirFmt', 'conditionals.tsv', ConditionalFormat) - - -# songbird stats summarizing loss and cv error -MMvecStats = SemanticType('MMvecStats', - variant_of=SampleData.field['type']) - - -class MMvecStatsFormat(model.TextFileFormat): - def validate(*args): - pass - - -MMvecStatsDirFmt = model.SingleFileDirectoryFormat( - 'MMvecStatsDirFmt', 'stats.tsv', MMvecStatsFormat) diff --git a/mmvec/q2/_summary.py b/mmvec/q2/_summary.py deleted file mode 100644 index 524b7a9..0000000 --- a/mmvec/q2/_summary.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -import qiime2 -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt - - -def _convergence_plot(model, baseline, ax0, ax1): - iterations = np.array(model['iteration']) - cv_model = model.dropna() - ax0.plot(cv_model['iteration'][1:], - np.array(cv_model['cross-validation'].values)[1:], - label='model') - ax0.set_ylabel('Cross validation score', fontsize=14) - ax0.set_xlabel('# Iterations', fontsize=14) - - ax1.plot(iterations[1:], - np.array(model['loss'])[1:], label='model') - ax1.set_ylabel('Loss', fontsize=14) - ax1.set_xlabel('# Iterations', fontsize=14) - - if baseline is not None: - iterations = baseline['iteration'] - cv_baseline = baseline.dropna() - ax0.plot(cv_baseline['iteration'][1:], - np.array(cv_baseline['cross-validation'].values)[1:], - label='baseline') - ax0.set_ylabel('Cross validation score', fontsize=14) - ax0.set_xlabel('# Iterations', fontsize=14) - ax0.legend() - - ax1.plot(iterations[1:], - np.array(baseline['loss'])[1:], label='baseline') - ax1.set_ylabel('Loss', fontsize=14) - ax1.set_xlabel('# Iterations', fontsize=14) - ax1.legend() - - -def _summarize(output_dir: str, model: pd.DataFrame, - baseline: pd.DataFrame = None): - - """ Helper method for generating summary pages - Parameters - ---------- - output_dir : str - Name of output directory - model : pd.DataFrame - Model summary with column names - ['loss', 'cross-validation'] - baseline : pd.DataFrame - Baseline model summary with column names - ['loss', 'cross-validation']. Defaults to None (i.e. if only a single - set of model stats will be summarized). - Note - ---- - There may be synchronizing issues if different summary intervals - were used between analyses. For predictable results, try to use the - same summary interval. - """ - fig, ax = plt.subplots(2, 1, figsize=(10, 10)) - if baseline is None: - _convergence_plot(model, None, ax[0], ax[1]) - q2 = None - else: - - _convergence_plot(model, baseline, ax[0], ax[1]) - - # this provides a pseudo-r2 commonly provided in the context - # of logistic / multinomail model (proposed by Cox & Snell) - # http://www3.stat.sinica.edu.tw/statistica/oldpdf/a16n39.pdf - end = min(10, len(model.index)) - # trim only the last 10 numbers - - # compute a q2 score, which is commonly used in - # partial least squares for cross validation - cv_model = model.dropna() - cv_baseline = baseline.dropna() - - l0 = np.mean(cv_baseline['cross-validation'][-end:]) - lm = np.mean(cv_model['cross-validation'][-end:]) - q2 = 1 - lm / l0 - - plt.tight_layout() - fig.savefig(os.path.join(output_dir, 'convergence-plot.svg')) - fig.savefig(os.path.join(output_dir, 'convergence-plot.pdf')) - - index_fp = os.path.join(output_dir, 'index.html') - with open(index_fp, 'w') as index_f: - index_f.write('\n') - index_f.write('

Convergence summary

\n') - index_f.write( - "

If you don't see anything in these plots, you probably need " - "to decrease your --p-summary-interval. Try setting " - "--p-summary-interval 1, which will record the loss at " - "every second.

\n" - ) - - if q2 is not None: - index_f.write( - '

' - '' - 'Pseudo Q-squared: %f

\n' % q2 - ) - - index_f.write( - 'convergence_plots' - ) - index_f.write('') - index_f.write('Download as PDF
\n') - - -def summarize_single(output_dir: str, model_stats: qiime2.Metadata): - _summarize(output_dir, model_stats.to_dataframe()) - - -def summarize_paired(output_dir: str, - model_stats: qiime2.Metadata, - baseline_stats: qiime2.Metadata): - _summarize(output_dir, - model_stats.to_dataframe(), - baseline_stats.to_dataframe()) diff --git a/mmvec/q2/_transformer.py b/mmvec/q2/_transformer.py deleted file mode 100644 index 7c304df..0000000 --- a/mmvec/q2/_transformer.py +++ /dev/null @@ -1,36 +0,0 @@ -import qiime2 -import pandas as pd - -from mmvec.q2 import ConditionalFormat, MMvecStatsFormat -from mmvec.q2.plugin_setup import plugin - - -@plugin.register_transformer -def _1(ff: ConditionalFormat) -> pd.DataFrame: - df = pd.read_csv(str(ff), sep='\t', comment='#', skip_blank_lines=True, - header=0, index_col=0) - return df - - -@plugin.register_transformer -def _2(df: pd.DataFrame) -> ConditionalFormat: - ff = ConditionalFormat() - df.to_csv(str(ff), sep='\t', header=True, index=True) - return ff - - -@plugin.register_transformer -def _3(ff: ConditionalFormat) -> qiime2.Metadata: - return qiime2.Metadata.load(str(ff)) - - -@plugin.register_transformer -def _4(obj: qiime2.Metadata) -> MMvecStatsFormat: - ff = MMvecStatsFormat() - obj.save(str(ff)) - return ff - - -@plugin.register_transformer -def _5(ff: MMvecStatsFormat) -> qiime2.Metadata: - return qiime2.Metadata.load(str(ff)) diff --git a/mmvec/q2/_transformers.py b/mmvec/q2/_transformers.py new file mode 100644 index 0000000..e69de29 diff --git a/mmvec/q2/_visualizers.py b/mmvec/q2/_visualizers.py deleted file mode 100644 index 6861768..0000000 --- a/mmvec/q2/_visualizers.py +++ /dev/null @@ -1,88 +0,0 @@ -from os.path import join -import pandas as pd -import qiime2 -import biom -import pkg_resources -import q2templates -from mmvec.heatmap import ranks_heatmap, paired_heatmaps - - -TEMPLATES = pkg_resources.resource_filename('mmvec.q2', 'assets') - - -def heatmap(output_dir: str, - ranks: pd.DataFrame, - microbe_metadata: qiime2.CategoricalMetadataColumn = None, - metabolite_metadata: qiime2.CategoricalMetadataColumn = None, - method: str = 'average', - metric: str = 'euclidean', - color_palette: str = 'seismic', - margin_palette: str = 'cubehelix', - x_labels: bool = False, - y_labels: bool = False, - level: int = -1, - row_center: bool = True) -> None: - if microbe_metadata is not None: - microbe_metadata = microbe_metadata.to_series() - if metabolite_metadata is not None: - metabolite_metadata = metabolite_metadata.to_series() - ranks = ranks.T - - if row_center: - ranks = ranks - ranks.mean(axis=0) - - hotmap = ranks_heatmap(ranks, microbe_metadata, metabolite_metadata, - method, metric, color_palette, margin_palette, - x_labels, y_labels, level) - - hotmap.savefig(join(output_dir, 'heatmap.pdf'), bbox_inches='tight') - hotmap.savefig(join(output_dir, 'heatmap.png'), bbox_inches='tight') - - index = join(TEMPLATES, 'index.html') - q2templates.render(index, output_dir, context={ - 'title': 'Rank Heatmap', - 'pdf_fp': 'heatmap.pdf', - 'png_fp': 'heatmap.png'}) - - -def paired_heatmap(output_dir: str, - ranks: pd.DataFrame, - microbes_table: biom.Table, - metabolites_table: biom.Table, - features: str = None, - top_k_microbes: int = 2, - keep_top_samples: bool = True, - microbe_metadata: qiime2.CategoricalMetadataColumn = None, - normalize: str = 'log10', - color_palette: str = 'magma', - top_k_metabolites: int = 50, - level: int = -1, - row_center: bool = True) -> None: - if microbe_metadata is not None: - microbe_metadata = microbe_metadata.to_series() - - ranks = ranks.T - - if row_center: - ranks = ranks - ranks.mean(axis=0) - - select_microbes, select_metabolites, hotmaps = paired_heatmaps( - ranks, microbes_table, metabolites_table, microbe_metadata, features, - top_k_microbes, top_k_metabolites, keep_top_samples, level, normalize, - color_palette) - - hotmaps.savefig(join(output_dir, 'heatmap.pdf'), bbox_inches='tight') - hotmaps.savefig(join(output_dir, 'heatmap.png'), bbox_inches='tight') - select_microbes.to_csv(join(output_dir, 'select_microbes.tsv'), sep='\t') - select_metabolites.to_csv( - join(output_dir, 'select_metabolites.tsv'), sep='\t') - - index = join(TEMPLATES, 'index.html') - q2templates.render(index, output_dir, context={ - 'title': 'Paired Feature Abundance Heatmaps', - 'pdf_fp': 'heatmap.pdf', - 'png_fp': 'heatmap.png', - 'table1_fp': 'select_microbes.tsv', - 'download1_text': 'Download microbe abundances as TSV', - 'table2_fp': 'select_metabolites.tsv', - 'download2_text': 'Download top k metabolite abundances as TSV'}) diff --git a/mmvec/q2/assets/index.html b/mmvec/q2/assets/index.html deleted file mode 100644 index a752d3b..0000000 --- a/mmvec/q2/assets/index.html +++ /dev/null @@ -1,28 +0,0 @@ -{% extends 'base.html' %} - -{% block title %}rhapsody : {{ title }}{% endblock %} - -{% block fixed %}{% endblock %} - -{% block content %} - -
-

{{ title }}

- -
- -{% endblock %} diff --git a/mmvec/q2/plugin_setup.py b/mmvec/q2/plugin_setup.py deleted file mode 100644 index 4285418..0000000 --- a/mmvec/q2/plugin_setup.py +++ /dev/null @@ -1,252 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2016--, gneiss development team. -# -# Distributed under the terms of the Modified BSD License. -# -# The full license is in the file COPYING.txt, distributed with this software. -# ---------------------------------------------------------------------------- -import importlib -import qiime2.plugin -import qiime2.sdk -from mmvec import __version__, _heatmap_choices, _cmaps -from qiime2.plugin import (Str, Properties, Int, Float, Metadata, Bool, - MetadataColumn, Categorical, Range, Choices, List) -from q2_types.feature_table import FeatureTable, Frequency -from q2_types.feature_data import FeatureData -from q2_types.sample_data import SampleData -from q2_types.ordination import PCoAResults -from mmvec.q2 import ( - Conditional, ConditionalFormat, ConditionalDirFmt, - MMvecStats, MMvecStatsFormat, MMvecStatsDirFmt, - paired_omics, heatmap, paired_heatmap, summarize_single, summarize_paired -) - -plugin = qiime2.plugin.Plugin( - name='mmvec', - version=__version__, - website="https://github.com/biocore/mmvec", - short_description='Plugin for performing microbe-metabolite ' - 'co-occurence analysis.', - description='This is a QIIME 2 plugin supporting microbe-metabolite ' - 'co-occurence analysis using mmvec.', - package='mmvec') - -plugin.methods.register_function( - function=paired_omics, - inputs={'microbes': FeatureTable[Frequency], - 'metabolites': FeatureTable[Frequency]}, - parameters={ - 'metadata': Metadata, - 'training_column': Str, - 'num_testing_examples': Int, - 'min_feature_count': Int, - 'epochs': Int, - 'batch_size': Int, - 'arm_the_gpu': Bool, - 'latent_dim': Int, - 'input_prior': Float, - 'output_prior': Float, - 'learning_rate': Float, - 'equalize_biplot': Bool, - 'summary_interval': Int - }, - outputs=[ - ('conditionals', FeatureData[Conditional]), - ('conditional_biplot', PCoAResults % Properties('biplot')), - ('model_stats', SampleData[MMvecStats]), - ], - input_descriptions={ - 'microbes': 'Input table of microbial counts.', - 'metabolites': 'Input table of metabolite intensities.', - }, - output_descriptions={ - 'conditionals': 'Mean-centered Conditional log-probabilities.', - 'conditional_biplot': 'Biplot of microbe-metabolite vectors.', - }, - parameter_descriptions={ - 'metadata': 'Sample metadata table with covariates of interest.', - 'training_column': "The metadata column specifying which " - "samples are for training/testing. " - "Entries must be marked `Train` for training " - "examples and `Test` for testing examples. ", - 'num_testing_examples': "The number of random examples to select " - "if `training_column` isn't specified.", - 'epochs': 'The total number of iterations over the entire dataset.', - 'equalize_biplot': 'Biplot arrows and points are on the same scale.', - 'batch_size': 'The number of samples to be evaluated per ' - 'training iteration.', - 'arm_the_gpu': 'Specifies whether or not to use the GPU.', - 'input_prior': 'Width of normal prior for the microbial ' - 'coefficients. Smaller values will regularize ' - 'parameters towards zero. Values must be greater ' - 'than 0.', - 'output_prior': 'Width of normal prior for the metabolite ' - 'coefficients. Smaller values will regularize ' - 'parameters towards zero. Values must be greater ' - 'than 0.', - 'learning_rate': 'Gradient descent decay rate.' - }, - name='Microbe metabolite vectors', - description="Performs bi-loglinear multinomial regression and calculates " - "the conditional probability ranks of metabolite " - "co-occurence given the microbe presence.", - citations=[] -) - -plugin.visualizers.register_function( - function=heatmap, - inputs={'ranks': FeatureData[Conditional]}, - parameters={ - 'microbe_metadata': MetadataColumn[Categorical], - 'metabolite_metadata': MetadataColumn[Categorical], - 'method': Str % Choices(_heatmap_choices['method']), - 'metric': Str % Choices(_heatmap_choices['metric']), - 'color_palette': Str % Choices(_cmaps['heatmap']), - 'margin_palette': Str % Choices(_cmaps['margins']), - 'x_labels': Bool, - 'y_labels': Bool, - 'level': Int % Range(-1, None), - 'row_center': Bool, - }, - input_descriptions={'ranks': 'Conditional probabilities.'}, - parameter_descriptions={ - 'microbe_metadata': 'Optional microbe metadata for annotating plots.', - 'metabolite_metadata': 'Optional metabolite metadata for annotating ' - 'plots.', - 'method': 'Hierarchical clustering method used in clustermap.', - 'metric': 'Distance metric used in clustermap.', - 'color_palette': 'Color palette for clustermap.', - 'margin_palette': 'Name of color palette to use for annotating ' - 'metadata along margin(s) of clustermap.', - 'x_labels': 'Plot x-axis (metabolite) labels?', - 'y_labels': 'Plot y-axis (microbe) labels?', - 'level': 'taxonomic level for annotating clustermap. Set to -1 if not ' - 'parsing semicolon-delimited taxonomies or wish to print ' - 'entire annotation.', - 'row_center': 'Center conditional probability table ' - 'around average row.' - }, - name='Conditional probability heatmap', - description="Generate heatmap depicting mmvec conditional probabilities.", - citations=[] -) - -plugin.visualizers.register_function( - function=paired_heatmap, - inputs={'ranks': FeatureData[Conditional], - 'microbes_table': FeatureTable[Frequency], - 'metabolites_table': FeatureTable[Frequency]}, - parameters={ - 'microbe_metadata': MetadataColumn[Categorical], - 'features': List[Str], - 'top_k_microbes': Int % Range(0, None), - 'color_palette': Str % Choices(_cmaps['heatmap']), - 'normalize': Str % Choices(['log10', 'z_score_col', 'z_score_row', - 'rel_row', 'rel_col', 'None']), - 'top_k_metabolites': Int % Range(1, None) | Str % Choices(['all']), - 'keep_top_samples': Bool, - 'level': Int % Range(-1, None), - 'row_center': Bool, - }, - input_descriptions={'ranks': 'Conditional probabilities.', - 'microbes_table': 'Microbial feature abundances.', - 'metabolites_table': 'Metabolite feature abundances.'}, - parameter_descriptions={ - 'microbe_metadata': 'Optional microbe metadata for annotating plots.', - 'features': 'Microbial feature IDs to display in heatmap. Use this ' - 'parameter to include named feature IDs in the heatmap. ' - 'Can be used in conjunction with top_k_microbes, in which ' - 'case named features will be displayed first, then top ' - 'microbial features in order of log conditional ' - 'probability maximum values.', - 'top_k_microbes': 'Select top k microbes (those with the highest ' - 'relative abundances) to display on the heatmap. ' - 'Set to "all" to display all metabolites.', - 'color_palette': 'Color palette for clustermap.', - 'normalize': 'Optionally normalize heatmap values by columns or rows.', - 'top_k_metabolites': 'Select top k metabolites associated with each ' - 'of the chosen features to display on heatmap.', - 'keep_top_samples': 'Display only samples in which at least one of ' - 'the selected microbes is the most abundant ' - 'feature.', - 'level': 'taxonomic level for annotating clustermap. Set to -1 if not ' - 'parsing semicolon-delimited taxonomies or wish to print ' - 'entire annotation.', - 'row_center': 'Center conditional probability table ' - 'around average row.' - }, - name='Paired feature abundance heatmaps', - description="Generate paired heatmaps that depict microbial and " - "metabolite feature abundances. The left panel displays the " - "abundance of each selected microbial feature in each sample. " - "The right panel displays the abundances of the top k " - "metabolites most highly correlated with these microbes in " - "each sample. The y-axis (sample axis) is shared between each " - "panel.", - citations=[] -) - - -plugin.visualizers.register_function( - function=summarize_single, - inputs={ - 'model_stats': SampleData[MMvecStats] - }, - parameters={}, - input_descriptions={ - 'model_stats': ( - "Summary information produced by running " - "`qiime mmvec paired-omics`." - ) - }, - parameter_descriptions={ - }, - name='MMvec summary statistics', - description=( - "Visualize the convergence statistics from running " - "`qiime mmvec paired-omics`, giving insight " - "into how the model fit to your data." - ) -) - -plugin.visualizers.register_function( - function=summarize_paired, - inputs={ - 'model_stats': SampleData[MMvecStats], - 'baseline_stats': SampleData[MMvecStats] - }, - parameters={}, - input_descriptions={ - - 'model_stats': ( - "Summary information for the reference model, produced by running " - "`qiime mmvec paired-omics`." - ), - 'baseline_stats': ( - "Summary information for the baseline model, produced by running " - "`qiime mmvec paired-omics`." - ) - - }, - parameter_descriptions={ - }, - name='Paired MMvec summary statistics', - description=( - "Visualize the convergence statistics from two MMvec models, " - "giving insight into how the models fit to your data. " - "The produced visualization includes a 'pseudo-Q-squared' value." - ) -) - -# Register types -plugin.register_formats(MMvecStatsFormat, MMvecStatsDirFmt) -plugin.register_semantic_types(MMvecStats) -plugin.register_semantic_type_to_format( - SampleData[MMvecStats], MMvecStatsDirFmt) - -plugin.register_formats(ConditionalFormat, ConditionalDirFmt) -plugin.register_semantic_types(Conditional) -plugin.register_semantic_type_to_format( - FeatureData[Conditional], ConditionalDirFmt) - -importlib.import_module('mmvec.q2._transformer') diff --git a/mmvec/q2/tests/test_method.py b/mmvec/q2/tests/test_method.py deleted file mode 100644 index 2bae849..0000000 --- a/mmvec/q2/tests/test_method.py +++ /dev/null @@ -1,98 +0,0 @@ -import biom -import unittest -import numpy as np -import tensorflow as tf -from mmvec.q2._method import paired_omics -from mmvec.util import random_multimodal -from skbio.stats.composition import clr_inv -from scipy.stats import spearmanr -import numpy.testing as npt - - -class TestMMvec(unittest.TestCase): - - def setUp(self): - np.random.seed(1) - res = random_multimodal( - num_microbes=8, num_metabolites=8, num_samples=150, - latent_dim=2, sigmaQ=2, - microbe_total=1000, metabolite_total=10000, seed=1 - ) - (self.microbes, self.metabolites, self.X, self.B, - self.U, self.Ubias, self.V, self.Vbias) = res - n, d1 = self.microbes.shape - n, d2 = self.metabolites.shape - - self.microbes = biom.Table(self.microbes.values.T, - self.microbes.columns, - self.microbes.index) - self.metabolites = biom.Table(self.metabolites.values.T, - self.metabolites.columns, - self.metabolites.index) - U_ = np.hstack( - (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) - V_ = np.vstack( - (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) - - uv = U_ @ V_ - h = np.zeros((d1, 1)) - self.exp_ranks = clr_inv(np.hstack((h, uv))) - - def test_fit(self): - np.random.seed(1) - tf.reset_default_graph() - tf.set_random_seed(0) - latent_dim = 2 - res_ranks, res_biplot, _ = paired_omics( - self.microbes, self.metabolites, - epochs=1000, latent_dim=latent_dim, - min_feature_count=1, learning_rate=0.1 - ) - res_ranks = clr_inv(res_ranks.T) - s_r, s_p = spearmanr(np.ravel(res_ranks), np.ravel(self.exp_ranks)) - - self.assertGreater(s_r, 0.5) - self.assertLess(s_p, 1e-2) - - # make sure the biplot is of the correct dimensions - npt.assert_allclose( - res_biplot.samples.shape, - np.array([self.microbes.shape[0], latent_dim])) - npt.assert_allclose( - res_biplot.features.shape, - np.array([self.metabolites.shape[0], latent_dim])) - - # make sure that the biplot has the correct ordering - self.assertGreater(res_biplot.proportion_explained[0], - res_biplot.proportion_explained[1]) - self.assertGreater(res_biplot.eigvals[0], - res_biplot.eigvals[1]) - - def test_equalize_sv(self): - np.random.seed(1) - tf.reset_default_graph() - tf.set_random_seed(0) - latent_dim = 2 - res_ranks, res_biplot, _ = paired_omics( - self.microbes, self.metabolites, - epochs=1000, latent_dim=latent_dim, - min_feature_count=1, learning_rate=0.1, - equalize_biplot=True - ) - # make sure the biplot is of the correct dimensions - npt.assert_allclose( - res_biplot.samples.shape, - np.array([self.microbes.shape[0], latent_dim])) - npt.assert_allclose( - res_biplot.features.shape, - np.array([self.metabolites.shape[0], latent_dim])) - - # make sure that the biplot has the correct ordering - self.assertGreater(res_biplot.proportion_explained[0], - res_biplot.proportion_explained[1]) - self.assertGreater(res_biplot.eigvals[0], - res_biplot.eigvals[1]) - - -if __name__ == "__main__": - unittest.main() diff --git a/mmvec/q2/tests/test_visualizers.py b/mmvec/q2/tests/test_visualizers.py deleted file mode 100644 index 6171670..0000000 --- a/mmvec/q2/tests/test_visualizers.py +++ /dev/null @@ -1,97 +0,0 @@ -import unittest -import pandas as pd -from qiime2 import Artifact, CategoricalMetadataColumn -from qiime2.plugins import mmvec -import biom -import numpy as np - - -# these tests just make sure the visualizer runs; nuts + bolts are tested in -# the main package. -class TestHeatmap(unittest.TestCase): - - def setUp(self): - _ranks = pd.DataFrame([[4.1, 1.3, 2.1], [0.1, 0.3, 0.2], - [2.2, 4.3, 3.2], [-6.3, -4.4, 2.1]], - index=pd.Index([c for c in 'ABCD'], name='id'), - columns=['m1', 'm2', 'm3']).T - self.ranks = Artifact.import_data('FeatureData[Conditional]', _ranks) - self.taxa = CategoricalMetadataColumn(pd.Series([ - 'k__Bacteria; p__Proteobacteria; c__Deltaproteobacteria; ' - 'o__Desulfobacterales; f__Desulfobulbaceae; g__; s__', - 'k__Bacteria; p__Cyanobacteria; c__Chloroplast; o__Streptophyta', - 'k__Bacteria; p__Proteobacteria; c__Alphaproteobacteria; ' - 'o__Rickettsiales; f__mitochondria; g__Lardizabala; s__biternata', - 'k__Archaea; p__Euryarchaeota; c__Methanomicrobia; ' - 'o__Methanosarcinales; f__Methanosarcinaceae; g__Methanosarcina'], - index=pd.Index([c for c in 'ABCD'], name='feature-id'), - name='Taxon')) - self.metabolites = CategoricalMetadataColumn(pd.Series([ - 'amino acid', 'carbohydrate', 'drug metabolism'], - index=pd.Index(['m1', 'm2', 'm3'], name='feature-id'), - name='Super Pathway')) - - def test_heatmap_default(self): - mmvec.actions.heatmap(self.ranks, self.taxa, self.metabolites) - - def test_heatmap_no_metadata(self): - mmvec.actions.heatmap(self.ranks) - - def test_heatmap_one_metadata(self): - mmvec.actions.heatmap(self.ranks, self.taxa, None) - - def test_heatmap_no_taxonomy_parsing(self): - mmvec.actions.heatmap(self.ranks, self.taxa, None, level=-1) - - def test_heatmap_plot_axis_labels(self): - mmvec.actions.heatmap(self.ranks, x_labels=True, y_labels=True) - - -class TestPairedHeatmap(unittest.TestCase): - - def setUp(self): - _ranks = pd.DataFrame([[4.1, 1.3, 2.1], [0.1, 0.3, 0.2], - [2.2, 4.3, 3.2], [-6.3, -4.4, 2.1]], - index=pd.Index([c for c in 'ABCD'], name='id'), - columns=['m1', 'm2', 'm3']).T - self.ranks = Artifact.import_data('FeatureData[Conditional]', _ranks) - self.taxa = CategoricalMetadataColumn(pd.Series([ - 'k__Bacteria; p__Proteobacteria; c__Deltaproteobacteria; ' - 'o__Desulfobacterales; f__Desulfobulbaceae; g__; s__', - 'k__Bacteria; p__Cyanobacteria; c__Chloroplast; o__Streptophyta', - 'k__Bacteria; p__Proteobacteria; c__Alphaproteobacteria; ' - 'o__Rickettsiales; f__mitochondria; g__Lardizabala; s__biternata', - 'k__Archaea; p__Euryarchaeota; c__Methanomicrobia; ' - 'o__Methanosarcinales; f__Methanosarcinaceae; g__Methanosarcina'], - index=pd.Index([c for c in 'ABCD'], name='feature-id'), - name='Taxon')) - metabolites = biom.Table( - np.array([[9, 8, 2], [2, 1, 2], [9, 4, 5], [8, 8, 7]]), - sample_ids=['s1', 's2', 's3'], - observation_ids=['m1', 'm2', 'm3', 'm4']) - self.metabolites = Artifact.import_data( - 'FeatureTable[Frequency]', metabolites) - microbes = biom.Table( - np.array([[1, 2, 3], [3, 6, 3], [1, 9, 9], [8, 8, 7]]), - sample_ids=['s1', 's2', 's3'], observation_ids=[i for i in 'ABCD']) - self.microbes = Artifact.import_data( - 'FeatureTable[Frequency]', microbes) - - def test_paired_heatmaps_single_feature(self): - mmvec.actions.paired_heatmap( - self.ranks, self.microbes, self.metabolites, features=['C'], - microbe_metadata=self.taxa) - - def test_paired_heatmaps_multifeature(self): - mmvec.actions.paired_heatmap( - self.ranks, self.microbes, self.metabolites, features=['A', 'C']) - - def test_paired_heatmaps_fail_on_unknown_feature(self): - with self.assertRaisesRegex(ValueError, "must represent feature IDs"): - mmvec.actions.paired_heatmap( - self.ranks, self.microbes, self.metabolites, - features=['A', 'barf']) - - -if __name__ == "__main__": - unittest.main() From d57f4c09c60a6b41cd901291bedaa9c8cec56d24 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Mon, 18 Apr 2022 16:59:40 -0700 Subject: [PATCH 03/27] FEAT: forward likelihood now sum(lu, lv, ly) --- mmvec/multimodal.py | 48 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/mmvec/multimodal.py b/mmvec/multimodal.py index 1f50449..4faa419 100644 --- a/mmvec/multimodal.py +++ b/mmvec/multimodal.py @@ -1,9 +1,12 @@ import torch import torch.nn as nn -from torch.distributions import Multinomial +from torch.distributions import Multinomial, Normal + +from torch.nn.parallel import DistributedDataParallel as ddp class MMvec(nn.Module): - def __init__(self, num_microbes, num_metabolites, latent_dim): + def __init__(self, num_microbes, num_metabolites, latent_dim, sigma_u, + sigma_v): super().__init__() self.encoder = nn.Embedding(num_microbes, latent_dim) @@ -11,17 +14,52 @@ def __init__(self, num_microbes, num_metabolites, latent_dim): nn.Linear(latent_dim, num_metabolites), nn.Softmax(dim=2) ) + self.sigma_u = sigma_u + self.sigma_v = sigma_v def forward(self, X, Y): + # Three likelihoods, the likelihood of each weight and the likelihood + # of the data fitting in the way that we thought + # LY z = self.encoder(X) y_pred = self.decoder(z) - + forward_dist = Multinomial(total_count=0, validate_args=False, probs=y_pred) forward_dist = forward_dist.log_prob(Y) - lp = forward_dist.mean(0).mean() + l_y = forward_dist.mean(0).mean() + + # LU + u_weights = self.encoder.weight#.detach().numpy() + l_u = Normal(0, self.sigma_u).log_prob(u_weights).sum() + #l_u = torch.normal(0, self.sigma_u).log_prob(z + + # LV + # index zero currently holds "linear", may need to be changed later + v_weights = self.decoder[0].weight#.detach().numpy() + l_v = Normal(0, self.sigma_v).log_prob(v_weights).sum() + + likelihood_sum = l_y + l_u + l_v + + return likelihood_sum + + +def train_loop(microbes, metabolites, model, optimizer, batch_size, epochs): + + for epoch in range(epochs): + + draws = torch.multinomial(microbes, + batch_size, + replacement=True).T + + mmvec_model = model(draws, metabolites) + + optimizer.zero_grad() + mmvec_model.backward() + optimizer.step() - return lp +# if epoch % 5 == 0: +# print(f"loss: {mmvec_model.item()}\nBatch #: {epoch}") From 4a29a12bf294ebcbaf728d7c3bc466615cb5ca3f Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Tue, 19 Apr 2022 13:35:57 -0700 Subject: [PATCH 04/27] DEBUG: first pass package refactoring --- __init__.py | 0 mmvec/__init__.py | 3 ++- mmvec/{multimodal.py => model.py} | 2 +- setup.py | 11 +++++------ 4 files changed, 8 insertions(+), 8 deletions(-) create mode 100644 __init__.py rename mmvec/{multimodal.py => model.py} (95%) diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mmvec/__init__.py b/mmvec/__init__.py index d1646ad..9dff8c8 100644 --- a/mmvec/__init__.py +++ b/mmvec/__init__.py @@ -1,5 +1,6 @@ from .heatmap import _heatmap_choices, _cmaps +from .multimodal import MMvec, mmvec_training_loop __version__ = "1.0.6" -__all__ = ['_heatmap_choices', '_cmaps'] +__all__ = ['_heatmap_choices', '_cmaps', 'MMvec', 'mmvec_training_loop'] diff --git a/mmvec/multimodal.py b/mmvec/model.py similarity index 95% rename from mmvec/multimodal.py rename to mmvec/model.py index 4faa419..a916c79 100644 --- a/mmvec/multimodal.py +++ b/mmvec/model.py @@ -47,7 +47,7 @@ def forward(self, X, Y): return likelihood_sum -def train_loop(microbes, metabolites, model, optimizer, batch_size, epochs): +def mmvec_training_loop(microbes, metabolites, model, optimizer, batch_size, epochs): for epoch in range(epochs): diff --git a/setup.py b/setup.py index 044f522..018be08 100644 --- a/setup.py +++ b/setup.py @@ -53,14 +53,13 @@ scripts=glob('scripts/mmvec'), install_requires=[ 'biom-format', - 'numpy >= 1.9.2', - 'pandas <= 0.25.3', - 'scipy >= 0.15.1', - 'nose >= 1.3.7', - 'scikit-bio >= 0.5.1', + 'numpy', + 'pandas', + 'scipy', + 'nose', + 'scikit-bio', 'seaborn', 'tqdm', - 'tensorflow>=1.15,<2' ], classifiers=classifiers, entry_points={ From 14e2b13ee841db522f28a5bccf34401f0cbbe63c Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Wed, 27 Apr 2022 13:02:35 -0700 Subject: [PATCH 05/27] IMP: split ILR and ALR, ALR done to ranks --- examples/refactor/ALR.ipynb | 283 ++++++++++++++++++++++++++++++++++++ mmvec/ALR.py | 91 ++++++++++++ mmvec/ILR.py | 112 ++++++++++++++ mmvec/__init__.py | 7 +- mmvec/model.py | 65 --------- mmvec/train.py | 19 +++ 6 files changed, 510 insertions(+), 67 deletions(-) create mode 100644 examples/refactor/ALR.ipynb create mode 100644 mmvec/ALR.py create mode 100644 mmvec/ILR.py delete mode 100644 mmvec/model.py create mode 100644 mmvec/train.py diff --git a/examples/refactor/ALR.ipynb b/examples/refactor/ALR.ipynb new file mode 100644 index 0000000..b50b464 --- /dev/null +++ b/examples/refactor/ALR.ipynb @@ -0,0 +1,283 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "bae1c8b8", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "51124af9", + "metadata": {}, + "outputs": [], + "source": [ + "%autoreload 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bcc95c50", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "acadee5d", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "import biom\n", + "import torch\n", + "import torch.nn as nn\n", + "%aimport mmvec.ALR\n", + "\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "14d8bfab", + "metadata": {}, + "outputs": [], + "source": [ + "microbes = biom.load_table(\"./soil_microbes.biom\")\n", + "metabolites = biom.load_table(\"./soil_metabolites.biom\")\n", + "\n", + "microbes = microbes.to_dataframe().T\n", + "metabolites = metabolites.to_dataframe().T\n", + "microbes = microbes.loc[metabolites.index]\n", + "\n", + "microbe_idx = microbes.columns\n", + "metabolite_idx = metabolites.columns\n", + "microbes = torch.tensor(microbes.values, dtype=torch.int)\n", + "metabolites = torch.tensor(metabolites.values, dtype=torch.int64)\n", + "\n", + "microbe_relative_frequency = (microbes.T/microbes.sum(1)).T\n", + "\n", + "microbe_count = microbes.shape[1]\n", + "metabolite_count = metabolites.shape[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "9003e38d", + "metadata": {}, + "outputs": [], + "source": [ + "model = mmvec.ALR.MMvecALR(microbe_count, metabolite_count, 15, sigma_u=1, sigma_v=1)\n", + "learning_rate = 1e-3\n", + "batch_size = 200\n", + "epochs = 10000\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, maximize=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "8c8076ae", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "15" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.decoder.linear.weight.shape[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "b977e212", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: -1554443264.0\n", + "Batch #: 0\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [38]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m maybe \u001b[38;5;241m=\u001b[39m \u001b[43mmmvec\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmmvec_training_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmicrobes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmicrobe_relative_frequency\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mmetabolites\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetabolites\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/work/qiime2/mmvec/mmvec/train.py:12\u001b[0m, in \u001b[0;36mmmvec_training_loop\u001b[0;34m(microbes, metabolites, model, optimizer, batch_size, epochs)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(epochs):\n\u001b[1;32m 8\u001b[0m draws \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmultinomial(microbes,\n\u001b[1;32m 9\u001b[0m batch_size,\n\u001b[1;32m 10\u001b[0m replacement\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mT\n\u001b[0;32m---> 12\u001b[0m mmvec_model \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdraws\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetabolites\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 15\u001b[0m mmvec_model\u001b[38;5;241m.\u001b[39mbackward()\n", + "File \u001b[0;32m~/opt/miniconda3/envs/mmvec/lib/python3.10/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/work/qiime2/mmvec/mmvec/ALR.py:56\u001b[0m, in \u001b[0;36mMMvecALR.forward\u001b[0;34m(self, X, Y)\u001b[0m\n\u001b[1;32m 50\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoder(z)\n\u001b[1;32m 52\u001b[0m forward_dist \u001b[38;5;241m=\u001b[39m Multinomial(total_count\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m 53\u001b[0m validate_args\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 54\u001b[0m probs\u001b[38;5;241m=\u001b[39my_pred)\n\u001b[0;32m---> 56\u001b[0m forward_dist \u001b[38;5;241m=\u001b[39m \u001b[43mforward_dist\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43mY\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 58\u001b[0m l_y \u001b[38;5;241m=\u001b[39m forward_dist\u001b[38;5;241m.\u001b[39mmean(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# LU\u001b[39;00m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "maybe = mmvec.train.mmvec_training_loop(microbes=microbe_relative_frequency,\n", + " metabolites=metabolites,\n", + " model=model,\n", + " optimizer=optimizer,\n", + " batch_size=batch_size,\n", + " epochs=epochs)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "cbea21c6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Parameter containing:\n", + "tensor([-0.2304, -0.1210, -0.2237, 1.5584, -0.0357, -0.2358, -0.3364, -0.1522,\n", + " 0.1831, 0.3266, -0.3287, 0.6406, -0.4283, -0.4332, -0.3035, -0.2616,\n", + " -0.3637, -0.0561, 0.3958, -0.4435, 0.1320, 0.3295, 0.3745, -0.3914,\n", + " -0.2479, 1.1489, 0.0920, 0.1924, 0.8578, -0.4792, -0.3238, -0.2298,\n", + " 0.0590, -0.2976, 0.4175, -0.1658, 0.5618, 0.0517, -0.4560, 0.0813,\n", + " 1.3941, 0.2018, -0.4606, 0.0243, -0.3180, 0.2470, -0.2259, 0.0057,\n", + " 0.8352, -0.4047, 0.0856, -0.3115, -0.2540, -0.4270, -0.2389, -0.3265,\n", + " -0.0604, -0.1562, 0.2893, -0.3153, 0.2701, -0.2802, -0.0060, -0.0257,\n", + " -0.1017, -0.2101, -0.4707, -0.3083, -0.4171, 0.9620, -0.2070, -0.0076,\n", + " -0.5172, -0.4790, -0.3156, 0.4800, -0.1798, -0.1373, -0.5174, -0.0725,\n", + " -0.3061, -0.2664, 0.0945, -0.4450], requires_grad=True)" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.decoder.linear.bias" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "2019c1ff", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "unsupported operand type(s) for @: 'Embedding' and 'Linear'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [35]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m@\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\n", + "\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for @: 'Embedding' and 'Linear'" + ] + } + ], + "source": [ + "model.encoder @ model.decoder.linear" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "a5f97b6d", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'pd' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [27]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ranks \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241m.\u001b[39mDataFrame(model\u001b[38;5;241m.\u001b[39mranks(),\n\u001b[1;32m 2\u001b[0m index\u001b[38;5;241m=\u001b[39mmicrobe_idx,\n\u001b[1;32m 3\u001b[0m columns\u001b[38;5;241m=\u001b[39mmetabolite_idx)\n", + "\u001b[0;31mNameError\u001b[0m: name 'pd' is not defined" + ] + } + ], + "source": [ + "ranks = pd.DataFrame(model.ranks(),\n", + " index=microbe_idx,\n", + " columns=metabolite_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fde6a1e", + "metadata": {}, + "outputs": [], + "source": [ + "model.decoder.linear.weight.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe702e51", + "metadata": {}, + "outputs": [], + "source": [ + "for param in model.parameters():\n", + " print(param.shape)\n", + "\n", + "print(model.parameters)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa9e8f31", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "46d440f7", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mmvec/ALR.py b/mmvec/ALR.py new file mode 100644 index 0000000..25ca02e --- /dev/null +++ b/mmvec/ALR.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Multinomial, Normal + +import numpy as np + + +class LinearALR(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim - 1) + + def forward(self, x): + y = self.linear(x) + z = torch.zeros((y.shape[0], y.shape[1], 1)) + y = torch.cat((z, y), dim=2) + + return F.softmax(y, dim=2) + + +class MMvecALR(nn.Module): + def __init__(self, num_microbes, num_metabolites, latent_dim, sigma_u, + sigma_v): + super().__init__() + + self.latent_dim = latent_dim + self.num_microbes = num_microbes + self.num_metabolites = num_metabolites + + self.u_bias = nn.parameter.Parameter(torch.randn((num_microbes, 1))) + + self.encoder = nn.Embedding(num_microbes, latent_dim) + #self.decoder = nn.Sequential( + # nn.Linear(latent_dim, num_metabolites), + # nn.Softmax(dim=2) + # ) + self.decoder = LinearALR(latent_dim, num_metabolites) + + self.sigma_u = sigma_u + self.sigma_v = sigma_v + + + def forward(self, X, Y): + # Three likelihoods, the likelihood of each weight and the likelihood + # of the data fitting in the way that we thought + # LYs + z = self.encoder(X) + z = z + self.u_bias[X].reshape((*X.shape, 1)) + y_pred = self.decoder(z) + + forward_dist = Multinomial(total_count=0, + validate_args=False, + probs=y_pred) + + forward_dist = forward_dist.log_prob(Y) + + l_y = forward_dist.mean(0).mean() + + # LU + u_weights = self.encoder.weight + l_u = Normal(0, self.sigma_u).log_prob(u_weights).sum() + #l_u = torch.normal(0, self.sigma_u).log_prob(z + + # LV + # index zero currently holds "linear", may need to be changed later + v_weights = self.decoder.linear.weight + l_v = Normal(0, self.sigma_v).log_prob(v_weights).sum() + + likelihood_sum = l_y + l_u + l_v + return likelihood_sum + + def ranks(self): + U = torch.cat( + (torch.ones((self.num_microbes, 1)), + self.u_bias.detach(), + self.encoder.weight.detach()), + dim=-1) + + V = torch.cat( + (self.decoder.linear.bias.detach().unsqueeze(dim=0), + torch.from_numpy(np.ones((1, self.num_metabolites - 1))), + self.decoder.linear.weight.detach().T), + dim=0) + #res = np.hstack((np.zeros((self.num_microbes - 1, 1)), modelU @ modelV)) + res = torch.cat((torch.zeros((self.num_microbes -1, 1)), U @ V), + dim=-1) + res = res - res.mean(axis=1).reshape(-1, 1) + # perform SVD here?..... + return res + diff --git a/mmvec/ILR.py b/mmvec/ILR.py new file mode 100644 index 0000000..1f5c132 --- /dev/null +++ b/mmvec/ILR.py @@ -0,0 +1,112 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Multinomial, Normal + +import numpy as np + +from gneiss.cluster import random_linkage +from gneiss.balances import sparse_balance_basis + + + +class MMvecILR(nn.Module): + def __init__(self, num_microbes, num_metabolites, latent_dim, sigma_u, + sigma_v): + super().__init__() + + self.latent_dim = latent_dim + self.num_microbes = num_microbes + self.num_metabolites = num_metabolites + + self.u_bias = nn.parameter.Parameter(torch.randn((num_microbes, 1))) + + self.encoder = nn.Embedding(num_microbes, latent_dim) + #self.decoder = nn.Sequential( + # nn.Linear(latent_dim, num_metabolites), + # nn.Softmax(dim=2) + # ) + self.decoder = LinearILR(latent_dim, num_metabolites) + + self.sigma_u = sigma_u + self.sigma_v = sigma_v + + def forward(self, X, Y): + # Three likelihoods, the likelihood of each weight and the likelihood + # of the data fitting in the way that we thought + # LYs + z = self.encoder(X) + z = z + self.u_bias[X].reshape((*X.shape, 1)) + y_pred = self.decoder(z) + + forward_dist = Multinomial(total_count=0, + validate_args=False, + probs=y_pred) + + forward_dist = forward_dist.log_prob(Y) + + l_y = forward_dist.mean(0).mean() + + # LU + u_weights = self.encoder.weight + l_u = Normal(0, self.sigma_u).log_prob(u_weights).sum() + #l_u = torch.normal(0, self.sigma_u).log_prob(z + + # LV + # index zero currently holds "linear", may need to be changed later + v_weights = self.decoder.linear.weight + l_v = Normal(0, self.sigma_v).log_prob(v_weights).sum() + + likelihood_sum = l_y + l_u + l_v + return likelihood_sum + + + + def ILRranks(self): + #modelU = np.hstack( + # (np.ones((self.num_microbes, 1)), + # self.u_bias.detach().numpy(), + # self.encoder.weight.detach().numpy())) + + U = torch.cat( + (torch.from_numpy(np.ones((self.num_microbes, 1))), + self.u_bias.detach(), + self.encoder.weight.detach()), + dim=-1) + + V = torch.stack( + (self.decoder.linear.bias.detach(), + torch.from_numpy(np.ones((1, self.num_metabolites))), + self.decoder.linear.weight.detach().T), + dim=0) + + #V = torch.sparse.mm(modelV, self.decoder.Psi.T) + #res = modelU V + res = U @ V @ self.decoder.Psi.to_dense().T + #res = modelU @ modelV @ self.decoder.Psi.T + #print(res) + #res = res - res.mean(axis=1).reshape(-1, 1) + return res + + +class LinearILR(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + tree = random_linkage(output_dim) # pick random tree it doesn't really matter tbh + basis = sparse_balance_basis(tree)[0].copy() + indices = np.vstack((basis.row, basis.col)) + Psi = torch.sparse_coo_tensor( + indices.copy(), + basis.data.astype(np.float32).copy(), + dtype=torch.double, + requires_grad=False).coalesce() + + self.linear = nn.Linear(input_dim, output_dim) + self.register_buffer('Psi', Psi) + + def forward(self, x): + y = self.linear(x) + logy = (Psi.t() @ y.t()).t() + return F.softmax(logy, dim=1) + diff --git a/mmvec/__init__.py b/mmvec/__init__.py index 9dff8c8..97c8211 100644 --- a/mmvec/__init__.py +++ b/mmvec/__init__.py @@ -1,6 +1,9 @@ from .heatmap import _heatmap_choices, _cmaps -from .multimodal import MMvec, mmvec_training_loop +from .ALR import MMvecALR +from .ILR import MMvecILR +from .train import mmvec_training_loop __version__ = "1.0.6" -__all__ = ['_heatmap_choices', '_cmaps', 'MMvec', 'mmvec_training_loop'] +__all__ = ['_heatmap_choices', '_cmaps', 'MMvecALR', 'MMvecILR', + 'mmvec_training_loop'] diff --git a/mmvec/model.py b/mmvec/model.py deleted file mode 100644 index a916c79..0000000 --- a/mmvec/model.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import torch.nn as nn -from torch.distributions import Multinomial, Normal - -from torch.nn.parallel import DistributedDataParallel as ddp - -class MMvec(nn.Module): - def __init__(self, num_microbes, num_metabolites, latent_dim, sigma_u, - sigma_v): - super().__init__() - - self.encoder = nn.Embedding(num_microbes, latent_dim) - self.decoder = nn.Sequential( - nn.Linear(latent_dim, num_metabolites), - nn.Softmax(dim=2) - ) - self.sigma_u = sigma_u - self.sigma_v = sigma_v - - def forward(self, X, Y): - # Three likelihoods, the likelihood of each weight and the likelihood - # of the data fitting in the way that we thought - # LY - z = self.encoder(X) - y_pred = self.decoder(z) - - forward_dist = Multinomial(total_count=0, - validate_args=False, - probs=y_pred) - - forward_dist = forward_dist.log_prob(Y) - - l_y = forward_dist.mean(0).mean() - - # LU - u_weights = self.encoder.weight#.detach().numpy() - l_u = Normal(0, self.sigma_u).log_prob(u_weights).sum() - #l_u = torch.normal(0, self.sigma_u).log_prob(z - - # LV - # index zero currently holds "linear", may need to be changed later - v_weights = self.decoder[0].weight#.detach().numpy() - l_v = Normal(0, self.sigma_v).log_prob(v_weights).sum() - - likelihood_sum = l_y + l_u + l_v - - return likelihood_sum - - -def mmvec_training_loop(microbes, metabolites, model, optimizer, batch_size, epochs): - - for epoch in range(epochs): - - draws = torch.multinomial(microbes, - batch_size, - replacement=True).T - - mmvec_model = model(draws, metabolites) - - optimizer.zero_grad() - mmvec_model.backward() - optimizer.step() - -# if epoch % 5 == 0: -# print(f"loss: {mmvec_model.item()}\nBatch #: {epoch}") diff --git a/mmvec/train.py b/mmvec/train.py new file mode 100644 index 0000000..182dcd3 --- /dev/null +++ b/mmvec/train.py @@ -0,0 +1,19 @@ +import torch + + +def mmvec_training_loop(microbes, metabolites, model, optimizer, + batch_size, epochs): + for epoch in range(epochs): + + draws = torch.multinomial(microbes, + batch_size, + replacement=True).T + + mmvec_model = model(draws, metabolites) + + optimizer.zero_grad() + mmvec_model.backward() + optimizer.step() + + if epoch % 500 == 0: + print(f"loss: {mmvec_model.item()}\nBatch #: {epoch}") From b6778c16484fa8506ae3ed8132f21445742ed554 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Wed, 27 Apr 2022 14:45:09 -0700 Subject: [PATCH 06/27] IMP: ALR outputs working --- examples/refactor/ALR.ipynb | 101 +++++++++++++++--------------------- mmvec/ALR.py | 26 +++++----- 2 files changed, 55 insertions(+), 72 deletions(-) diff --git a/examples/refactor/ALR.ipynb b/examples/refactor/ALR.ipynb index b50b464..b5dbf14 100644 --- a/examples/refactor/ALR.ipynb +++ b/examples/refactor/ALR.ipynb @@ -2,8 +2,8 @@ "cells": [ { "cell_type": "code", - "execution_count": null, - "id": "bae1c8b8", + "execution_count": 1, + "id": "c73ac6c4", "metadata": {}, "outputs": [], "source": [ @@ -12,8 +12,8 @@ }, { "cell_type": "code", - "execution_count": 8, - "id": "51124af9", + "execution_count": 2, + "id": "15cefd0d", "metadata": {}, "outputs": [], "source": [ @@ -23,14 +23,14 @@ { "cell_type": "code", "execution_count": null, - "id": "bcc95c50", + "id": "b5234303", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 3, "id": "acadee5d", "metadata": { "scrolled": false @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "id": "14d8bfab", "metadata": {}, "outputs": [], @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "id": "9003e38d", "metadata": {}, "outputs": [], @@ -87,13 +87,13 @@ { "cell_type": "code", "execution_count": 11, - "id": "8c8076ae", + "id": "1ced7f98", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "15" + "torch.Size([84, 15])" ] }, "execution_count": 11, @@ -102,12 +102,12 @@ } ], "source": [ - "model.decoder.linear.weight.shape[1]" + "model.decoder.linear.weight.shape" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 9, "id": "b977e212", "metadata": {}, "outputs": [ @@ -115,8 +115,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "loss: -1554443264.0\n", - "Batch #: 0\n" + "loss: -14500305920.0\n", + "Batch #: 0\n", + "loss: -2583066368.0\n", + "Batch #: 500\n", + "loss: -1934131328.0\n", + "Batch #: 1000\n" ] }, { @@ -126,10 +130,11 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [38]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m maybe \u001b[38;5;241m=\u001b[39m \u001b[43mmmvec\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmmvec_training_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmicrobes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmicrobe_relative_frequency\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mmetabolites\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetabolites\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m)\u001b[49m\n", + "Input \u001b[0;32mIn [9]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m maybe \u001b[38;5;241m=\u001b[39m \u001b[43mmmvec\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmmvec_training_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmicrobes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmicrobe_relative_frequency\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mmetabolites\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetabolites\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/work/qiime2/mmvec/mmvec/train.py:12\u001b[0m, in \u001b[0;36mmmvec_training_loop\u001b[0;34m(microbes, metabolites, model, optimizer, batch_size, epochs)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(epochs):\n\u001b[1;32m 8\u001b[0m draws \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmultinomial(microbes,\n\u001b[1;32m 9\u001b[0m batch_size,\n\u001b[1;32m 10\u001b[0m replacement\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mT\n\u001b[0;32m---> 12\u001b[0m mmvec_model \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdraws\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetabolites\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 15\u001b[0m mmvec_model\u001b[38;5;241m.\u001b[39mbackward()\n", "File \u001b[0;32m~/opt/miniconda3/envs/mmvec/lib/python3.10/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/work/qiime2/mmvec/mmvec/ALR.py:56\u001b[0m, in \u001b[0;36mMMvecALR.forward\u001b[0;34m(self, X, Y)\u001b[0m\n\u001b[1;32m 50\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoder(z)\n\u001b[1;32m 52\u001b[0m forward_dist \u001b[38;5;241m=\u001b[39m Multinomial(total_count\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m 53\u001b[0m validate_args\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 54\u001b[0m probs\u001b[38;5;241m=\u001b[39my_pred)\n\u001b[0;32m---> 56\u001b[0m forward_dist \u001b[38;5;241m=\u001b[39m \u001b[43mforward_dist\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43mY\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 58\u001b[0m l_y \u001b[38;5;241m=\u001b[39m forward_dist\u001b[38;5;241m.\u001b[39mmean(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# LU\u001b[39;00m\n", + "File \u001b[0;32m~/work/qiime2/mmvec/mmvec/ALR.py:51\u001b[0m, in \u001b[0;36mMMvecALR.forward\u001b[0;34m(self, X, Y)\u001b[0m\n\u001b[1;32m 49\u001b[0m z \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencoder(X)\n\u001b[1;32m 50\u001b[0m z \u001b[38;5;241m=\u001b[39m z \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mu_bias[X]\u001b[38;5;241m.\u001b[39mreshape((\u001b[38;5;241m*\u001b[39mX\u001b[38;5;241m.\u001b[39mshape, \u001b[38;5;241m1\u001b[39m))\n\u001b[0;32m---> 51\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mz\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 53\u001b[0m forward_dist \u001b[38;5;241m=\u001b[39m Multinomial(total_count\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m 54\u001b[0m validate_args\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 55\u001b[0m probs\u001b[38;5;241m=\u001b[39my_pred)\n\u001b[1;32m 57\u001b[0m forward_dist \u001b[38;5;241m=\u001b[39m forward_dist\u001b[38;5;241m.\u001b[39mlog_prob(Y)\n", + "File \u001b[0;32m~/opt/miniconda3/envs/mmvec/lib/python3.10/site-packages/torch/nn/modules/module.py:1105\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_call_impl\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 1105\u001b[0m forward_call \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_slow_forward \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_tracing_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } @@ -145,61 +150,27 @@ }, { "cell_type": "code", - "execution_count": 31, - "id": "cbea21c6", + "execution_count": null, + "id": "56b1a17a", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Parameter containing:\n", - "tensor([-0.2304, -0.1210, -0.2237, 1.5584, -0.0357, -0.2358, -0.3364, -0.1522,\n", - " 0.1831, 0.3266, -0.3287, 0.6406, -0.4283, -0.4332, -0.3035, -0.2616,\n", - " -0.3637, -0.0561, 0.3958, -0.4435, 0.1320, 0.3295, 0.3745, -0.3914,\n", - " -0.2479, 1.1489, 0.0920, 0.1924, 0.8578, -0.4792, -0.3238, -0.2298,\n", - " 0.0590, -0.2976, 0.4175, -0.1658, 0.5618, 0.0517, -0.4560, 0.0813,\n", - " 1.3941, 0.2018, -0.4606, 0.0243, -0.3180, 0.2470, -0.2259, 0.0057,\n", - " 0.8352, -0.4047, 0.0856, -0.3115, -0.2540, -0.4270, -0.2389, -0.3265,\n", - " -0.0604, -0.1562, 0.2893, -0.3153, 0.2701, -0.2802, -0.0060, -0.0257,\n", - " -0.1017, -0.2101, -0.4707, -0.3083, -0.4171, 0.9620, -0.2070, -0.0076,\n", - " -0.5172, -0.4790, -0.3156, 0.4800, -0.1798, -0.1373, -0.5174, -0.0725,\n", - " -0.3061, -0.2664, 0.0945, -0.4450], requires_grad=True)" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.decoder.linear.bias" ] }, { "cell_type": "code", - "execution_count": 35, - "id": "2019c1ff", + "execution_count": null, + "id": "9797ed91", "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "unsupported operand type(s) for @: 'Embedding' and 'Linear'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [35]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m@\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\n", - "\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for @: 'Embedding' and 'Linear'" - ] - } - ], + "outputs": [], "source": [ - "model.encoder @ model.decoder.linear" + "import pandas as pd" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 12, "id": "a5f97b6d", "metadata": {}, "outputs": [ @@ -210,7 +181,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [27]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ranks \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241m.\u001b[39mDataFrame(model\u001b[38;5;241m.\u001b[39mranks(),\n\u001b[1;32m 2\u001b[0m index\u001b[38;5;241m=\u001b[39mmicrobe_idx,\n\u001b[1;32m 3\u001b[0m columns\u001b[38;5;241m=\u001b[39mmetabolite_idx)\n", + "Input \u001b[0;32mIn [12]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ranks \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241m.\u001b[39mDataFrame(model\u001b[38;5;241m.\u001b[39mranks(),\n\u001b[1;32m 2\u001b[0m index\u001b[38;5;241m=\u001b[39mmicrobe_idx,\n\u001b[1;32m 3\u001b[0m columns\u001b[38;5;241m=\u001b[39mmetabolite_idx)\n", "\u001b[0;31mNameError\u001b[0m: name 'pd' is not defined" ] } @@ -221,6 +192,16 @@ " columns=metabolite_idx)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "248a1eff", + "metadata": {}, + "outputs": [], + "source": [ + "ranks" + ] + }, { "cell_type": "code", "execution_count": null, @@ -228,7 +209,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.decoder.linear.weight.shape" + "(model.encoder.weight.detach() @ model.decoder.linear.weight.detach().T)" ] }, { diff --git a/mmvec/ALR.py b/mmvec/ALR.py index 25ca02e..99c3b21 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -1,4 +1,5 @@ import torch +import pandas as pd import torch.nn as nn import torch.nn.functional as F from torch.distributions import Multinomial, Normal @@ -57,35 +58,36 @@ def forward(self, X, Y): l_y = forward_dist.mean(0).mean() - # LU u_weights = self.encoder.weight l_u = Normal(0, self.sigma_u).log_prob(u_weights).sum() - #l_u = torch.normal(0, self.sigma_u).log_prob(z + l_ubias = Normal(0, self.sigma_u).log_prob(self.u_bias).sum() - # LV - # index zero currently holds "linear", may need to be changed later v_weights = self.decoder.linear.weight l_v = Normal(0, self.sigma_v).log_prob(v_weights).sum() + l_vbias = Normal(0, self.sigma_v).log_prob(self.decoder.linear.bias).sum() - likelihood_sum = l_y + l_u + l_v + likelihood_sum = l_y + l_u + l_v + l_ubias + l_vbias return likelihood_sum + # def get_ordination(self): + # ranks_df = pd.DataFrame(self.ranks(), + + # pass + + def ranks(self): U = torch.cat( (torch.ones((self.num_microbes, 1)), self.u_bias.detach(), self.encoder.weight.detach()), - dim=-1) + dim=1) V = torch.cat( (self.decoder.linear.bias.detach().unsqueeze(dim=0), - torch.from_numpy(np.ones((1, self.num_metabolites - 1))), + torch.ones((1, self.num_metabolites - 1)), self.decoder.linear.weight.detach().T), dim=0) - #res = np.hstack((np.zeros((self.num_microbes - 1, 1)), modelU @ modelV)) - res = torch.cat((torch.zeros((self.num_microbes -1, 1)), U @ V), - dim=-1) + + res = torch.cat((torch.zeros((self.num_microbes, 1)), U @ V), dim=1) res = res - res.mean(axis=1).reshape(-1, 1) - # perform SVD here?..... return res - From fcc066bb38c39c3e542193115be8747aaa94f57a Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Fri, 29 Apr 2022 13:07:15 -0700 Subject: [PATCH 07/27] FEAT: function for ordination created /getting ready to refactor model init and data --- mmvec/ALR.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/mmvec/ALR.py b/mmvec/ALR.py index 99c3b21..4f5d7c2 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -1,11 +1,11 @@ -import torch import pandas as pd + +import torch import torch.nn as nn +import torch.linalg as linalg import torch.nn.functional as F from torch.distributions import Multinomial, Normal -import numpy as np - class LinearALR(nn.Module): def __init__(self, input_dim, output_dim): @@ -32,10 +32,6 @@ def __init__(self, num_microbes, num_metabolites, latent_dim, sigma_u, self.u_bias = nn.parameter.Parameter(torch.randn((num_microbes, 1))) self.encoder = nn.Embedding(num_microbes, latent_dim) - #self.decoder = nn.Sequential( - # nn.Linear(latent_dim, num_metabolites), - # nn.Softmax(dim=2) - # ) self.decoder = LinearALR(latent_dim, num_metabolites) self.sigma_u = sigma_u @@ -69,13 +65,14 @@ def forward(self, X, Y): likelihood_sum = l_y + l_u + l_v + l_ubias + l_vbias return likelihood_sum - # def get_ordination(self): - # ranks_df = pd.DataFrame(self.ranks(), + def get_ordination(self, equalize_biplot=False): + ranks = self.ranks_matrix - self.ranks_matrix.mean(dim=0) + u, s, v = linalg.svd(ranks, full_matrices=False) + print(u) + print(s) + print(v) - # pass - - - def ranks(self): + def ranks(self, microbe_ids, metabolite_ids): U = torch.cat( (torch.ones((self.num_microbes, 1)), self.u_bias.detach(), @@ -90,4 +87,7 @@ def ranks(self): res = torch.cat((torch.zeros((self.num_microbes, 1)), U @ V), dim=1) res = res - res.mean(axis=1).reshape(-1, 1) - return res + + self.ranks_matrix = res + self.ranks_df = pd.DataFrame(res, index=microbe_ids, + columns=metabolite_ids) From dc3eaa9ef441c16a3170a62db4ca5b035a34cc04 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Fri, 29 Apr 2022 17:27:46 -0700 Subject: [PATCH 08/27] FEAT: Produces OrdinationResults --- examples/refactor/ALR.ipynb | 1657 +++++++++++++++++++++++++++++++++-- mmvec/ALR.py | 92 +- mmvec/train.py | 6 +- 3 files changed, 1670 insertions(+), 85 deletions(-) diff --git a/examples/refactor/ALR.ipynb b/examples/refactor/ALR.ipynb index b5dbf14..10b203d 100644 --- a/examples/refactor/ALR.ipynb +++ b/examples/refactor/ALR.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "c73ac6c4", + "id": "2388b10b", "metadata": {}, "outputs": [], "source": [ @@ -13,21 +13,13 @@ { "cell_type": "code", "execution_count": 2, - "id": "15cefd0d", + "id": "a0aea731", "metadata": {}, "outputs": [], "source": [ "%autoreload 1" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "b5234303", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 3, @@ -35,14 +27,27 @@ "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import biom\n", "import torch\n", "import torch.nn as nn\n", "%aimport mmvec.ALR\n", "\n", - "import numpy as np" + "import numpy as np\n", + "\n", + "torch.manual_seed(15)" ] }, { @@ -55,6 +60,8 @@ "microbes = biom.load_table(\"./soil_microbes.biom\")\n", "metabolites = biom.load_table(\"./soil_metabolites.biom\")\n", "\n", + "model = mmvec.ALR.MMvecALR(microbes, metabolites, 15, sigma_u=1, sigma_v=1)\n", + "\n", "microbes = microbes.to_dataframe().T\n", "metabolites = metabolites.to_dataframe().T\n", "microbes = microbes.loc[metabolites.index]\n", @@ -72,42 +79,48 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "id": "9003e38d", "metadata": {}, "outputs": [], "source": [ - "model = mmvec.ALR.MMvecALR(microbe_count, metabolite_count, 15, sigma_u=1, sigma_v=1)\n", + "#model = mmvec.ALR.MMvecALR(microbe_count, metabolite_count, 15, sigma_u=1, sigma_v=1)\n", "learning_rate = 1e-3\n", "batch_size = 200\n", - "epochs = 10000\n", + "epochs = 100\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, maximize=True)" ] }, { "cell_type": "code", - "execution_count": 11, - "id": "1ced7f98", + "execution_count": 7, + "id": "6bc4fcf6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([84, 15])" + "tensor([[0.2017, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [0.0950, 0.1752, 0.0000, ..., 0.0000, 0.0013, 0.0000],\n", + " [0.1416, 0.0973, 0.0000, ..., 0.0000, 0.0064, 0.0000],\n", + " ...,\n", + " [0.0507, 0.0000, 0.0090, ..., 0.0058, 0.0000, 0.0051],\n", + " [0.0382, 0.0076, 0.0025, ..., 0.0000, 0.0009, 0.0100],\n", + " [0.0027, 0.0135, 0.0000, ..., 0.0000, 0.0008, 0.0036]])" ] }, - "execution_count": 11, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "model.decoder.linear.weight.shape" + "model.microbe_relative_freq" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 6, "id": "b977e212", "metadata": {}, "outputs": [ @@ -115,33 +128,831 @@ "name": "stdout", "output_type": "stream", "text": [ - "loss: -14500305920.0\n", + "torch.Size([200, 19])\n", + "tensor([[ 79, 139, 435, ..., 189, 202, 335],\n", + " [169, 433, 150, ..., 68, 136, 224],\n", + " [ 9, 1, 224, ..., 358, 224, 436],\n", + " ...,\n", + " [ 0, 156, 71, ..., 114, 402, 224],\n", + " [143, 0, 0, ..., 3, 189, 224],\n", + " [ 39, 426, 4, ..., 81, 280, 70]])\n", + "loss: -14241466368.0\n", "Batch #: 0\n", - "loss: -2583066368.0\n", - "Batch #: 500\n", - "loss: -1934131328.0\n", - "Batch #: 1000\n" + "torch.Size([200, 19])\n", + "tensor([[119, 424, 413, ..., 103, 416, 335],\n", + " [418, 126, 225, ..., 26, 7, 129],\n", + " [ 48, 1, 1, ..., 224, 438, 224],\n", + " ...,\n", + " [224, 1, 324, ..., 193, 0, 224],\n", + " [418, 1, 224, ..., 188, 0, 224],\n", + " [ 21, 415, 0, ..., 103, 224, 121]])\n", + "torch.Size([200, 19])\n", + "tensor([[284, 1, 191, ..., 97, 137, 26],\n", + " [ 0, 174, 70, ..., 440, 59, 335],\n", + " [455, 158, 49, ..., 451, 141, 147],\n", + " ...,\n", + " [455, 269, 70, ..., 448, 335, 3],\n", + " [158, 455, 335, ..., 103, 59, 15],\n", + " [431, 101, 191, ..., 0, 125, 202]])\n", + "torch.Size([200, 19])\n", + "tensor([[439, 66, 60, ..., 103, 335, 224],\n", + " [430, 159, 283, ..., 188, 213, 335],\n", + " [ 37, 142, 1, ..., 103, 302, 56],\n", + " ...,\n", + " [174, 0, 0, ..., 103, 4, 224],\n", + " [ 0, 460, 225, ..., 437, 441, 436],\n", + " [ 5, 411, 166, ..., 419, 413, 412]])\n", + "torch.Size([200, 19])\n", + "tensor([[431, 18, 92, ..., 103, 413, 224],\n", + " [ 0, 423, 0, ..., 97, 79, 440],\n", + " [224, 0, 1, ..., 114, 79, 418],\n", + " ...,\n", + " [191, 434, 433, ..., 85, 154, 166],\n", + " [ 0, 81, 169, ..., 122, 10, 52],\n", + " [158, 1, 0, ..., 103, 335, 189]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 38, 0, ..., 40, 300, 224],\n", + " [283, 182, 455, ..., 32, 35, 224],\n", + " [152, 70, 286, ..., 103, 5, 465],\n", + " ...,\n", + " [460, 0, 48, ..., 266, 0, 147],\n", + " [ 0, 180, 0, ..., 3, 71, 224],\n", + " [ 59, 1, 1, ..., 0, 323, 158]])\n", + "torch.Size([200, 19])\n", + "tensor([[431, 455, 1, ..., 193, 349, 440],\n", + " [291, 180, 100, ..., 0, 3, 202],\n", + " [ 48, 174, 0, ..., 97, 225, 440],\n", + " ...,\n", + " [147, 429, 417, ..., 103, 419, 440],\n", + " [391, 1, 43, ..., 2, 202, 37],\n", + " [ 62, 96, 62, ..., 103, 92, 324]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 26, 3, 91, ..., 108, 181, 224],\n", + " [191, 255, 118, ..., 53, 227, 224],\n", + " [125, 455, 48, ..., 422, 411, 224],\n", + " ...,\n", + " [ 96, 347, 84, ..., 263, 425, 114],\n", + " [116, 55, 81, ..., 103, 410, 242],\n", + " [ 33, 34, 4, ..., 425, 0, 25]])\n", + "torch.Size([200, 19])\n", + "tensor([[439, 34, 0, ..., 437, 423, 240],\n", + " [191, 0, 0, ..., 411, 5, 109],\n", + " [ 25, 3, 413, ..., 430, 417, 224],\n", + " ...,\n", + " [ 7, 0, 70, ..., 250, 335, 35],\n", + " [164, 3, 1, ..., 81, 258, 421],\n", + " [ 92, 1, 0, ..., 114, 413, 436]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 0, 48, ..., 263, 137, 224],\n", + " [110, 116, 366, ..., 103, 242, 224],\n", + " [ 0, 391, 1, ..., 0, 186, 202],\n", + " ...,\n", + " [ 0, 55, 34, ..., 424, 321, 440],\n", + " [ 92, 170, 385, ..., 236, 454, 114],\n", + " [169, 12, 380, ..., 92, 448, 376]])\n", + "torch.Size([200, 19])\n", + "tensor([[439, 1, 413, ..., 437, 136, 436],\n", + " [428, 443, 0, ..., 0, 1, 99],\n", + " [439, 441, 0, ..., 388, 51, 380],\n", + " ...,\n", + " [ 0, 393, 454, ..., 103, 413, 412],\n", + " [ 92, 259, 87, ..., 80, 216, 224],\n", + " [461, 142, 4, ..., 26, 150, 9]])\n", + "torch.Size([200, 19])\n", + "tensor([[370, 416, 380, ..., 424, 294, 224],\n", + " [428, 1, 201, ..., 103, 88, 306],\n", + " [ 90, 142, 335, ..., 132, 224, 440],\n", + " ...,\n", + " [444, 114, 291, ..., 452, 335, 112],\n", + " [206, 1, 0, ..., 425, 0, 415],\n", + " [ 49, 0, 54, ..., 448, 36, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[228, 116, 424, ..., 184, 451, 72],\n", + " [158, 171, 0, ..., 424, 224, 253],\n", + " [235, 455, 7, ..., 358, 421, 224],\n", + " ...,\n", + " [455, 439, 0, ..., 147, 88, 440],\n", + " [369, 1, 335, ..., 97, 418, 436],\n", + " [ 0, 107, 464, ..., 103, 109, 37]])\n", + "torch.Size([200, 19])\n", + "tensor([[131, 464, 4, ..., 0, 202, 81],\n", + " [ 16, 439, 69, ..., 430, 59, 224],\n", + " [411, 1, 161, ..., 208, 423, 37],\n", + " ...,\n", + " [185, 302, 417, ..., 76, 354, 274],\n", + " [ 0, 0, 420, ..., 147, 202, 141],\n", + " [ 0, 49, 0, ..., 256, 417, 190]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 1, 99, ..., 147, 0, 224],\n", + " [169, 151, 464, ..., 97, 0, 202],\n", + " [114, 0, 18, ..., 103, 33, 15],\n", + " ...,\n", + " [ 48, 1, 48, ..., 0, 447, 436],\n", + " [424, 248, 48, ..., 424, 413, 436],\n", + " [203, 455, 4, ..., 420, 224, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[231, 1, 26, ..., 424, 414, 440],\n", + " [201, 35, 413, ..., 111, 125, 440],\n", + " [ 0, 38, 224, ..., 87, 354, 415],\n", + " ...,\n", + " [ 0, 1, 413, ..., 431, 114, 90],\n", + " [ 0, 1, 411, ..., 97, 0, 15],\n", + " [350, 334, 180, ..., 422, 82, 105]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 28, 148, 462, ..., 51, 160, 51],\n", + " [280, 41, 286, ..., 103, 160, 224],\n", + " [ 59, 15, 125, ..., 80, 335, 37],\n", + " ...,\n", + " [ 0, 336, 48, ..., 245, 335, 421],\n", + " [206, 415, 0, ..., 424, 335, 452],\n", + " [ 58, 23, 279, ..., 26, 61, 92]])\n", + "torch.Size([200, 19])\n", + "tensor([[439, 418, 335, ..., 424, 224, 99],\n", + " [169, 95, 335, ..., 51, 92, 90],\n", + " [ 0, 360, 459, ..., 147, 136, 147],\n", + " ...,\n", + " [169, 71, 180, ..., 103, 13, 426],\n", + " [424, 288, 462, ..., 365, 397, 258],\n", + " [ 0, 429, 454, ..., 171, 154, 191]])\n", + "torch.Size([200, 19])\n", + "tensor([[100, 448, 93, ..., 423, 431, 224],\n", + " [ 98, 1, 233, ..., 103, 51, 437],\n", + " [153, 455, 48, ..., 424, 418, 26],\n", + " ...,\n", + " [169, 0, 443, ..., 451, 35, 436],\n", + " [ 12, 156, 223, ..., 103, 10, 15],\n", + " [ 79, 1, 92, ..., 455, 146, 440]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 184, 430, ..., 463, 15, 335],\n", + " [446, 158, 0, ..., 35, 452, 436],\n", + " [ 3, 198, 1, ..., 430, 6, 147],\n", + " ...,\n", + " [267, 3, 0, ..., 0, 7, 412],\n", + " [ 0, 1, 0, ..., 103, 354, 224],\n", + " [458, 92, 1, ..., 97, 335, 37]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 70, 1, 429, ..., 3, 70, 224],\n", + " [446, 3, 0, ..., 17, 335, 421],\n", + " [158, 70, 290, ..., 147, 136, 449],\n", + " ...,\n", + " [ 0, 0, 383, ..., 455, 40, 80],\n", + " [224, 438, 377, ..., 103, 37, 51],\n", + " [158, 1, 0, ..., 437, 369, 412]])\n", + "torch.Size([200, 19])\n", + "tensor([[231, 428, 91, ..., 103, 4, 224],\n", + " [228, 81, 89, ..., 275, 166, 441],\n", + " [ 0, 147, 413, ..., 2, 202, 224],\n", + " ...,\n", + " [132, 174, 457, ..., 424, 114, 436],\n", + " [454, 0, 99, ..., 420, 421, 185],\n", + " [186, 31, 0, ..., 246, 335, 412]])\n", + "torch.Size([200, 19])\n", + "tensor([[267, 350, 48, ..., 103, 332, 157],\n", + " [418, 1, 1, ..., 218, 260, 410],\n", + " [ 92, 92, 6, ..., 0, 224, 336],\n", + " ...,\n", + " [170, 0, 152, ..., 107, 15, 452],\n", + " [ 69, 145, 392, ..., 103, 52, 55],\n", + " [444, 460, 89, ..., 420, 164, 358]])\n", + "torch.Size([200, 19])\n", + "tensor([[455, 0, 1, ..., 103, 419, 224],\n", + " [434, 439, 0, ..., 161, 96, 429],\n", + " [ 0, 13, 294, ..., 0, 65, 224],\n", + " ...,\n", + " [169, 114, 1, ..., 463, 302, 224],\n", + " [444, 28, 48, ..., 437, 79, 410],\n", + " [124, 15, 429, ..., 97, 59, 301]])\n", + "torch.Size([200, 19])\n", + "tensor([[426, 391, 96, ..., 103, 166, 1],\n", + " [431, 171, 379, ..., 103, 132, 440],\n", + " [143, 435, 0, ..., 129, 412, 56],\n", + " ...,\n", + " [462, 159, 0, ..., 324, 26, 440],\n", + " [158, 81, 271, ..., 81, 136, 436],\n", + " [ 0, 224, 200, ..., 430, 418, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 33, 17, 455, ..., 189, 15, 15],\n", + " [169, 8, 335, ..., 171, 441, 7],\n", + " [428, 50, 60, ..., 103, 0, 440],\n", + " ...,\n", + " [258, 17, 1, ..., 136, 322, 411],\n", + " [428, 77, 92, ..., 115, 242, 440],\n", + " [119, 70, 92, ..., 103, 302, 358]])\n", + "torch.Size([200, 19])\n", + "tensor([[269, 1, 202, ..., 147, 465, 335],\n", + " [ 0, 0, 48, ..., 3, 108, 440],\n", + " [199, 1, 335, ..., 437, 105, 30],\n", + " ...,\n", + " [ 0, 148, 48, ..., 103, 189, 166],\n", + " [ 53, 3, 4, ..., 410, 114, 436],\n", + " [428, 125, 433, ..., 103, 414, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[419, 439, 225, ..., 437, 224, 92],\n", + " [ 0, 1, 435, ..., 81, 335, 224],\n", + " [279, 45, 48, ..., 103, 417, 258],\n", + " ...,\n", + " [241, 347, 48, ..., 424, 62, 59],\n", + " [455, 70, 0, ..., 424, 58, 347],\n", + " [ 81, 455, 1, ..., 147, 0, 440]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 1, 0, ..., 451, 417, 92],\n", + " [454, 84, 126, ..., 0, 2, 224],\n", + " [455, 50, 48, ..., 147, 454, 37],\n", + " ...,\n", + " [111, 0, 41, ..., 92, 65, 424],\n", + " [ 0, 107, 271, ..., 103, 126, 37],\n", + " [130, 73, 3, ..., 0, 413, 455]])\n", + "torch.Size([200, 19])\n", + "tensor([[148, 203, 0, ..., 103, 146, 258],\n", + " [ 0, 1, 324, ..., 330, 231, 440],\n", + " [169, 0, 258, ..., 103, 105, 440],\n", + " ...,\n", + " [347, 59, 462, ..., 30, 125, 189],\n", + " [169, 229, 8, ..., 109, 0, 120],\n", + " [428, 439, 335, ..., 424, 411, 258]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 33, 269, ..., 103, 315, 15],\n", + " [ 50, 156, 152, ..., 422, 202, 440],\n", + " [418, 447, 48, ..., 424, 124, 224],\n", + " ...,\n", + " [137, 59, 254, ..., 103, 26, 440],\n", + " [241, 143, 391, ..., 437, 213, 224],\n", + " [402, 191, 1, ..., 103, 335, 436]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 3, 1, 183, ..., 424, 114, 147],\n", + " [ 84, 1, 4, ..., 81, 225, 465],\n", + " [ 0, 460, 131, ..., 103, 79, 281],\n", + " ...,\n", + " [419, 1, 0, ..., 430, 243, 224],\n", + " [455, 258, 5, ..., 15, 421, 440],\n", + " [161, 0, 288, ..., 103, 354, 436]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 314, 4, ..., 295, 0, 330],\n", + " [455, 107, 20, ..., 81, 0, 335],\n", + " [459, 252, 366, ..., 103, 224, 114],\n", + " ...,\n", + " [431, 1, 269, ..., 411, 72, 437],\n", + " [428, 438, 0, ..., 424, 369, 79],\n", + " [ 0, 455, 1, ..., 103, 79, 401]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 174, 418, ..., 103, 413, 335],\n", + " [258, 366, 0, ..., 103, 336, 224],\n", + " [ 58, 258, 140, ..., 20, 96, 9],\n", + " ...,\n", + " [186, 423, 3, ..., 429, 45, 15],\n", + " [439, 49, 391, ..., 103, 40, 224],\n", + " [428, 1, 428, ..., 147, 313, 37]])\n" ] }, { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [9]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m maybe \u001b[38;5;241m=\u001b[39m \u001b[43mmmvec\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmmvec_training_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmicrobes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmicrobe_relative_frequency\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mmetabolites\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetabolites\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/work/qiime2/mmvec/mmvec/train.py:12\u001b[0m, in \u001b[0;36mmmvec_training_loop\u001b[0;34m(microbes, metabolites, model, optimizer, batch_size, epochs)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(epochs):\n\u001b[1;32m 8\u001b[0m draws \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmultinomial(microbes,\n\u001b[1;32m 9\u001b[0m batch_size,\n\u001b[1;32m 10\u001b[0m replacement\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mT\n\u001b[0;32m---> 12\u001b[0m mmvec_model \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdraws\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetabolites\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 15\u001b[0m mmvec_model\u001b[38;5;241m.\u001b[39mbackward()\n", - "File \u001b[0;32m~/opt/miniconda3/envs/mmvec/lib/python3.10/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/work/qiime2/mmvec/mmvec/ALR.py:51\u001b[0m, in \u001b[0;36mMMvecALR.forward\u001b[0;34m(self, X, Y)\u001b[0m\n\u001b[1;32m 49\u001b[0m z \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencoder(X)\n\u001b[1;32m 50\u001b[0m z \u001b[38;5;241m=\u001b[39m z \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mu_bias[X]\u001b[38;5;241m.\u001b[39mreshape((\u001b[38;5;241m*\u001b[39mX\u001b[38;5;241m.\u001b[39mshape, \u001b[38;5;241m1\u001b[39m))\n\u001b[0;32m---> 51\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mz\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 53\u001b[0m forward_dist \u001b[38;5;241m=\u001b[39m Multinomial(total_count\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m 54\u001b[0m validate_args\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 55\u001b[0m probs\u001b[38;5;241m=\u001b[39my_pred)\n\u001b[1;32m 57\u001b[0m forward_dist \u001b[38;5;241m=\u001b[39m forward_dist\u001b[38;5;241m.\u001b[39mlog_prob(Y)\n", - "File \u001b[0;32m~/opt/miniconda3/envs/mmvec/lib/python3.10/site-packages/torch/nn/modules/module.py:1105\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_call_impl\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 1105\u001b[0m forward_call \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_slow_forward \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_tracing_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([200, 19])\n", + "tensor([[273, 12, 0, ..., 0, 335, 437],\n", + " [153, 464, 1, ..., 97, 39, 1],\n", + " [142, 433, 48, ..., 103, 26, 181],\n", + " ...,\n", + " [428, 1, 271, ..., 103, 413, 224],\n", + " [ 0, 31, 308, ..., 448, 335, 422],\n", + " [ 3, 0, 435, ..., 103, 264, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[452, 313, 1, ..., 424, 302, 224],\n", + " [ 34, 174, 335, ..., 45, 336, 15],\n", + " [ 72, 26, 295, ..., 81, 252, 438],\n", + " ...,\n", + " [169, 0, 49, ..., 147, 82, 112],\n", + " [ 0, 70, 170, ..., 437, 147, 458],\n", + " [ 3, 87, 114, ..., 103, 335, 37]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 48, 198, 249, ..., 103, 141, 336],\n", + " [ 0, 334, 180, ..., 411, 419, 15],\n", + " [439, 123, 0, ..., 103, 80, 224],\n", + " ...,\n", + " [439, 412, 101, ..., 430, 278, 15],\n", + " [423, 84, 0, ..., 0, 180, 436],\n", + " [ 79, 34, 200, ..., 0, 40, 440]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 72, 114, 1, ..., 147, 295, 440],\n", + " [239, 12, 1, ..., 402, 224, 420],\n", + " [ 0, 224, 383, ..., 137, 315, 440],\n", + " ...,\n", + " [423, 172, 69, ..., 103, 114, 420],\n", + " [297, 423, 1, ..., 103, 35, 9],\n", + " [424, 1, 269, ..., 424, 225, 190]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 165, 13, ..., 103, 354, 335],\n", + " [ 17, 150, 43, ..., 424, 347, 224],\n", + " [228, 0, 23, ..., 192, 434, 441],\n", + " ...,\n", + " [125, 92, 1, ..., 26, 180, 335],\n", + " [169, 1, 48, ..., 103, 0, 86],\n", + " [ 0, 0, 54, ..., 81, 202, 80]])\n", + "torch.Size([200, 19])\n", + "tensor([[131, 1, 61, ..., 295, 413, 92],\n", + " [ 0, 420, 48, ..., 197, 412, 335],\n", + " [146, 1, 335, ..., 423, 111, 224],\n", + " ...,\n", + " [400, 455, 60, ..., 440, 449, 224],\n", + " [280, 15, 3, ..., 147, 437, 336],\n", + " [400, 179, 0, ..., 0, 220, 436]])\n", + "torch.Size([200, 19])\n", + "tensor([[119, 174, 53, ..., 103, 202, 376],\n", + " [169, 28, 136, ..., 424, 35, 224],\n", + " [318, 1, 235, ..., 147, 418, 224],\n", + " ...,\n", + " [ 0, 242, 0, ..., 208, 413, 335],\n", + " [ 0, 1, 136, ..., 147, 202, 376],\n", + " [167, 1, 335, ..., 103, 96, 420]])\n", + "torch.Size([200, 19])\n", + "tensor([[105, 303, 433, ..., 465, 177, 59],\n", + " [ 0, 334, 335, ..., 0, 147, 376],\n", + " [ 7, 111, 191, ..., 402, 70, 109],\n", + " ...,\n", + " [ 32, 26, 4, ..., 430, 168, 377],\n", + " [276, 455, 106, ..., 437, 6, 378],\n", + " [ 39, 1, 95, ..., 147, 335, 335]])\n", + "torch.Size([200, 19])\n", + "tensor([[132, 153, 5, ..., 424, 410, 440],\n", + " [ 31, 123, 0, ..., 97, 335, 15],\n", + " [ 0, 347, 9, ..., 103, 5, 440],\n", + " ...,\n", + " [101, 284, 335, ..., 102, 354, 15],\n", + " [198, 83, 183, ..., 424, 177, 426],\n", + " [ 0, 3, 0, ..., 164, 429, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 412, 153, ..., 103, 73, 440],\n", + " [ 0, 258, 48, ..., 147, 335, 436],\n", + " [ 99, 1, 335, ..., 425, 56, 9],\n", + " ...,\n", + " [229, 3, 1, ..., 2, 335, 440],\n", + " [ 0, 59, 60, ..., 424, 441, 436],\n", + " [284, 84, 179, ..., 424, 161, 125]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 84, 336, ..., 172, 319, 82],\n", + " [130, 431, 1, ..., 147, 413, 15],\n", + " [419, 236, 18, ..., 81, 158, 37],\n", + " ...,\n", + " [ 0, 50, 48, ..., 423, 185, 1],\n", + " [435, 1, 48, ..., 70, 425, 15],\n", + " [147, 3, 335, ..., 103, 5, 147]])\n", + "torch.Size([200, 19])\n", + "tensor([[439, 225, 335, ..., 97, 0, 5],\n", + " [ 0, 234, 60, ..., 437, 416, 436],\n", + " [244, 248, 380, ..., 420, 26, 9],\n", + " ...,\n", + " [418, 156, 291, ..., 424, 150, 224],\n", + " [ 53, 439, 0, ..., 302, 79, 436],\n", + " [191, 361, 375, ..., 422, 110, 440]])\n", + "torch.Size([200, 19])\n", + "tensor([[347, 70, 433, ..., 133, 79, 259],\n", + " [114, 142, 4, ..., 424, 70, 37],\n", + " [ 0, 287, 48, ..., 430, 454, 224],\n", + " ...,\n", + " [ 48, 1, 455, ..., 147, 335, 86],\n", + " [ 35, 158, 428, ..., 376, 413, 335],\n", + " [142, 70, 150, ..., 32, 2, 147]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 92, 377, ..., 103, 259, 224],\n", + " [ 3, 28, 92, ..., 291, 0, 224],\n", + " [158, 26, 48, ..., 437, 441, 143],\n", + " ...,\n", + " [ 0, 421, 1, ..., 424, 88, 138],\n", + " [231, 433, 18, ..., 193, 1, 258],\n", + " [ 13, 435, 48, ..., 441, 226, 147]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 81, 241, 163, ..., 430, 7, 436],\n", + " [ 0, 139, 0, ..., 103, 181, 224],\n", + " [ 0, 242, 1, ..., 103, 421, 208],\n", + " ...,\n", + " [164, 1, 92, ..., 103, 330, 436],\n", + " [ 68, 81, 48, ..., 420, 313, 224],\n", + " [284, 124, 417, ..., 0, 414, 440]])\n", + "torch.Size([200, 19])\n", + "tensor([[163, 50, 269, ..., 455, 243, 9],\n", + " [ 0, 1, 303, ..., 103, 41, 92],\n", + " [ 49, 423, 48, ..., 0, 138, 55],\n", + " ...,\n", + " [ 16, 140, 0, ..., 424, 3, 440],\n", + " [433, 0, 199, ..., 302, 335, 452],\n", + " [455, 25, 1, ..., 424, 58, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 40, 0, ..., 410, 414, 436],\n", + " [165, 412, 125, ..., 103, 3, 380],\n", + " [ 76, 20, 0, ..., 358, 335, 440],\n", + " ...,\n", + " [191, 50, 269, ..., 0, 137, 37],\n", + " [323, 172, 335, ..., 92, 38, 37],\n", + " [ 0, 452, 136, ..., 103, 260, 365]])\n", + "torch.Size([200, 19])\n", + "tensor([[225, 1, 158, ..., 103, 185, 335],\n", + " [ 0, 258, 413, ..., 0, 109, 224],\n", + " [ 0, 31, 4, ..., 147, 354, 224],\n", + " ...,\n", + " [130, 1, 1, ..., 103, 213, 449],\n", + " [391, 225, 243, ..., 430, 26, 224],\n", + " [111, 418, 48, ..., 103, 29, 140]])\n", + "torch.Size([200, 19])\n", + "tensor([[239, 1, 8, ..., 103, 315, 15],\n", + " [461, 438, 1, ..., 39, 433, 224],\n", + " [411, 380, 454, ..., 245, 429, 15],\n", + " ...,\n", + " [422, 92, 256, ..., 26, 109, 224],\n", + " [ 25, 31, 1, ..., 417, 0, 402],\n", + " [ 0, 460, 0, ..., 0, 164, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 7, 431, 413, ..., 330, 421, 147],\n", + " [126, 358, 413, ..., 437, 421, 15],\n", + " [131, 92, 48, ..., 302, 202, 224],\n", + " ...,\n", + " [284, 67, 1, ..., 103, 30, 189],\n", + " [ 0, 83, 462, ..., 245, 3, 431],\n", + " [ 0, 33, 413, ..., 103, 423, 92]])\n", + "torch.Size([200, 19])\n", + "tensor([[152, 431, 454, ..., 39, 441, 224],\n", + " [ 50, 32, 324, ..., 422, 45, 440],\n", + " [444, 191, 315, ..., 451, 59, 65],\n", + " ...,\n", + " [ 0, 334, 1, ..., 13, 0, 449],\n", + " [ 0, 184, 1, ..., 358, 410, 441],\n", + " [124, 152, 6, ..., 424, 0, 455]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 59, 36, 335, ..., 425, 147, 26],\n", + " [431, 439, 445, ..., 430, 114, 458],\n", + " [297, 1, 224, ..., 0, 398, 5],\n", + " ...,\n", + " [224, 114, 1, ..., 261, 411, 436],\n", + " [ 81, 183, 0, ..., 424, 414, 224],\n", + " [ 26, 49, 0, ..., 103, 413, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 84, 333, ..., 103, 412, 224],\n", + " [184, 202, 54, ..., 103, 417, 224],\n", + " [169, 1, 48, ..., 103, 429, 429],\n", + " ...,\n", + " [ 73, 334, 435, ..., 302, 412, 224],\n", + " [202, 92, 6, ..., 208, 109, 224],\n", + " [101, 54, 305, ..., 103, 336, 15]])\n", + "torch.Size([200, 19])\n", + "tensor([[439, 1, 48, ..., 465, 423, 147],\n", + " [ 17, 1, 462, ..., 154, 213, 147],\n", + " [ 0, 123, 115, ..., 103, 410, 440],\n", + " ...,\n", + " [462, 1, 107, ..., 424, 202, 1],\n", + " [ 48, 111, 462, ..., 316, 354, 224],\n", + " [419, 158, 9, ..., 26, 459, 28]])\n", + "torch.Size([200, 19])\n", + "tensor([[185, 1, 1, ..., 147, 451, 35],\n", + " [ 0, 37, 131, ..., 451, 0, 185],\n", + " [241, 224, 1, ..., 103, 459, 440],\n", + " ...,\n", + " [239, 1, 48, ..., 26, 455, 279],\n", + " [ 0, 0, 380, ..., 263, 3, 224],\n", + " [455, 1, 26, ..., 437, 224, 147]])\n", + "torch.Size([200, 19])\n", + "tensor([[411, 225, 317, ..., 441, 249, 82],\n", + " [152, 456, 0, ..., 64, 447, 335],\n", + " [ 28, 121, 34, ..., 424, 93, 335],\n", + " ...,\n", + " [276, 438, 0, ..., 448, 295, 436],\n", + " [191, 250, 0, ..., 2, 356, 37],\n", + " [147, 167, 131, ..., 81, 454, 58]])\n", + "torch.Size([200, 19])\n", + "tensor([[142, 57, 65, ..., 103, 45, 37],\n", + " [ 0, 438, 148, ..., 61, 222, 253],\n", + " [246, 99, 180, ..., 103, 274, 436],\n", + " ...,\n", + " [ 3, 0, 219, ..., 452, 335, 202],\n", + " [428, 1, 291, ..., 455, 202, 185],\n", + " [439, 1, 454, ..., 103, 459, 335]])\n", + "torch.Size([200, 19])\n", + "tensor([[119, 224, 4, ..., 133, 15, 224],\n", + " [139, 443, 424, ..., 64, 58, 208],\n", + " [174, 391, 48, ..., 97, 410, 419],\n", + " ...,\n", + " [ 0, 1, 219, ..., 103, 3, 440],\n", + " [169, 1, 107, ..., 103, 434, 440],\n", + " [411, 1, 152, ..., 97, 202, 35]])\n", + "torch.Size([200, 19])\n", + "tensor([[130, 133, 48, ..., 424, 411, 258],\n", + " [ 0, 447, 1, ..., 437, 35, 389],\n", + " [411, 0, 48, ..., 0, 39, 224],\n", + " ...,\n", + " [ 0, 59, 206, ..., 56, 45, 129],\n", + " [ 3, 0, 286, ..., 424, 132, 224],\n", + " [ 0, 135, 180, ..., 97, 465, 440]])\n", + "torch.Size([200, 19])\n", + "tensor([[125, 16, 342, ..., 424, 160, 436],\n", + " [241, 3, 1, ..., 422, 23, 224],\n", + " [ 0, 3, 413, ..., 424, 335, 133],\n", + " ...,\n", + " [254, 5, 151, ..., 3, 336, 3],\n", + " [451, 464, 457, ..., 97, 12, 9],\n", + " [412, 191, 324, ..., 324, 413, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[439, 1, 335, ..., 0, 27, 160],\n", + " [439, 1, 31, ..., 103, 335, 37],\n", + " [169, 20, 4, ..., 424, 224, 410],\n", + " ...,\n", + " [152, 16, 24, ..., 9, 192, 9],\n", + " [ 59, 391, 48, ..., 419, 81, 224],\n", + " [ 0, 3, 48, ..., 114, 40, 1]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 1, 0, ..., 365, 79, 38],\n", + " [ 69, 464, 324, ..., 147, 197, 15],\n", + " [169, 433, 269, ..., 103, 414, 225],\n", + " ...,\n", + " [174, 103, 1, ..., 103, 354, 420],\n", + " [ 0, 1, 0, ..., 0, 125, 70],\n", + " [ 16, 1, 347, ..., 173, 224, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[439, 1, 39, ..., 431, 224, 224],\n", + " [117, 84, 60, ..., 445, 20, 436],\n", + " [158, 455, 179, ..., 103, 451, 436],\n", + " ...,\n", + " [164, 433, 48, ..., 26, 442, 440],\n", + " [456, 455, 69, ..., 424, 126, 335],\n", + " [455, 1, 57, ..., 137, 4, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[293, 1, 1, ..., 424, 164, 441],\n", + " [323, 0, 48, ..., 425, 33, 434],\n", + " [203, 1, 0, ..., 424, 0, 440],\n", + " ...,\n", + " [139, 419, 9, ..., 192, 410, 258],\n", + " [ 81, 198, 148, ..., 76, 0, 224],\n", + " [ 48, 107, 0, ..., 437, 79, 224]])\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([200, 19])\n", + "tensor([[169, 158, 335, ..., 115, 164, 335],\n", + " [336, 31, 429, ..., 136, 410, 422],\n", + " [234, 358, 328, ..., 451, 40, 224],\n", + " ...,\n", + " [ 9, 234, 48, ..., 402, 335, 461],\n", + " [174, 455, 0, ..., 97, 462, 15],\n", + " [139, 156, 219, ..., 424, 294, 15]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 455, 324, ..., 148, 354, 140],\n", + " [ 92, 70, 258, ..., 103, 279, 224],\n", + " [439, 435, 48, ..., 85, 215, 147],\n", + " ...,\n", + " [180, 1, 0, ..., 32, 28, 224],\n", + " [ 0, 250, 0, ..., 103, 441, 202],\n", + " [ 0, 1, 196, ..., 97, 80, 336]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 48, 70, 1, ..., 314, 258, 440],\n", + " [ 0, 0, 48, ..., 424, 105, 440],\n", + " [ 0, 35, 380, ..., 437, 147, 434],\n", + " ...,\n", + " [228, 183, 4, ..., 437, 224, 224],\n", + " [206, 70, 224, ..., 147, 335, 224],\n", + " [460, 269, 48, ..., 424, 431, 18]])\n", + "torch.Size([200, 19])\n", + "tensor([[439, 1, 12, ..., 314, 231, 114],\n", + " [439, 92, 335, ..., 314, 465, 335],\n", + " [ 5, 426, 256, ..., 224, 323, 446],\n", + " ...,\n", + " [ 69, 418, 294, ..., 147, 444, 440],\n", + " [107, 433, 417, ..., 15, 0, 371],\n", + " [158, 439, 65, ..., 246, 402, 335]])\n", + "torch.Size([200, 19])\n", + "tensor([[358, 191, 429, ..., 103, 397, 226],\n", + " [199, 251, 435, ..., 424, 413, 224],\n", + " [336, 0, 48, ..., 64, 217, 335],\n", + " ...,\n", + " [169, 129, 449, ..., 437, 126, 224],\n", + " [158, 179, 0, ..., 97, 27, 421],\n", + " [ 0, 444, 302, ..., 103, 105, 458]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 0, 48, ..., 189, 213, 40],\n", + " [ 0, 151, 151, ..., 424, 1, 440],\n", + " [151, 291, 452, ..., 33, 376, 99],\n", + " ...,\n", + " [104, 121, 149, ..., 147, 146, 224],\n", + " [145, 152, 6, ..., 103, 213, 433],\n", + " [415, 70, 3, ..., 103, 224, 15]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 81, 0, ..., 193, 10, 189],\n", + " [ 0, 1, 424, ..., 97, 30, 189],\n", + " [ 38, 428, 1, ..., 103, 441, 347],\n", + " ...,\n", + " [ 24, 191, 424, ..., 430, 5, 72],\n", + " [239, 0, 286, ..., 103, 125, 436],\n", + " [357, 156, 48, ..., 438, 269, 9]])\n", + "torch.Size([200, 19])\n", + "tensor([[418, 1, 41, ..., 115, 417, 411],\n", + " [ 0, 180, 26, ..., 420, 96, 1],\n", + " [ 88, 0, 335, ..., 103, 59, 92],\n", + " ...,\n", + " [444, 70, 241, ..., 30, 67, 440],\n", + " [391, 321, 1, ..., 417, 102, 224],\n", + " [254, 101, 89, ..., 431, 76, 335]])\n", + "torch.Size([200, 19])\n", + "tensor([[119, 424, 8, ..., 415, 414, 224],\n", + " [174, 151, 0, ..., 430, 410, 224],\n", + " [358, 1, 462, ..., 30, 425, 173],\n", + " ...,\n", + " [202, 151, 202, ..., 92, 459, 224],\n", + " [291, 149, 413, ..., 85, 225, 410],\n", + " [ 72, 66, 21, ..., 266, 40, 145]])\n", + "torch.Size([200, 19])\n", + "tensor([[158, 224, 1, ..., 103, 0, 190],\n", + " [ 0, 1, 291, ..., 291, 417, 224],\n", + " [433, 443, 1, ..., 437, 59, 436],\n", + " ...,\n", + " [439, 455, 93, ..., 422, 335, 224],\n", + " [241, 142, 294, ..., 262, 70, 37],\n", + " [250, 1, 8, ..., 330, 136, 436]])\n", + "torch.Size([200, 19])\n", + "tensor([[180, 0, 0, ..., 411, 133, 72],\n", + " [ 25, 81, 8, ..., 147, 7, 335],\n", + " [ 0, 3, 48, ..., 103, 180, 181],\n", + " ...,\n", + " [ 0, 59, 10, ..., 424, 0, 258],\n", + " [ 0, 33, 76, ..., 103, 335, 26],\n", + " [159, 0, 206, ..., 137, 410, 335]])\n", + "torch.Size([200, 19])\n", + "tensor([[458, 438, 180, ..., 80, 40, 37],\n", + " [ 0, 1, 0, ..., 451, 243, 15],\n", + " [145, 446, 269, ..., 9, 313, 145],\n", + " ...,\n", + " [ 3, 225, 33, ..., 448, 354, 436],\n", + " [439, 174, 98, ..., 147, 434, 224],\n", + " [430, 118, 317, ..., 369, 419, 281]])\n", + "torch.Size([200, 19])\n", + "tensor([[336, 98, 0, ..., 411, 10, 224],\n", + " [297, 419, 0, ..., 2, 126, 335],\n", + " [ 0, 1, 258, ..., 422, 376, 224],\n", + " ...,\n", + " [ 0, 32, 433, ..., 147, 227, 410],\n", + " [137, 447, 9, ..., 376, 146, 192],\n", + " [125, 75, 152, ..., 429, 150, 426]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 81, 348, 302, ..., 424, 224, 224],\n", + " [431, 140, 243, ..., 424, 335, 440],\n", + " [402, 16, 369, ..., 97, 0, 37],\n", + " ...,\n", + " [225, 1, 3, ..., 434, 27, 335],\n", + " [169, 171, 442, ..., 64, 150, 440],\n", + " [151, 435, 1, ..., 58, 231, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[220, 59, 299, ..., 437, 302, 224],\n", + " [ 99, 255, 0, ..., 154, 414, 37],\n", + " [ 0, 358, 158, ..., 189, 431, 37],\n", + " ...,\n", + " [ 0, 287, 304, ..., 424, 459, 5],\n", + " [167, 280, 48, ..., 424, 132, 224],\n", + " [165, 70, 4, ..., 437, 412, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[324, 435, 297, ..., 430, 413, 462],\n", + " [207, 1, 249, ..., 27, 279, 1],\n", + " [369, 1, 279, ..., 17, 15, 419],\n", + " ...,\n", + " [224, 0, 49, ..., 103, 79, 224],\n", + " [ 0, 169, 54, ..., 114, 20, 224],\n", + " [283, 3, 190, ..., 424, 413, 140]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 81, 78, 0, ..., 147, 459, 368],\n", + " [131, 270, 159, ..., 103, 224, 440],\n", + " [257, 1, 23, ..., 256, 264, 41],\n", + " ...,\n", + " [174, 26, 23, ..., 103, 125, 258],\n", + " [191, 121, 1, ..., 75, 72, 62],\n", + " [280, 0, 140, ..., 103, 3, 295]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 1, 48, ..., 103, 264, 158],\n", + " [ 73, 58, 60, ..., 103, 3, 335],\n", + " [ 16, 334, 0, ..., 103, 138, 438],\n", + " ...,\n", + " [136, 1, 335, ..., 269, 315, 436],\n", + " [411, 1, 48, ..., 103, 222, 9],\n", + " [169, 1, 237, ..., 0, 461, 114]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 26, 34, 1, ..., 256, 0, 90],\n", + " [131, 121, 1, ..., 424, 114, 224],\n", + " [131, 1, 0, ..., 424, 459, 446],\n", + " ...,\n", + " [ 0, 76, 1, ..., 103, 5, 418],\n", + " [ 0, 1, 48, ..., 103, 27, 121],\n", + " [203, 35, 305, ..., 424, 64, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 87, 0, 114, ..., 147, 291, 436],\n", + " [105, 80, 115, ..., 103, 321, 224],\n", + " [454, 206, 48, ..., 103, 380, 336],\n", + " ...,\n", + " [ 0, 58, 100, ..., 154, 160, 224],\n", + " [169, 1, 335, ..., 0, 452, 37],\n", + " [ 0, 297, 1, ..., 425, 330, 37]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 84, 1, 4, ..., 103, 459, 124],\n", + " [ 77, 369, 420, ..., 103, 410, 224],\n", + " [ 96, 75, 335, ..., 13, 15, 224],\n", + " ...,\n", + " [ 31, 0, 69, ..., 45, 202, 436],\n", + " [ 72, 250, 4, ..., 434, 465, 412],\n", + " [ 33, 96, 0, ..., 139, 89, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 39, 152, 335, ..., 448, 59, 28],\n", + " [439, 158, 454, ..., 103, 114, 173],\n", + " [ 0, 75, 24, ..., 437, 441, 440],\n", + " ...,\n", + " [418, 111, 294, ..., 103, 414, 347],\n", + " [460, 358, 74, ..., 103, 419, 369],\n", + " [ 0, 15, 225, ..., 227, 269, 246]])\n", + "torch.Size([200, 19])\n", + "tensor([[447, 10, 335, ..., 103, 416, 380],\n", + " [ 0, 103, 48, ..., 103, 136, 37],\n", + " [402, 262, 0, ..., 103, 376, 295],\n", + " ...,\n", + " [ 17, 0, 0, ..., 103, 224, 109],\n", + " [257, 1, 258, ..., 147, 335, 437],\n", + " [370, 0, 81, ..., 103, 258, 202]])\n", + "torch.Size([200, 19])\n", + "tensor([[135, 12, 48, ..., 411, 109, 224],\n", + " [428, 1, 335, ..., 422, 147, 380],\n", + " [119, 0, 48, ..., 61, 146, 147],\n", + " ...,\n", + " [415, 1, 0, ..., 302, 319, 274],\n", + " [ 86, 59, 8, ..., 47, 335, 440],\n", + " [ 0, 140, 31, ..., 425, 202, 157]])\n", + "torch.Size([200, 19])\n", + "tensor([[239, 60, 106, ..., 148, 213, 90],\n", + " [167, 0, 225, ..., 81, 0, 1],\n", + " [410, 81, 65, ..., 311, 111, 92],\n", + " ...,\n", + " [337, 3, 0, ..., 26, 414, 225],\n", + " [ 0, 66, 26, ..., 103, 0, 224],\n", + " [241, 428, 429, ..., 103, 0, 92]])\n", + "torch.Size([200, 19])\n", + "tensor([[159, 1, 0, ..., 103, 402, 15],\n", + " [ 5, 0, 458, ..., 130, 225, 15],\n", + " [ 0, 0, 101, ..., 147, 59, 440],\n", + " ...,\n", + " [411, 1, 380, ..., 430, 30, 15],\n", + " [456, 458, 136, ..., 103, 111, 1],\n", + " [ 96, 1, 92, ..., 455, 419, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[337, 1, 69, ..., 30, 111, 436],\n", + " [269, 139, 0, ..., 103, 147, 418],\n", + " [126, 223, 335, ..., 437, 412, 92],\n", + " ...,\n", + " [148, 459, 1, ..., 256, 319, 335],\n", + " [ 52, 202, 136, ..., 330, 441, 436],\n", + " [ 0, 25, 92, ..., 437, 108, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 418, 413, ..., 81, 31, 440],\n", + " [366, 280, 0, ..., 424, 465, 440],\n", + " [ 3, 70, 92, ..., 103, 376, 15],\n", + " ...,\n", + " [159, 206, 0, ..., 0, 109, 35],\n", + " [ 0, 387, 152, ..., 147, 40, 146],\n", + " [ 3, 291, 4, ..., 457, 136, 133]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 0, 1, 210, ..., 437, 181, 121],\n", + " [380, 0, 461, ..., 20, 93, 99],\n", + " [169, 84, 48, ..., 30, 51, 440],\n", + " ...,\n", + " [309, 455, 3, ..., 147, 313, 336],\n", + " [284, 1, 443, ..., 103, 202, 147],\n", + " [361, 1, 446, ..., 103, 354, 15]])\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([200, 19])\n", + "tensor([[ 48, 1, 26, ..., 424, 356, 235],\n", + " [ 0, 0, 20, ..., 114, 423, 437],\n", + " [191, 1, 299, ..., 103, 425, 440],\n", + " ...,\n", + " [100, 67, 1, ..., 437, 37, 94],\n", + " [ 0, 70, 118, ..., 103, 332, 422],\n", + " [ 73, 455, 0, ..., 455, 441, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[439, 1, 93, ..., 64, 40, 224],\n", + " [ 67, 3, 294, ..., 64, 224, 440],\n", + " [169, 337, 1, ..., 147, 146, 1],\n", + " ...,\n", + " [444, 423, 48, ..., 81, 335, 38],\n", + " [439, 1, 48, ..., 103, 0, 420],\n", + " [ 0, 0, 48, ..., 426, 221, 224]])\n", + "torch.Size([200, 19])\n", + "tensor([[ 59, 1, 0, ..., 317, 105, 335],\n", + " [439, 1, 224, ..., 414, 37, 440],\n", + " [ 70, 1, 0, ..., 3, 295, 436],\n", + " ...,\n", + " [415, 1, 0, ..., 136, 224, 335],\n", + " [213, 0, 1, ..., 103, 411, 436],\n", + " [ 0, 0, 0, ..., 358, 414, 1]])\n" ] } ], "source": [ - "maybe = mmvec.train.mmvec_training_loop(microbes=microbe_relative_frequency,\n", - " metabolites=metabolites,\n", + "maybe = mmvec.train.mmvec_training_loop(\n", " model=model,\n", " optimizer=optimizer,\n", " batch_size=batch_size,\n", @@ -151,65 +962,629 @@ { "cell_type": "code", "execution_count": null, - "id": "56b1a17a", + "id": "50f42c33", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "4b7fc6fc", + "metadata": {}, + "outputs": [], + "source": [ + "model.ranks()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c962e22f", + "metadata": {}, + "outputs": [], + "source": [ + "# h = model.ranks_df - model.ranks_df.mean(axis=0)\n", + "h = model.ranks_matrix\n", + "\n", + "k = model.latent_dim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72dfa241", "metadata": {}, "outputs": [], "source": [ - "model.decoder.linear.bias" + "h.mean(dim=0).shape" ] }, { "cell_type": "code", "execution_count": null, - "id": "9797ed91", + "id": "564f2523", "metadata": {}, "outputs": [], "source": [ - "import pandas as pd" + "from torch.linalg import svd\n" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "a5f97b6d", + "execution_count": null, + "id": "1ccd4896", + "metadata": {}, + "outputs": [], + "source": [ + "u, s, v = svd(h, full_matrices=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b952d538", + "metadata": {}, + "outputs": [], + "source": [ + "u.shape, s.shape, v.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa680b9b", + "metadata": {}, + "outputs": [], + "source": [ + "s" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "79869201", "metadata": {}, "outputs": [ { - "ename": "NameError", - "evalue": "name 'pd' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [12]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ranks \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241m.\u001b[39mDataFrame(model\u001b[38;5;241m.\u001b[39mranks(),\n\u001b[1;32m 2\u001b[0m index\u001b[38;5;241m=\u001b[39mmicrobe_idx,\n\u001b[1;32m 3\u001b[0m columns\u001b[38;5;241m=\u001b[39mmetabolite_idx)\n", - "\u001b[0;31mNameError\u001b[0m: name 'pd' is not defined" + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
(2,3-dihydroxy-3-methylbutanoate)(2,5-diaminohexanoate)(3-hydroxypyridine)(3-methyladenine)(4-oxoproline)(5,6-dihydrothymine)(alanyl-leucine)(dehydroalanine)(glycero-3-phosphoethanolamine)(indoleacrylate)...thyminetryptophantyrosineuracilurateuridineurocanatevalinexanthinexylitol
rplo 1 (Cyanobacteria)1.1757660.067746-0.552615-0.1754780.6255560.0718180.807590-0.603510-0.627171-0.381616...-0.0435110.210599-0.1623330.025163-0.007784-0.5240750.659753-0.782468-0.526718-0.181540
rplo 2 (Firmicutes)0.258168-0.7441900.057575-0.1398430.546534-0.1656670.343791-0.439517-0.666039-0.410150...0.1379200.1694660.3499170.342212-0.071381-0.052086-0.100996-1.005960-0.265847-0.966022
rplo 60 (Firmicutes)0.958025-0.507875-0.8934010.241754-0.286902-0.0437520.2722530.417226-0.2924480.065090...0.6973500.051952-0.6837150.277871-0.3961330.8440740.610815-0.597109-0.023769-0.329333
rplo 7 (Actinobacteria)0.9198030.355543-0.450468-0.3769220.4424900.0494980.4198290.1836860.026987-0.625513...-0.1346560.4477720.1403080.308640-0.013243-0.8494601.202816-0.2298810.251655-0.254032
rplo 10 (Firmicutes)1.143667-0.617700-0.6932220.1995390.394505-0.239950-0.1745970.0469500.183876-0.300330...0.495663-0.2309000.343985-0.021149-0.2151280.4891720.304803-0.8025400.111719-0.414126
..................................................................
rplo 95 (Proteobacteria)-0.8979420.236029-0.405040-0.3714270.381187-0.0110430.1398770.2280630.1039170.207617...0.3556110.438821-0.261516-0.1030440.0158210.6492420.696189-0.050961-0.2506350.593225
rplo 96 (unknown)0.8719050.071470-0.382577-0.0898500.0997910.226773-0.2171440.7699840.7849190.361632...0.619834-0.3059870.616732-0.252027-0.8365410.1751150.724142-0.492409-0.038374-0.038857
rplo 97 (Firmicutes)0.0645210.104134-0.882605-0.478958-0.463571-0.6453460.0435120.4364980.9191830.265277...-0.427435-0.177776-0.9167010.122711-0.3019270.9857570.0654690.094737-0.151891-0.177370
rplo 98 (Actinobacteria)0.4001700.120926-0.454491-0.0275680.462932-0.8092990.197001-0.167618-0.0427040.013358...1.281862-0.094519-0.3422070.218514-0.343730-0.531529-0.677377-0.828105-1.0183690.738386
rplo 99 (Cyanobacteria)1.129443-0.654812-0.0166710.0420460.2327940.0029600.087835-0.866775-1.0589320.196746...0.2675050.2468620.594449-0.000066-0.6790700.2334830.244907-1.055982-0.158424-0.869372
\n", + "

466 rows × 85 columns

\n", + "
" + ], + "text/plain": [ + " (2,3-dihydroxy-3-methylbutanoate) \\\n", + "rplo 1 (Cyanobacteria) 1.175766 \n", + "rplo 2 (Firmicutes) 0.258168 \n", + "rplo 60 (Firmicutes) 0.958025 \n", + "rplo 7 (Actinobacteria) 0.919803 \n", + "rplo 10 (Firmicutes) 1.143667 \n", + "... ... \n", + "rplo 95 (Proteobacteria) -0.897942 \n", + "rplo 96 (unknown) 0.871905 \n", + "rplo 97 (Firmicutes) 0.064521 \n", + "rplo 98 (Actinobacteria) 0.400170 \n", + "rplo 99 (Cyanobacteria) 1.129443 \n", + "\n", + " (2,5-diaminohexanoate) (3-hydroxypyridine) \\\n", + "rplo 1 (Cyanobacteria) 0.067746 -0.552615 \n", + "rplo 2 (Firmicutes) -0.744190 0.057575 \n", + "rplo 60 (Firmicutes) -0.507875 -0.893401 \n", + "rplo 7 (Actinobacteria) 0.355543 -0.450468 \n", + "rplo 10 (Firmicutes) -0.617700 -0.693222 \n", + "... ... ... \n", + "rplo 95 (Proteobacteria) 0.236029 -0.405040 \n", + "rplo 96 (unknown) 0.071470 -0.382577 \n", + "rplo 97 (Firmicutes) 0.104134 -0.882605 \n", + "rplo 98 (Actinobacteria) 0.120926 -0.454491 \n", + "rplo 99 (Cyanobacteria) -0.654812 -0.016671 \n", + "\n", + " (3-methyladenine) (4-oxoproline) \\\n", + "rplo 1 (Cyanobacteria) -0.175478 0.625556 \n", + "rplo 2 (Firmicutes) -0.139843 0.546534 \n", + "rplo 60 (Firmicutes) 0.241754 -0.286902 \n", + "rplo 7 (Actinobacteria) -0.376922 0.442490 \n", + "rplo 10 (Firmicutes) 0.199539 0.394505 \n", + "... ... ... \n", + "rplo 95 (Proteobacteria) -0.371427 0.381187 \n", + "rplo 96 (unknown) -0.089850 0.099791 \n", + "rplo 97 (Firmicutes) -0.478958 -0.463571 \n", + "rplo 98 (Actinobacteria) -0.027568 0.462932 \n", + "rplo 99 (Cyanobacteria) 0.042046 0.232794 \n", + "\n", + " (5,6-dihydrothymine) (alanyl-leucine) \\\n", + "rplo 1 (Cyanobacteria) 0.071818 0.807590 \n", + "rplo 2 (Firmicutes) -0.165667 0.343791 \n", + "rplo 60 (Firmicutes) -0.043752 0.272253 \n", + "rplo 7 (Actinobacteria) 0.049498 0.419829 \n", + "rplo 10 (Firmicutes) -0.239950 -0.174597 \n", + "... ... ... \n", + "rplo 95 (Proteobacteria) -0.011043 0.139877 \n", + "rplo 96 (unknown) 0.226773 -0.217144 \n", + "rplo 97 (Firmicutes) -0.645346 0.043512 \n", + "rplo 98 (Actinobacteria) -0.809299 0.197001 \n", + "rplo 99 (Cyanobacteria) 0.002960 0.087835 \n", + "\n", + " (dehydroalanine) (glycero-3-phosphoethanolamine) \\\n", + "rplo 1 (Cyanobacteria) -0.603510 -0.627171 \n", + "rplo 2 (Firmicutes) -0.439517 -0.666039 \n", + "rplo 60 (Firmicutes) 0.417226 -0.292448 \n", + "rplo 7 (Actinobacteria) 0.183686 0.026987 \n", + "rplo 10 (Firmicutes) 0.046950 0.183876 \n", + "... ... ... \n", + "rplo 95 (Proteobacteria) 0.228063 0.103917 \n", + "rplo 96 (unknown) 0.769984 0.784919 \n", + "rplo 97 (Firmicutes) 0.436498 0.919183 \n", + "rplo 98 (Actinobacteria) -0.167618 -0.042704 \n", + "rplo 99 (Cyanobacteria) -0.866775 -1.058932 \n", + "\n", + " (indoleacrylate) ... thymine tryptophan \\\n", + "rplo 1 (Cyanobacteria) -0.381616 ... -0.043511 0.210599 \n", + "rplo 2 (Firmicutes) -0.410150 ... 0.137920 0.169466 \n", + "rplo 60 (Firmicutes) 0.065090 ... 0.697350 0.051952 \n", + "rplo 7 (Actinobacteria) -0.625513 ... -0.134656 0.447772 \n", + "rplo 10 (Firmicutes) -0.300330 ... 0.495663 -0.230900 \n", + "... ... ... ... ... \n", + "rplo 95 (Proteobacteria) 0.207617 ... 0.355611 0.438821 \n", + "rplo 96 (unknown) 0.361632 ... 0.619834 -0.305987 \n", + "rplo 97 (Firmicutes) 0.265277 ... -0.427435 -0.177776 \n", + "rplo 98 (Actinobacteria) 0.013358 ... 1.281862 -0.094519 \n", + "rplo 99 (Cyanobacteria) 0.196746 ... 0.267505 0.246862 \n", + "\n", + " tyrosine uracil urate uridine urocanate \\\n", + "rplo 1 (Cyanobacteria) -0.162333 0.025163 -0.007784 -0.524075 0.659753 \n", + "rplo 2 (Firmicutes) 0.349917 0.342212 -0.071381 -0.052086 -0.100996 \n", + "rplo 60 (Firmicutes) -0.683715 0.277871 -0.396133 0.844074 0.610815 \n", + "rplo 7 (Actinobacteria) 0.140308 0.308640 -0.013243 -0.849460 1.202816 \n", + "rplo 10 (Firmicutes) 0.343985 -0.021149 -0.215128 0.489172 0.304803 \n", + "... ... ... ... ... ... \n", + "rplo 95 (Proteobacteria) -0.261516 -0.103044 0.015821 0.649242 0.696189 \n", + "rplo 96 (unknown) 0.616732 -0.252027 -0.836541 0.175115 0.724142 \n", + "rplo 97 (Firmicutes) -0.916701 0.122711 -0.301927 0.985757 0.065469 \n", + "rplo 98 (Actinobacteria) -0.342207 0.218514 -0.343730 -0.531529 -0.677377 \n", + "rplo 99 (Cyanobacteria) 0.594449 -0.000066 -0.679070 0.233483 0.244907 \n", + "\n", + " valine xanthine xylitol \n", + "rplo 1 (Cyanobacteria) -0.782468 -0.526718 -0.181540 \n", + "rplo 2 (Firmicutes) -1.005960 -0.265847 -0.966022 \n", + "rplo 60 (Firmicutes) -0.597109 -0.023769 -0.329333 \n", + "rplo 7 (Actinobacteria) -0.229881 0.251655 -0.254032 \n", + "rplo 10 (Firmicutes) -0.802540 0.111719 -0.414126 \n", + "... ... ... ... \n", + "rplo 95 (Proteobacteria) -0.050961 -0.250635 0.593225 \n", + "rplo 96 (unknown) -0.492409 -0.038374 -0.038857 \n", + "rplo 97 (Firmicutes) 0.094737 -0.151891 -0.177370 \n", + "rplo 98 (Actinobacteria) -0.828105 -1.018369 0.738386 \n", + "rplo 99 (Cyanobacteria) -1.055982 -0.158424 -0.869372 \n", + "\n", + "[466 rows x 85 columns]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.ranks_df" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "9e5c5c7f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([466, 85])\n", + "torch.Size([85])\n", + "torch.Size([85, 85])\n", + "torch.Size([466])\n", + "tensor([ -1.1445, -6.3680, -4.3185, 4.9945, -1.4082, -8.2258, 1.8486,\n", + " -4.3756, 4.7173, -6.9559, 3.2740, 0.4376, -5.9801, 6.3349,\n", + " 1.0003, -8.6350, 2.8014, 0.6915, 1.2820, -4.5347, 4.0379,\n", + " -5.6501, 1.9838, 0.2808, -0.2633, -3.2443, -4.1948, 1.5267,\n", + " -1.1784, -2.8921, 2.6564, 1.6168, -0.1698, -0.9756, -4.8180,\n", + " -4.7775, 1.7069, -1.3242, -4.7067, 0.6504, -5.5535, 4.1691,\n", + " -2.5773, 4.0220, -2.2959, -6.7313, -6.9085, 1.4211, 3.6404,\n", + " 1.5132, 2.4939, 0.5881, 1.9394, 7.5292, -8.3888, -0.6539,\n", + " 10.1397, -2.4303, -3.2335, 1.4971, -1.7446, 2.4674, 3.4813,\n", + " 0.0457, 0.8573, 3.2580, -3.3791, 3.7058, 4.9621, -1.6585,\n", + " 0.2361, 0.1459, 4.4815, -4.4632, 7.9341, -3.4729, 6.7684,\n", + " -3.3562, 2.3691, 6.9417, 5.4023, -0.2494, 6.6950, 3.0350,\n", + " -3.9309, 0.9241, 0.5092, 1.8853, 2.6824, 6.8036, -0.4861,\n", + " 5.1062, 2.4624, -0.9210, 9.8531, 0.5789, 1.9744, -6.0279,\n", + " 2.6953, -7.4324, 3.2620, -2.9513, 9.6945, -1.3096, 2.4214,\n", + " -0.9256, -4.1337, 11.0306, 0.4112, 0.1319, -4.7474, 4.4109,\n", + " 2.7715, 2.8321, -0.6753, 3.6894, 2.2041, -6.7310, 3.7233,\n", + " -2.9628, 3.3455, -6.5863, -4.7063, 5.9036, -3.2456, -0.0869,\n", + " -3.5623, -6.0653, -5.9259, 3.8253, -11.0783, 0.0408, -4.8903,\n", + " 0.9617, -1.5991, 4.5272, 5.0266, 1.9491, -3.0679, -0.6566,\n", + " -5.9211, 7.6033, -7.0827, -3.6042, 1.3228, -6.4924, 1.7801,\n", + " -2.6599, 5.8849, 0.3166, -2.8488, -5.0392, -7.5366, -1.9267,\n", + " -1.2711, -5.4646, 1.5345, 2.9971, 0.9353, -5.5945, 5.0444,\n", + " -0.5978, 3.8224, 3.5736, 0.0708, 2.3189, 0.8124, -10.2031,\n", + " 0.1141, -3.4993, 1.6761, 1.1674, -6.0600, -5.8619, -5.7287,\n", + " -3.9315, -1.6097, 3.0353, 1.5874, 0.5476, 2.2777, 0.8731,\n", + " -0.6450, -7.1289, -0.0645, 0.5143, 0.9316, 2.2382, 3.6106,\n", + " -8.4546, -1.4214, -3.5985, -3.4575, -0.5916, 2.5210, -0.4692,\n", + " 2.5000, -6.4857, 0.7608, -5.6512, -7.5485, 4.4014, -6.1880,\n", + " -4.3598, -3.7198, 2.0294, -3.4157, 3.5193, -0.0775, 4.9616,\n", + " -2.4156, -2.3211, -0.2058, 5.9809, 2.5119, 0.0125, -2.1752,\n", + " -5.8627, 3.3680, 3.4236, -4.8590, -1.8029, -2.5155, 5.7980,\n", + " -5.7988, -6.8575, 8.2150, 6.3653, -3.2859, 3.0349, -9.1496,\n", + " 6.2771, -5.9259, 7.2524, 3.4612, -0.6388, 6.2876, 3.6062,\n", + " 1.0134, 0.9312, -2.3465, 3.3057, 1.0286, 9.0334, -8.9800,\n", + " -3.1666, 5.6102, -1.1290, -0.4083, 4.8317, 9.2724, 0.0997,\n", + " -8.8108, 3.4332, 0.3276, 5.0469, -2.0226, -1.6557, 5.6105,\n", + " 1.8530, 2.8858, 1.5988, 2.5177, -2.7918, -2.6911, 2.7218,\n", + " -3.1462, -6.2753, -2.6276, -1.8484, 3.0457, -3.4599, -5.8190,\n", + " -0.9930, -9.0980, 7.5351, -1.4414, -3.8330, -3.9160, -2.0748,\n", + " -4.7279, -2.6979, 2.2114, -9.7617, -1.9074, 9.8307, 1.1703,\n", + " 2.3597, 3.6719, -1.4355, 1.3314, 0.9512, -5.2816, -3.2768,\n", + " 2.1892, -8.9302, -2.4061, 4.7443, -0.6404, -7.9222, 3.9574,\n", + " -5.7212, -2.4539, 3.4378, -7.4782, 1.8264, -1.5297, 4.7548,\n", + " -5.7164, -1.8924, -0.9265, -3.2981, -2.6631, -0.3037, 2.1184,\n", + " 2.4061, -3.7237, 0.9267, 2.6104, -3.0550, 7.8785, 1.0147,\n", + " 3.5998, 3.8647, 3.3049, 0.7033, -4.0938, -6.9029, -2.7553,\n", + " -0.8194, 4.4504, -0.2810, -3.4939, -1.3974, 4.2549, -9.4413,\n", + " 6.4951, 5.9425, -2.5674, 3.1822, -0.9808, -4.4396, 6.4448,\n", + " -0.3536, 5.1797, 0.8818, 4.1052, 4.9712, -0.7238, 8.7621,\n", + " -5.2645, 1.5924, -4.0963, 4.6621, -10.9097, 0.4642, -0.5150,\n", + " -2.9584, -5.4681, 2.4455, -1.9391, 4.9934, -4.7105, -0.8750,\n", + " 6.3088, 0.2136, -2.9872, -1.8482, -5.2081, 1.9450, -3.2619,\n", + " -0.6486, 3.6653, -1.8660, -1.0397, 8.5315, -1.5133, 4.1649,\n", + " -0.9625, 0.8924, -1.6494, 5.3174, 7.2113, -0.4926, 2.1117,\n", + " 2.0516, 3.9590, 3.3258, 4.5366, 2.1683, -1.0748, 2.4090,\n", + " -4.1125, -1.6299, 1.8558, 0.0114, -3.8395, -0.3071, -0.2672,\n", + " 4.3818, 2.9695, 3.4528, 9.6955, -3.0135, -3.6088, 3.4343,\n", + " -3.9485, -3.1757, 4.3005, 1.0197, -3.4628, 0.1942, 0.7603,\n", + " 2.1585, 4.4071, 2.2928, 10.1469, 7.1473, 5.0083, -3.6591,\n", + " -0.3181, 4.8017, 2.0600, 0.7875, 3.8353, 1.9623, 3.0753,\n", + " 3.4961, 0.2156, -1.6791, 2.8405, 3.2189, 5.8801, 0.5369,\n", + " 1.1090, 0.5457, 1.0708, 3.6782, 2.6795, 0.1788, 7.0609,\n", + " -0.4870, -1.4217, -3.9887, 4.7482, -3.6168, -2.9442, -3.7465,\n", + " -0.3917, 5.7974, -1.7506, 1.5932, 2.9426, -4.3741, 0.0520,\n", + " 0.4566, -2.2609, 1.0170, 4.9163, -2.8058, -7.9425, 5.4053,\n", + " -1.4912, -7.1008, -8.9607, -5.4042])\n" ] } ], "source": [ - "ranks = pd.DataFrame(model.ranks(),\n", - " index=microbe_idx,\n", - " columns=metabolite_idx)" + "model.get_ordination()" ] }, { "cell_type": "code", "execution_count": null, - "id": "248a1eff", + "id": "a5f97b6d", "metadata": {}, "outputs": [], "source": [ - "ranks" + "ranks = pd.DataFrame(model.ranks(),\n", + " index=microbe_idx,\n", + " columns=metabolite_idx)" ] }, { "cell_type": "code", "execution_count": null, - "id": "0fde6a1e", + "id": "6d0e4892", "metadata": {}, "outputs": [], "source": [ - "(model.encoder.weight.detach() @ model.decoder.linear.weight.detach().T)" + "ranks" ] }, { @@ -227,11 +1602,167 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 115, + "id": "d8aa9b3a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[9., 6., 7.],\n", + " [3., 8., 2.],\n", + " [6., 5., 3.],\n", + " [4., 6., 4.],\n", + " [7., 5., 0.],\n", + " [6., 0., 3.],\n", + " [2., 4., 3.],\n", + " [5., 4., 9.],\n", + " [7., 3., 1.]]),\n", + " tensor([[ 3.5556, 1.4444, 3.4444],\n", + " [-2.4444, 3.4444, -1.5556],\n", + " [ 0.5556, 0.4444, -0.5556],\n", + " [-1.4444, 1.4444, 0.4444],\n", + " [ 1.5556, 0.4444, -3.5556],\n", + " [ 0.5556, -4.5556, -0.5556],\n", + " [-3.4444, -0.5556, -0.5556],\n", + " [-0.4444, -0.5556, 5.4444],\n", + " [ 1.5556, -1.5556, -2.5556]]),\n", + " tensor([5.4444, 4.5556, 3.5556]))" + ] + }, + "execution_count": 115, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a = torch.randint(10, (9, 3), dtype=torch.float)\n", + "b = torch.randint(10, (3,))\n", + "a, a - a.mean(dim=0), a.mean(dim=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 91, "id": "aa9e8f31", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([466, 85])\n", + "torch.Size([85])\n", + "torch.Size([85, 85])\n", + "compare unmodified:\n", + "tensor([3.9370e+01, 3.0685e+01, 2.7607e+01, 2.6057e+01, 2.4771e+01, 2.4166e+01,\n", + " 2.3556e+01, 2.1152e+01, 2.0435e+01, 1.9499e+01, 1.8038e+01, 1.7893e+01,\n", + " 1.6577e+01, 1.5838e+01, 1.4549e+01, 8.0469e+00, 1.3352e-05, 7.6159e-06,\n", + " 6.4840e-06, 5.7771e-06, 5.1234e-06, 4.8650e-06, 4.7076e-06, 4.3261e-06,\n", + " 3.7033e-06, 3.3649e-06, 3.2426e-06, 3.1312e-06, 2.9980e-06, 2.9524e-06,\n", + " 2.7562e-06, 2.7255e-06, 2.5842e-06, 2.5818e-06, 2.5675e-06, 2.5544e-06,\n", + " 2.5496e-06, 2.5244e-06, 2.5015e-06, 2.4779e-06, 2.4762e-06, 2.4727e-06,\n", + " 2.4645e-06, 2.4528e-06, 2.4424e-06, 2.4378e-06, 2.4306e-06, 2.4263e-06,\n", + " 2.4155e-06, 2.3750e-06, 2.2824e-06, 2.2305e-06, 2.2250e-06, 2.2146e-06,\n", + " 2.2125e-06, 2.1982e-06, 2.1954e-06, 2.1732e-06, 2.1682e-06, 2.1579e-06,\n", + " 2.1559e-06, 2.1535e-06, 2.1495e-06, 2.1278e-06, 2.0853e-06, 2.0755e-06,\n", + " 2.0694e-06, 2.0658e-06, 2.0506e-06, 2.0451e-06, 2.0284e-06, 2.0273e-06,\n", + " 2.0137e-06, 2.0100e-06, 2.0076e-06, 2.0001e-06, 1.9894e-06, 1.9617e-06,\n", + " 1.9169e-06, 1.9098e-06, 1.9008e-06, 1.8981e-06, 1.8605e-06, 1.7143e-06,\n", + " 1.3593e-06])\n", + "sqrt:\n", + "tensor([6.2745e+00, 5.5394e+00, 5.2542e+00, 5.1046e+00, 4.9770e+00, 4.9159e+00,\n", + " 4.8535e+00, 4.5991e+00, 4.5205e+00, 4.4158e+00, 4.2471e+00, 4.2300e+00,\n", + " 4.0714e+00, 3.9797e+00, 3.8143e+00, 2.8367e+00, 3.6540e-03, 2.7597e-03,\n", + " 2.5464e-03, 2.4036e-03, 2.2635e-03, 2.2057e-03, 2.1697e-03, 2.0799e-03,\n", + " 1.9244e-03, 1.8344e-03, 1.8007e-03, 1.7695e-03, 1.7315e-03, 1.7182e-03,\n", + " 1.6602e-03, 1.6509e-03, 1.6075e-03, 1.6068e-03, 1.6023e-03, 1.5983e-03,\n", + " 1.5968e-03, 1.5888e-03, 1.5816e-03, 1.5741e-03, 1.5736e-03, 1.5725e-03,\n", + " 1.5699e-03, 1.5661e-03, 1.5628e-03, 1.5613e-03, 1.5590e-03, 1.5577e-03,\n", + " 1.5542e-03, 1.5411e-03, 1.5108e-03, 1.4935e-03, 1.4916e-03, 1.4882e-03,\n", + " 1.4874e-03, 1.4826e-03, 1.4817e-03, 1.4742e-03, 1.4725e-03, 1.4690e-03,\n", + " 1.4683e-03, 1.4675e-03, 1.4661e-03, 1.4587e-03, 1.4441e-03, 1.4407e-03,\n", + " 1.4385e-03, 1.4373e-03, 1.4320e-03, 1.4301e-03, 1.4242e-03, 1.4238e-03,\n", + " 1.4190e-03, 1.4177e-03, 1.4169e-03, 1.4142e-03, 1.4104e-03, 1.4006e-03,\n", + " 1.3845e-03, 1.3820e-03, 1.3787e-03, 1.3777e-03, 1.3640e-03, 1.3093e-03,\n", + " 1.1659e-03])\n", + "torch.Size([466, 85])\n", + "tensor([[-8.1189e-01, -1.8219e+00, -2.1267e+00, ..., 6.5056e-09,\n", + " 1.0232e-07, -5.1138e-09],\n", + " [-4.1947e+00, 3.4689e-02, 1.1555e+00, ..., -4.7035e-08,\n", + " 3.9148e-08, -6.2629e-08],\n", + " [-9.1145e-01, 1.1518e+00, -6.4364e-01, ..., 1.0274e-07,\n", + " 1.3688e-07, -2.0395e-10],\n", + " ...,\n", + " [ 9.2180e-01, -2.7336e-01, -2.8572e-01, ..., -5.4841e-08,\n", + " 5.8509e-08, -1.8683e-08],\n", + " [-9.7168e-01, -1.1153e+00, -1.3348e+00, ..., -1.0824e-07,\n", + " -6.6422e-08, -4.7472e-08],\n", + " [-4.2655e+00, -1.0002e+00, 2.2608e+00, ..., -1.4434e-07,\n", + " -2.8319e-08, 4.3824e-08]])\n", + "['PC0', 'PC1', 'PC2', 'PC3', 'PC4', 'PC5', 'PC6', 'PC7', 'PC8', 'PC9', 'PC10', 'PC11', 'PC12', 'PC13', 'PC14', 'PC15', 'PC16', 'PC17', 'PC18', 'PC19', 'PC20', 'PC21', 'PC22', 'PC23', 'PC24', 'PC25', 'PC26', 'PC27', 'PC28', 'PC29', 'PC30', 'PC31', 'PC32', 'PC33', 'PC34', 'PC35', 'PC36', 'PC37', 'PC38', 'PC39', 'PC40', 'PC41', 'PC42', 'PC43', 'PC44', 'PC45', 'PC46', 'PC47', 'PC48', 'PC49', 'PC50', 'PC51', 'PC52', 'PC53', 'PC54', 'PC55', 'PC56', 'PC57', 'PC58', 'PC59', 'PC60', 'PC61', 'PC62', 'PC63', 'PC64', 'PC65', 'PC66', 'PC67', 'PC68', 'PC69', 'PC70', 'PC71', 'PC72', 'PC73', 'PC74', 'PC75', 'PC76', 'PC77', 'PC78', 'PC79', 'PC80', 'PC81', 'PC82', 'PC83', 'PC84']\n", + "tensor([[-8.1189e-01, -1.8219e+00, -2.1267e+00, ..., 6.5056e-09,\n", + " 1.0232e-07, -5.1138e-09],\n", + " [-4.1947e+00, 3.4689e-02, 1.1555e+00, ..., -4.7035e-08,\n", + " 3.9148e-08, -6.2629e-08],\n", + " [-9.1145e-01, 1.1518e+00, -6.4364e-01, ..., 1.0274e-07,\n", + " 1.3688e-07, -2.0395e-10],\n", + " ...,\n", + " [ 9.2180e-01, -2.7336e-01, -2.8572e-01, ..., -5.4841e-08,\n", + " 5.8509e-08, -1.8683e-08],\n", + " [-9.7168e-01, -1.1153e+00, -1.3348e+00, ..., -1.0824e-07,\n", + " -6.6422e-08, -4.7472e-08],\n", + " [-4.2655e+00, -1.0002e+00, 2.2608e+00, ..., -1.4434e-07,\n", + " -2.8319e-08, 4.3824e-08]])\n" + ] + } + ], + "source": [ + "bp = model.get_ordination()" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "ede38482", + "metadata": { + "collapsed": true + }, "outputs": [], - "source": [] + "source": [ + "import os.path" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "id": "c79d0e79", + "metadata": {}, + "outputs": [], + "source": [ + "bp_path = os.path.join(\"/Users/keeganevans/Desktop/\", \"biplot\")" + ] + }, + { + "cell_type": "code", + "execution_count": 118, + "id": "605bce56", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/Users/keeganevans/Desktop/biplot'" + ] + }, + "execution_count": 118, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bp.write(bp_path)" + ] }, { "cell_type": "markdown", diff --git a/mmvec/ALR.py b/mmvec/ALR.py index 4f5d7c2..d8fe0b4 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -6,6 +6,28 @@ import torch.nn.functional as F from torch.distributions import Multinomial, Normal +from skbio import OrdinationResults + + +def structure_data(microbes, metabolites): + microbes = microbes.to_dataframe().T + metabolites = metabolites.to_dataframe().T + microbes = microbes.loc[metabolites.index] + + microbe_idx = microbes.columns + metabolite_idx = metabolites.columns + + microbe_count = microbes.shape[1] + metabolite_count = metabolites.shape[1] + + microbes = torch.tensor(microbes.values, dtype=torch.int) + metabolites = torch.tensor(metabolites.values, dtype=torch.int64) + + microbe_relative_frequency = (microbes.T/microbes.sum(1)).T + + return (microbes, metabolites, microbe_idx, metabolite_idx, microbe_count, + metabolite_count, microbe_relative_frequency) + class LinearALR(nn.Module): def __init__(self, input_dim, output_dim): @@ -21,24 +43,25 @@ def forward(self, x): class MMvecALR(nn.Module): - def __init__(self, num_microbes, num_metabolites, latent_dim, sigma_u, + def __init__(self, microbes, metabolites, latent_dim, sigma_u, sigma_v): super().__init__() - self.latent_dim = latent_dim - self.num_microbes = num_microbes - self.num_metabolites = num_metabolites - - self.u_bias = nn.parameter.Parameter(torch.randn((num_microbes, 1))) - - self.encoder = nn.Embedding(num_microbes, latent_dim) - self.decoder = LinearALR(latent_dim, num_metabolites) - + # Data setup + self.microbes, self.metabolites, \ + self.microbe_idx, self. metabolite_idx, \ + self.num_microbes, self.num_metabolites, \ + self.microbe_relative_freq = structure_data(microbes, + metabolites) self.sigma_u = sigma_u self.sigma_v = sigma_v + self.latent_dim = latent_dim + self.u_bias = nn.parameter.Parameter(torch.randn((self.num_microbes, 1))) + self.encoder = nn.Embedding(self.num_microbes, self.latent_dim) + self.decoder = LinearALR(self.latent_dim, self.num_metabolites) - def forward(self, X, Y): + def forward(self, X): # Three likelihoods, the likelihood of each weight and the likelihood # of the data fitting in the way that we thought # LYs @@ -50,7 +73,7 @@ def forward(self, X, Y): validate_args=False, probs=y_pred) - forward_dist = forward_dist.log_prob(Y) + forward_dist = forward_dist.log_prob(self.metabolites) l_y = forward_dist.mean(0).mean() @@ -66,13 +89,44 @@ def forward(self, X, Y): return likelihood_sum def get_ordination(self, equalize_biplot=False): + ranks = self.ranks_matrix - self.ranks_matrix.mean(dim=0) - u, s, v = linalg.svd(ranks, full_matrices=False) - print(u) - print(s) - print(v) - def ranks(self, microbe_ids, metabolite_ids): + u, s_diag, v = linalg.svd(ranks, full_matrices=False) + + + # us torch.diag to go from vector to matrix with the vector on dia + if equalize_biplot: + microbe_embed = u @ torch.sqrt(torch.diag(s_diag)) + metabolite_embed = v.T @ torch.sqrt(s_diag) + else: + microbe_embed = u @ torch.diag(s_diag) + metabolite_embed = v.T + + pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] + + + features = pd.DataFrame( + microbe_embed, columns=pc_ids, index=self.microbe_idx) + + samples = pd.DataFrame(metabolite_embed, columns=pc_ids, + index=self.metabolite_idx) + + short_method_name = 'mmvec biplot' + long_method_name = 'Multiomics mmvec biplot' + eigvals = pd.Series(s_diag, index=pc_ids) + proportion_explained = pd.Series(torch.square(s_diag) / + torch.sum(torch.square(s_diag)), index=pc_ids) + + biplot = OrdinationResults( + short_method_name, long_method_name, eigvals, + samples=samples, features=features, + proportion_explained=proportion_explained) + + return biplot + + + def ranks(self): U = torch.cat( (torch.ones((self.num_microbes, 1)), self.u_bias.detach(), @@ -89,5 +143,5 @@ def ranks(self, microbe_ids, metabolite_ids): res = res - res.mean(axis=1).reshape(-1, 1) self.ranks_matrix = res - self.ranks_df = pd.DataFrame(res, index=microbe_ids, - columns=metabolite_ids) + self.ranks_df = pd.DataFrame(res, index=self.microbe_idx, + columns=self.metabolite_idx) diff --git a/mmvec/train.py b/mmvec/train.py index 182dcd3..a166e83 100644 --- a/mmvec/train.py +++ b/mmvec/train.py @@ -1,15 +1,15 @@ import torch -def mmvec_training_loop(microbes, metabolites, model, optimizer, +def mmvec_training_loop(model, optimizer, batch_size, epochs): for epoch in range(epochs): - draws = torch.multinomial(microbes, + draws = torch.multinomial(model.microbe_relative_freq, batch_size, replacement=True).T - mmvec_model = model(draws, metabolites) + mmvec_model = model(draws) optimizer.zero_grad() mmvec_model.backward() From 449f4a2312e39d8011143f29074f59ccb0cd29c6 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Mon, 2 May 2022 15:45:35 -0700 Subject: [PATCH 09/27] IMP: cleanup before working on tests. --- __init__.py | 0 examples/refactor/ALR.ipynb | 473 ++++++++++++++++++++++++++++++++++-- mmvec/ALR.py | 17 +- 3 files changed, 470 insertions(+), 20 deletions(-) delete mode 100644 __init__.py diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/examples/refactor/ALR.ipynb b/examples/refactor/ALR.ipynb index 10b203d..fa1e31d 100644 --- a/examples/refactor/ALR.ipynb +++ b/examples/refactor/ALR.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "2388b10b", + "id": "d71b64f1", "metadata": {}, "outputs": [], "source": [ @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "a0aea731", + "id": "83368d02", "metadata": {}, "outputs": [], "source": [ @@ -94,7 +94,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "6bc4fcf6", + "id": "1fdc5c3b", "metadata": {}, "outputs": [ { @@ -962,7 +962,7 @@ { "cell_type": "code", "execution_count": null, - "id": "50f42c33", + "id": "9c6d7d97", "metadata": {}, "outputs": [], "source": [] @@ -970,7 +970,7 @@ { "cell_type": "code", "execution_count": 11, - "id": "4b7fc6fc", + "id": "3d543ae6", "metadata": {}, "outputs": [], "source": [ @@ -980,7 +980,7 @@ { "cell_type": "code", "execution_count": 12, - "id": "c962e22f", + "id": "b924530f", "metadata": {}, "outputs": [], "source": [ @@ -993,7 +993,7 @@ { "cell_type": "code", "execution_count": null, - "id": "72dfa241", + "id": "c64267df", "metadata": {}, "outputs": [], "source": [ @@ -1003,7 +1003,7 @@ { "cell_type": "code", "execution_count": null, - "id": "564f2523", + "id": "66ee80e0", "metadata": {}, "outputs": [], "source": [ @@ -1013,7 +1013,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1ccd4896", + "id": "c1ba8000", "metadata": {}, "outputs": [], "source": [ @@ -1023,7 +1023,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b952d538", + "id": "b550250e", "metadata": {}, "outputs": [], "source": [ @@ -1033,7 +1033,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fa680b9b", + "id": "42e4637c", "metadata": {}, "outputs": [], "source": [ @@ -1043,7 +1043,7 @@ { "cell_type": "code", "execution_count": 13, - "id": "79869201", + "id": "1cf51885", "metadata": {}, "outputs": [ { @@ -1480,7 +1480,7 @@ { "cell_type": "code", "execution_count": 48, - "id": "9e5c5c7f", + "id": "a0d8a2a2", "metadata": {}, "outputs": [ { @@ -1580,7 +1580,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6d0e4892", + "id": "38b0b239", "metadata": {}, "outputs": [], "source": [ @@ -1603,7 +1603,7 @@ { "cell_type": "code", "execution_count": 115, - "id": "d8aa9b3a", + "id": "6157d538", "metadata": {}, "outputs": [ { @@ -1724,7 +1724,7 @@ { "cell_type": "code", "execution_count": 116, - "id": "ede38482", + "id": "7ac106a9", "metadata": { "collapsed": true }, @@ -1736,7 +1736,7 @@ { "cell_type": "code", "execution_count": 117, - "id": "c79d0e79", + "id": "780e8bd7", "metadata": {}, "outputs": [], "source": [ @@ -1746,7 +1746,7 @@ { "cell_type": "code", "execution_count": 118, - "id": "605bce56", + "id": "188a5f70", "metadata": {}, "outputs": [ { @@ -1764,6 +1764,443 @@ "bp.write(bp_path)" ] }, + { + "cell_type": "code", + "execution_count": 119, + "id": "ba4b85ce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
(2,3-dihydroxy-3-methylbutanoate)(2,5-diaminohexanoate)(3-hydroxypyridine)(3-methyladenine)(4-oxoproline)(5,6-dihydrothymine)(alanyl-leucine)(dehydroalanine)(glycero-3-phosphoethanolamine)(indoleacrylate)...thyminetryptophantyrosineuracilurateuridineurocanatevalinexanthinexylitol
rplo 1 (Cyanobacteria)1.1757660.067746-0.552615-0.1754780.6255560.0718180.807590-0.603510-0.627171-0.381616...-0.0435110.210599-0.1623330.025163-0.007784-0.5240750.659753-0.782468-0.526718-0.181540
rplo 2 (Firmicutes)0.258168-0.7441900.057575-0.1398430.546534-0.1656670.343791-0.439517-0.666039-0.410150...0.1379200.1694660.3499170.342212-0.071381-0.052086-0.100996-1.005960-0.265847-0.966022
rplo 60 (Firmicutes)0.958025-0.507875-0.8934010.241754-0.286902-0.0437520.2722530.417226-0.2924480.065090...0.6973500.051952-0.6837150.277871-0.3961330.8440740.610815-0.597109-0.023769-0.329333
rplo 7 (Actinobacteria)0.9198030.355543-0.450468-0.3769220.4424900.0494980.4198290.1836860.026987-0.625513...-0.1346560.4477720.1403080.308640-0.013243-0.8494601.202816-0.2298810.251655-0.254032
rplo 10 (Firmicutes)1.143667-0.617700-0.6932220.1995390.394505-0.239950-0.1745970.0469500.183876-0.300330...0.495663-0.2309000.343985-0.021149-0.2151280.4891720.304803-0.8025400.111719-0.414126
..................................................................
rplo 95 (Proteobacteria)-0.8979420.236029-0.405040-0.3714270.381187-0.0110430.1398770.2280630.1039170.207617...0.3556110.438821-0.261516-0.1030440.0158210.6492420.696189-0.050961-0.2506350.593225
rplo 96 (unknown)0.8719050.071470-0.382577-0.0898500.0997910.226773-0.2171440.7699840.7849190.361632...0.619834-0.3059870.616732-0.252027-0.8365410.1751150.724142-0.492409-0.038374-0.038857
rplo 97 (Firmicutes)0.0645210.104134-0.882605-0.478958-0.463571-0.6453460.0435120.4364980.9191830.265277...-0.427435-0.177776-0.9167010.122711-0.3019270.9857570.0654690.094737-0.151891-0.177370
rplo 98 (Actinobacteria)0.4001700.120926-0.454491-0.0275680.462932-0.8092990.197001-0.167618-0.0427040.013358...1.281862-0.094519-0.3422070.218514-0.343730-0.531529-0.677377-0.828105-1.0183690.738386
rplo 99 (Cyanobacteria)1.129443-0.654812-0.0166710.0420460.2327940.0029600.087835-0.866775-1.0589320.196746...0.2675050.2468620.594449-0.000066-0.6790700.2334830.244907-1.055982-0.158424-0.869372
\n", + "

466 rows × 85 columns

\n", + "
" + ], + "text/plain": [ + " (2,3-dihydroxy-3-methylbutanoate) \\\n", + "rplo 1 (Cyanobacteria) 1.175766 \n", + "rplo 2 (Firmicutes) 0.258168 \n", + "rplo 60 (Firmicutes) 0.958025 \n", + "rplo 7 (Actinobacteria) 0.919803 \n", + "rplo 10 (Firmicutes) 1.143667 \n", + "... ... \n", + "rplo 95 (Proteobacteria) -0.897942 \n", + "rplo 96 (unknown) 0.871905 \n", + "rplo 97 (Firmicutes) 0.064521 \n", + "rplo 98 (Actinobacteria) 0.400170 \n", + "rplo 99 (Cyanobacteria) 1.129443 \n", + "\n", + " (2,5-diaminohexanoate) (3-hydroxypyridine) \\\n", + "rplo 1 (Cyanobacteria) 0.067746 -0.552615 \n", + "rplo 2 (Firmicutes) -0.744190 0.057575 \n", + "rplo 60 (Firmicutes) -0.507875 -0.893401 \n", + "rplo 7 (Actinobacteria) 0.355543 -0.450468 \n", + "rplo 10 (Firmicutes) -0.617700 -0.693222 \n", + "... ... ... \n", + "rplo 95 (Proteobacteria) 0.236029 -0.405040 \n", + "rplo 96 (unknown) 0.071470 -0.382577 \n", + "rplo 97 (Firmicutes) 0.104134 -0.882605 \n", + "rplo 98 (Actinobacteria) 0.120926 -0.454491 \n", + "rplo 99 (Cyanobacteria) -0.654812 -0.016671 \n", + "\n", + " (3-methyladenine) (4-oxoproline) \\\n", + "rplo 1 (Cyanobacteria) -0.175478 0.625556 \n", + "rplo 2 (Firmicutes) -0.139843 0.546534 \n", + "rplo 60 (Firmicutes) 0.241754 -0.286902 \n", + "rplo 7 (Actinobacteria) -0.376922 0.442490 \n", + "rplo 10 (Firmicutes) 0.199539 0.394505 \n", + "... ... ... \n", + "rplo 95 (Proteobacteria) -0.371427 0.381187 \n", + "rplo 96 (unknown) -0.089850 0.099791 \n", + "rplo 97 (Firmicutes) -0.478958 -0.463571 \n", + "rplo 98 (Actinobacteria) -0.027568 0.462932 \n", + "rplo 99 (Cyanobacteria) 0.042046 0.232794 \n", + "\n", + " (5,6-dihydrothymine) (alanyl-leucine) \\\n", + "rplo 1 (Cyanobacteria) 0.071818 0.807590 \n", + "rplo 2 (Firmicutes) -0.165667 0.343791 \n", + "rplo 60 (Firmicutes) -0.043752 0.272253 \n", + "rplo 7 (Actinobacteria) 0.049498 0.419829 \n", + "rplo 10 (Firmicutes) -0.239950 -0.174597 \n", + "... ... ... \n", + "rplo 95 (Proteobacteria) -0.011043 0.139877 \n", + "rplo 96 (unknown) 0.226773 -0.217144 \n", + "rplo 97 (Firmicutes) -0.645346 0.043512 \n", + "rplo 98 (Actinobacteria) -0.809299 0.197001 \n", + "rplo 99 (Cyanobacteria) 0.002960 0.087835 \n", + "\n", + " (dehydroalanine) (glycero-3-phosphoethanolamine) \\\n", + "rplo 1 (Cyanobacteria) -0.603510 -0.627171 \n", + "rplo 2 (Firmicutes) -0.439517 -0.666039 \n", + "rplo 60 (Firmicutes) 0.417226 -0.292448 \n", + "rplo 7 (Actinobacteria) 0.183686 0.026987 \n", + "rplo 10 (Firmicutes) 0.046950 0.183876 \n", + "... ... ... \n", + "rplo 95 (Proteobacteria) 0.228063 0.103917 \n", + "rplo 96 (unknown) 0.769984 0.784919 \n", + "rplo 97 (Firmicutes) 0.436498 0.919183 \n", + "rplo 98 (Actinobacteria) -0.167618 -0.042704 \n", + "rplo 99 (Cyanobacteria) -0.866775 -1.058932 \n", + "\n", + " (indoleacrylate) ... thymine tryptophan \\\n", + "rplo 1 (Cyanobacteria) -0.381616 ... -0.043511 0.210599 \n", + "rplo 2 (Firmicutes) -0.410150 ... 0.137920 0.169466 \n", + "rplo 60 (Firmicutes) 0.065090 ... 0.697350 0.051952 \n", + "rplo 7 (Actinobacteria) -0.625513 ... -0.134656 0.447772 \n", + "rplo 10 (Firmicutes) -0.300330 ... 0.495663 -0.230900 \n", + "... ... ... ... ... \n", + "rplo 95 (Proteobacteria) 0.207617 ... 0.355611 0.438821 \n", + "rplo 96 (unknown) 0.361632 ... 0.619834 -0.305987 \n", + "rplo 97 (Firmicutes) 0.265277 ... -0.427435 -0.177776 \n", + "rplo 98 (Actinobacteria) 0.013358 ... 1.281862 -0.094519 \n", + "rplo 99 (Cyanobacteria) 0.196746 ... 0.267505 0.246862 \n", + "\n", + " tyrosine uracil urate uridine urocanate \\\n", + "rplo 1 (Cyanobacteria) -0.162333 0.025163 -0.007784 -0.524075 0.659753 \n", + "rplo 2 (Firmicutes) 0.349917 0.342212 -0.071381 -0.052086 -0.100996 \n", + "rplo 60 (Firmicutes) -0.683715 0.277871 -0.396133 0.844074 0.610815 \n", + "rplo 7 (Actinobacteria) 0.140308 0.308640 -0.013243 -0.849460 1.202816 \n", + "rplo 10 (Firmicutes) 0.343985 -0.021149 -0.215128 0.489172 0.304803 \n", + "... ... ... ... ... ... \n", + "rplo 95 (Proteobacteria) -0.261516 -0.103044 0.015821 0.649242 0.696189 \n", + "rplo 96 (unknown) 0.616732 -0.252027 -0.836541 0.175115 0.724142 \n", + "rplo 97 (Firmicutes) -0.916701 0.122711 -0.301927 0.985757 0.065469 \n", + "rplo 98 (Actinobacteria) -0.342207 0.218514 -0.343730 -0.531529 -0.677377 \n", + "rplo 99 (Cyanobacteria) 0.594449 -0.000066 -0.679070 0.233483 0.244907 \n", + "\n", + " valine xanthine xylitol \n", + "rplo 1 (Cyanobacteria) -0.782468 -0.526718 -0.181540 \n", + "rplo 2 (Firmicutes) -1.005960 -0.265847 -0.966022 \n", + "rplo 60 (Firmicutes) -0.597109 -0.023769 -0.329333 \n", + "rplo 7 (Actinobacteria) -0.229881 0.251655 -0.254032 \n", + "rplo 10 (Firmicutes) -0.802540 0.111719 -0.414126 \n", + "... ... ... ... \n", + "rplo 95 (Proteobacteria) -0.050961 -0.250635 0.593225 \n", + "rplo 96 (unknown) -0.492409 -0.038374 -0.038857 \n", + "rplo 97 (Firmicutes) 0.094737 -0.151891 -0.177370 \n", + "rplo 98 (Actinobacteria) -0.828105 -1.018369 0.738386 \n", + "rplo 99 (Cyanobacteria) -1.055982 -0.158424 -0.869372 \n", + "\n", + "[466 rows x 85 columns]" + ] + }, + "execution_count": 119, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.ranks_df" + ] + }, { "cell_type": "markdown", "id": "46d440f7", diff --git a/mmvec/ALR.py b/mmvec/ALR.py index d8fe0b4..e30dfe4 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -56,7 +56,7 @@ def __init__(self, microbes, metabolites, latent_dim, sigma_u, self.sigma_u = sigma_u self.sigma_v = sigma_v self.latent_dim = latent_dim - self.u_bias = nn.parameter.Parameter(torch.randn((self.num_microbes, 1))) + self.encoder_bias = nn.parameter.Parameter(torch.randn((self.num_microbes, 1))) self.encoder = nn.Embedding(self.num_microbes, self.latent_dim) self.decoder = LinearALR(self.latent_dim, self.num_metabolites) @@ -126,6 +126,18 @@ def get_ordination(self, equalize_biplot=False): return biplot + + @property + def u_bias(self): + #ensure consistent access + return self.encoder_bias + + @property + def v_bias(self): + #ensure consistent access + return self.decoder.linear.bias + + @property def ranks(self): U = torch.cat( (torch.ones((self.num_microbes, 1)), @@ -142,6 +154,7 @@ def ranks(self): res = torch.cat((torch.zeros((self.num_microbes, 1)), U @ V), dim=1) res = res - res.mean(axis=1).reshape(-1, 1) - self.ranks_matrix = res self.ranks_df = pd.DataFrame(res, index=self.microbe_idx, columns=self.metabolite_idx) + + return res From 5946e334236818890209bdda713138ed872f6dc8 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Tue, 3 May 2022 16:36:52 -0700 Subject: [PATCH 10/27] TEST: test_multimodal runs but fails. --- mmvec/ALR.py | 63 +++++--- mmvec/tests/test_multimodal.py | 270 +++++++++++++++++---------------- mmvec/train.py | 6 +- setup.py | 1 + 4 files changed, 181 insertions(+), 159 deletions(-) diff --git a/mmvec/ALR.py b/mmvec/ALR.py index e30dfe4..766487e 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -1,8 +1,8 @@ import pandas as pd import torch +from torch import linalg import torch.nn as nn -import torch.linalg as linalg import torch.nn.functional as F from torch.distributions import Multinomial, Normal @@ -10,8 +10,8 @@ def structure_data(microbes, metabolites): - microbes = microbes.to_dataframe().T - metabolites = metabolites.to_dataframe().T + #microbes = microbes.to_dataframe().T + #metabolites = metabolites.to_dataframe().T microbes = microbes.loc[metabolites.index] microbe_idx = microbes.columns @@ -43,30 +43,33 @@ def forward(self, x): class MMvecALR(nn.Module): - def __init__(self, microbes, metabolites, latent_dim, sigma_u, - sigma_v): + def __init__(self, microbes, metabolites, latent_dim, sigma_u=1, + sigma_v=1): super().__init__() # Data setup - self.microbes, self.metabolites, \ - self.microbe_idx, self. metabolite_idx, \ - self.num_microbes, self.num_metabolites, \ - self.microbe_relative_freq = structure_data(microbes, + (self.microbes, self.metabolites, + self.microbe_idx, self. metabolite_idx, + self.num_microbes, self.num_metabolites, + self.microbe_relative_freq) = structure_data(microbes, metabolites) self.sigma_u = sigma_u self.sigma_v = sigma_v self.latent_dim = latent_dim - self.encoder_bias = nn.parameter.Parameter(torch.randn((self.num_microbes, 1))) + # TODO: intialize same way as linear bias + self.encoder_bias = nn.parameter.Parameter( + torch.randn((self.num_microbes, 1))) self.encoder = nn.Embedding(self.num_microbes, self.latent_dim) self.decoder = LinearALR(self.latent_dim, self.num_metabolites) + def forward(self, X): # Three likelihoods, the likelihood of each weight and the likelihood # of the data fitting in the way that we thought # LYs z = self.encoder(X) - z = z + self.u_bias[X].reshape((*X.shape, 1)) + z = z + self.encoder_bias[X].reshape((*X.shape, 1)) y_pred = self.decoder(z) forward_dist = Multinomial(total_count=0, @@ -75,11 +78,11 @@ def forward(self, X): forward_dist = forward_dist.log_prob(self.metabolites) - l_y = forward_dist.mean(0).mean() + l_y = forward_dist.sum(0).sum() u_weights = self.encoder.weight l_u = Normal(0, self.sigma_u).log_prob(u_weights).sum() - l_ubias = Normal(0, self.sigma_u).log_prob(self.u_bias).sum() + l_ubias = Normal(0, self.sigma_u).log_prob(self.encoder_bias).sum() v_weights = self.decoder.linear.weight l_v = Normal(0, self.sigma_v).log_prob(v_weights).sum() @@ -90,7 +93,8 @@ def forward(self, X): def get_ordination(self, equalize_biplot=False): - ranks = self.ranks_matrix - self.ranks_matrix.mean(dim=0) + ranks = self.ranks() + ranks = ranks - ranks.mean(dim=0) u, s_diag, v = linalg.svd(ranks, full_matrices=False) @@ -130,31 +134,40 @@ def get_ordination(self, equalize_biplot=False): @property def u_bias(self): #ensure consistent access - return self.encoder_bias + return self.encoder_bias.detach() @property def v_bias(self): #ensure consistent access - return self.decoder.linear.bias - + return self.decoder.linear.bias.detach() + @property - def ranks(self): + def U(self): U = torch.cat( (torch.ones((self.num_microbes, 1)), - self.u_bias.detach(), + self.u_bias, self.encoder.weight.detach()), dim=1) + return U + @property + def V(self): V = torch.cat( - (self.decoder.linear.bias.detach().unsqueeze(dim=0), + (self.v_bias.unsqueeze(dim=0), torch.ones((1, self.num_metabolites - 1)), self.decoder.linear.weight.detach().T), dim=0) + return V - res = torch.cat((torch.zeros((self.num_microbes, 1)), U @ V), dim=1) - res = res - res.mean(axis=1).reshape(-1, 1) - - self.ranks_df = pd.DataFrame(res, index=self.microbe_idx, - columns=self.metabolite_idx) + def ranks_dataframe(self): + return pd.DataFrame(self.ranks(), index=self.microbe_idx, + columns=self.metabolite_idx) + def ranks(self): + # Adding the zeros is part of the inverse ALR. + res = torch.cat(( + torch.zeros((self.num_microbes, 1)), + self.U @ self.V + ), dim=1) + res = res - res.mean(axis=1).reshape(-1, 1) return res diff --git a/mmvec/tests/test_multimodal.py b/mmvec/tests/test_multimodal.py index d98f4c6..e312f1c 100644 --- a/mmvec/tests/test_multimodal.py +++ b/mmvec/tests/test_multimodal.py @@ -9,10 +9,9 @@ from scipy.stats import spearmanr from scipy.sparse import coo_matrix from scipy.spatial.distance import pdist -from mmvec.multimodal import MMvec +from mmvec.ALR import MMvecALR +from mmvec.train import mmvec_training_loop from mmvec.util import random_multimodal -from tensorflow import set_random_seed -import tensorflow as tf class TestMMvec(unittest.TestCase): @@ -39,136 +38,145 @@ def tearDown(self): def test_fit(self): np.random.seed(1) - tf.reset_default_graph() + #tf.reset_default_graph() n, d1 = self.trainX.shape n, d2 = self.trainY.shape - with tf.Graph().as_default(), tf.Session() as session: - set_random_seed(0) - model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=2) - model(session, - coo_matrix(self.trainX.values), self.trainY.values, - coo_matrix(self.testX.values), self.testY.values) - model.fit(epoch=1000) - - U_ = np.hstack( - (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) - V_ = np.vstack( - (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) - - u_r, u_p = spearmanr(pdist(model.U), pdist(self.U)) - v_r, v_p = spearmanr(pdist(model.V.T), pdist(self.V.T)) - - res = softmax(model.ranks()) - exp = softmax(np.hstack((np.zeros((d1, 1)), U_ @ V_))) - s_r, s_p = spearmanr(np.ravel(res), np.ravel(exp)) - - self.assertGreater(u_r, 0.5) - self.assertGreater(v_r, 0.5) - self.assertGreater(s_r, 0.5) - self.assertLess(u_p, 5e-2) - self.assertLess(v_p, 5e-2) - self.assertLess(s_p, 5e-2) - - # sanity check cross validation - self.assertLess(model.cv.eval(), 500) - - -class TestMMvecSoilsBenchmark(unittest.TestCase): - def setUp(self): - self.microbes = load_table(get_data_path('soil_microbes.biom')) - self.metabolites = load_table(get_data_path('soil_metabolites.biom')) - X = self.microbes.to_dataframe().T - Y = self.metabolites.to_dataframe().T - X = X.loc[Y.index] - self.trainX = X.iloc[:-2] - self.trainY = Y.iloc[:-2] - self.testX = X.iloc[-2:] - self.testY = Y.iloc[-2:] - - def tearDown(self): - # remove all log directories - for r in glob.glob("logdir*"): - shutil.rmtree(r) - - def test_soils(self): - np.random.seed(1) - tf.reset_default_graph() - n, d1 = self.trainX.shape - n, d2 = self.trainY.shape - - with tf.Graph().as_default(), tf.Session() as session: - set_random_seed(0) - model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=1, - learning_rate=1e-3) - model(session, - coo_matrix(self.trainX.values), self.trainY.values, - coo_matrix(self.testX.values), self.testY.values) - model.fit(epoch=1000) - - ranks = pd.DataFrame( - model.ranks(), - index=self.microbes.ids(axis='observation'), - columns=self.metabolites.ids(axis='observation')) - - microcoleus_metabolites = [ - '(3-methyladenine)', '7-methyladenine', '4-guanidinobutanoate', - 'uracil', 'xanthine', 'hypoxanthine', '(N6-acetyl-lysine)', - 'cytosine', 'N-acetylornithine', 'N-acetylornithine', - 'succinate', 'adenosine', 'guanine', 'adenine'] - mprobs = ranks.loc['rplo 1 (Cyanobacteria)'] - self.assertEqual(np.sum(mprobs.loc[microcoleus_metabolites] > 0), - len(microcoleus_metabolites)) - - -class TestMMvecBenchmark(unittest.TestCase): - def setUp(self): - # build small simulation - res = random_multimodal( - num_microbes=100, num_metabolites=1000, num_samples=300, - latent_dim=2, sigmaQ=2, - microbe_total=5000, metabolite_total=10000, seed=1 - ) - (self.microbes, self.metabolites, self.X, self.B, - self.U, self.Ubias, self.V, self.Vbias) = res - num_train = 10 - self.trainX = self.microbes.iloc[:-num_train] - self.testX = self.microbes.iloc[-num_train:] - self.trainY = self.metabolites.iloc[:-num_train] - self.testY = self.metabolites.iloc[-num_train:] - - @unittest.skip("Only for benchmarking") - def test_gpu(self): - np.random.seed(1) - tf.reset_default_graph() - n, d1 = self.trainX.shape - n, d2 = self.trainY.shape - - with tf.Graph().as_default(), tf.Session() as session: - set_random_seed(0) - model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=2, - batch_size=2000, - device_name="/device:GPU:0") - model(session, - coo_matrix(self.trainX.values), self.trainY.values, - coo_matrix(self.testX.values), self.testY.values) - model.fit(epoch=10000) - - @unittest.skip("Only for benchmarking") - def test_cpu(self): - print('CPU run') - np.random.seed(1) - tf.reset_default_graph() - n, d1 = self.trainX.shape - n, d2 = self.trainY.shape - - with tf.Graph().as_default(), tf.Session() as session: - set_random_seed(0) - model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=2, - batch_size=2000) - model(session, - coo_matrix(self.trainX.values), self.trainY.values, - coo_matrix(self.testX.values), self.testY.values) - model.fit(epoch=10000) + model = MMvecALR(self.trainX, self.trainY, latent_dim=2) + mmvec_training_loop(model=model, learning_rate=0.1, batch_size=1000, + epochs=1000) + + U_ = np.hstack( + (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) + V_ = np.vstack( + (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) + + + res = softmax(model.ranks().numpy()) + exp = softmax(np.hstack((np.zeros((d1, 1)), U_ @ V_))) + + s_r, s_p = spearmanr(np.ravel(res), np.ravel(exp)) + + u_r, u_p = spearmanr(pdist(model.U), pdist(self.U)) + v_r, v_p = spearmanr(pdist(model.V.T), pdist(self.V.T)) + + self.assertGreater(u_r, 0.5) + self.assertGreater(v_r, 0.5) + self.assertGreater(s_r, 0.5) + self.assertLess(u_p, 5e-2) + self.assertLess(v_p, 5e-2) + self.assertLess(s_p, 5e-2) + + + assert False +# with tf.Graph().as_default(), tf.Session() as session: +# set_random_seed(0) +# model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=2) +# model(session, +# coo_matrix(self.trainX.values), self.trainY.values, +# coo_matrix(self.testX.values), self.testY.values) +# model.fit(epoch=1000) +# +# +# # sanity check cross validation +# self.assertLess(model.cv.eval(), 500) +# + + +#class TestMMvecSoilsBenchmark(unittest.TestCase): +# def setUp(self): +# self.microbes = load_table(get_data_path('soil_microbes.biom')) +# self.metabolites = load_table(get_data_path('soil_metabolites.biom')) +# X = self.microbes.to_dataframe().T +# Y = self.metabolites.to_dataframe().T +# X = X.loc[Y.index] +# self.trainX = X.iloc[:-2] +# self.trainY = Y.iloc[:-2] +# self.testX = X.iloc[-2:] +# self.testY = Y.iloc[-2:] +# +# def tearDown(self): +# # remove all log directories +# for r in glob.glob("logdir*"): +# shutil.rmtree(r) + +# def test_soils(self): +# np.random.seed(1) +# n, d1 = self.trainX.shape +# n, d2 = self.trainY.shape +# +# with tf.Graph().as_default(), tf.Session() as session: +# set_random_seed(0) +# model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=1, +# learning_rate=1e-3) +# model(session, +# coo_matrix(self.trainX.values), self.trainY.values, +# coo_matrix(self.testX.values), self.testY.values) +# model.fit(epoch=1000) +# +# ranks = pd.DataFrame( +# model.ranks(), +# index=self.microbes.ids(axis='observation'), +# columns=self.metabolites.ids(axis='observation')) +# +# microcoleus_metabolites = [ +# '(3-methyladenine)', '7-methyladenine', '4-guanidinobutanoate', +# 'uracil', 'xanthine', 'hypoxanthine', '(N6-acetyl-lysine)', +# 'cytosine', 'N-acetylornithine', 'N-acetylornithine', +# 'succinate', 'adenosine', 'guanine', 'adenine'] +# mprobs = ranks.loc['rplo 1 (Cyanobacteria)'] +# self.assertEqual(np.sum(mprobs.loc[microcoleus_metabolites] > 0), +# len(microcoleus_metabolites)) +# + +#class TestMMvecBenchmark(unittest.TestCase): +# def setUp(self): +# # build small simulation +# res = random_multimodal( +# num_microbes=100, num_metabolites=1000, num_samples=300, +# latent_dim=2, sigmaQ=2, +# microbe_total=5000, metabolite_total=10000, seed=1 +# ) +# (self.microbes, self.metabolites, self.X, self.B, +# self.U, self.Ubias, self.V, self.Vbias) = res +# num_train = 10 +# self.trainX = self.microbes.iloc[:-num_train] +# self.testX = self.microbes.iloc[-num_train:] +# self.trainY = self.metabolites.iloc[:-num_train] +# self.testY = self.metabolites.iloc[-num_train:] +# +# @unittest.skip("Only for benchmarking") +# def test_gpu(self): +# np.random.seed(1) +# tf.reset_default_graph() +# n, d1 = self.trainX.shape +# n, d2 = self.trainY.shape +# +# with tf.Graph().as_default(), tf.Session() as session: +# set_random_seed(0) +# model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=2, +# batch_size=2000, +# device_name="/device:GPU:0") +# model(session, +# coo_matrix(self.trainX.values), self.trainY.values, +# coo_matrix(self.testX.values), self.testY.values) +# model.fit(epoch=10000) + + #@unittest.skip("Only for benchmarking") + #def test_cpu(self): + # print('CPU run') + # np.random.seed(1) + # tf.reset_default_graph() + # n, d1 = self.trainX.shape + # n, d2 = self.trainY.shape + + # with tf.Graph().as_default(), tf.Session() as session: + # set_random_seed(0) + # model = MMvec(beta_1=0.8, beta_2=0.9, latent_dim=2, + # batch_size=2000) + # model(session, + # coo_matrix(self.trainX.values), self.trainY.values, + # coo_matrix(self.testX.values), self.testY.values) + # model.fit(epoch=10000) if __name__ == "__main__": diff --git a/mmvec/train.py b/mmvec/train.py index a166e83..e2328cb 100644 --- a/mmvec/train.py +++ b/mmvec/train.py @@ -1,8 +1,8 @@ import torch - -def mmvec_training_loop(model, optimizer, - batch_size, epochs): +def mmvec_training_loop(model, learning_rate, batch_size, epochs): + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, + betas=(0.8, 0.9), maximize=True) for epoch in range(epochs): draws = torch.multinomial(model.microbe_relative_freq, diff --git a/setup.py b/setup.py index 018be08..4198c7a 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,7 @@ 'scikit-bio', 'seaborn', 'tqdm', + 'pytorch' ], classifiers=classifiers, entry_points={ From b32d5f1fbb233d7c7b29e8ac8b458dcf33e8f4ea Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 5 May 2022 12:24:03 -0700 Subject: [PATCH 11/27] checkpoint before cleanup for pr --- examples/refactor/ALR.ipynb | 1987 ++--------------------------------- mmvec/ALR.py | 6 +- 2 files changed, 87 insertions(+), 1906 deletions(-) diff --git a/examples/refactor/ALR.ipynb b/examples/refactor/ALR.ipynb index fa1e31d..9e84413 100644 --- a/examples/refactor/ALR.ipynb +++ b/examples/refactor/ALR.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "d71b64f1", + "id": "461ff352", "metadata": {}, "outputs": [], "source": [ @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "83368d02", + "id": "e536cc44", "metadata": {}, "outputs": [], "source": [ @@ -31,7 +31,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -93,8 +93,8 @@ }, { "cell_type": "code", - "execution_count": 7, - "id": "1fdc5c3b", + "execution_count": 6, + "id": "29799ea5", "metadata": {}, "outputs": [ { @@ -109,7 +109,7 @@ " [0.0027, 0.0135, 0.0000, ..., 0.0000, 0.0008, 0.0036]])" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -120,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "b977e212", "metadata": {}, "outputs": [ @@ -128,826 +128,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([200, 19])\n", - "tensor([[ 79, 139, 435, ..., 189, 202, 335],\n", - " [169, 433, 150, ..., 68, 136, 224],\n", - " [ 9, 1, 224, ..., 358, 224, 436],\n", - " ...,\n", - " [ 0, 156, 71, ..., 114, 402, 224],\n", - " [143, 0, 0, ..., 3, 189, 224],\n", - " [ 39, 426, 4, ..., 81, 280, 70]])\n", "loss: -14241466368.0\n", - "Batch #: 0\n", - "torch.Size([200, 19])\n", - "tensor([[119, 424, 413, ..., 103, 416, 335],\n", - " [418, 126, 225, ..., 26, 7, 129],\n", - " [ 48, 1, 1, ..., 224, 438, 224],\n", - " ...,\n", - " [224, 1, 324, ..., 193, 0, 224],\n", - " [418, 1, 224, ..., 188, 0, 224],\n", - " [ 21, 415, 0, ..., 103, 224, 121]])\n", - "torch.Size([200, 19])\n", - "tensor([[284, 1, 191, ..., 97, 137, 26],\n", - " [ 0, 174, 70, ..., 440, 59, 335],\n", - " [455, 158, 49, ..., 451, 141, 147],\n", - " ...,\n", - " [455, 269, 70, ..., 448, 335, 3],\n", - " [158, 455, 335, ..., 103, 59, 15],\n", - " [431, 101, 191, ..., 0, 125, 202]])\n", - "torch.Size([200, 19])\n", - "tensor([[439, 66, 60, ..., 103, 335, 224],\n", - " [430, 159, 283, ..., 188, 213, 335],\n", - " [ 37, 142, 1, ..., 103, 302, 56],\n", - " ...,\n", - " [174, 0, 0, ..., 103, 4, 224],\n", - " [ 0, 460, 225, ..., 437, 441, 436],\n", - " [ 5, 411, 166, ..., 419, 413, 412]])\n", - "torch.Size([200, 19])\n", - "tensor([[431, 18, 92, ..., 103, 413, 224],\n", - " [ 0, 423, 0, ..., 97, 79, 440],\n", - " [224, 0, 1, ..., 114, 79, 418],\n", - " ...,\n", - " [191, 434, 433, ..., 85, 154, 166],\n", - " [ 0, 81, 169, ..., 122, 10, 52],\n", - " [158, 1, 0, ..., 103, 335, 189]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 38, 0, ..., 40, 300, 224],\n", - " [283, 182, 455, ..., 32, 35, 224],\n", - " [152, 70, 286, ..., 103, 5, 465],\n", - " ...,\n", - " [460, 0, 48, ..., 266, 0, 147],\n", - " [ 0, 180, 0, ..., 3, 71, 224],\n", - " [ 59, 1, 1, ..., 0, 323, 158]])\n", - "torch.Size([200, 19])\n", - "tensor([[431, 455, 1, ..., 193, 349, 440],\n", - " [291, 180, 100, ..., 0, 3, 202],\n", - " [ 48, 174, 0, ..., 97, 225, 440],\n", - " ...,\n", - " [147, 429, 417, ..., 103, 419, 440],\n", - " [391, 1, 43, ..., 2, 202, 37],\n", - " [ 62, 96, 62, ..., 103, 92, 324]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 26, 3, 91, ..., 108, 181, 224],\n", - " [191, 255, 118, ..., 53, 227, 224],\n", - " [125, 455, 48, ..., 422, 411, 224],\n", - " ...,\n", - " [ 96, 347, 84, ..., 263, 425, 114],\n", - " [116, 55, 81, ..., 103, 410, 242],\n", - " [ 33, 34, 4, ..., 425, 0, 25]])\n", - "torch.Size([200, 19])\n", - "tensor([[439, 34, 0, ..., 437, 423, 240],\n", - " [191, 0, 0, ..., 411, 5, 109],\n", - " [ 25, 3, 413, ..., 430, 417, 224],\n", - " ...,\n", - " [ 7, 0, 70, ..., 250, 335, 35],\n", - " [164, 3, 1, ..., 81, 258, 421],\n", - " [ 92, 1, 0, ..., 114, 413, 436]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 0, 48, ..., 263, 137, 224],\n", - " [110, 116, 366, ..., 103, 242, 224],\n", - " [ 0, 391, 1, ..., 0, 186, 202],\n", - " ...,\n", - " [ 0, 55, 34, ..., 424, 321, 440],\n", - " [ 92, 170, 385, ..., 236, 454, 114],\n", - " [169, 12, 380, ..., 92, 448, 376]])\n", - "torch.Size([200, 19])\n", - "tensor([[439, 1, 413, ..., 437, 136, 436],\n", - " [428, 443, 0, ..., 0, 1, 99],\n", - " [439, 441, 0, ..., 388, 51, 380],\n", - " ...,\n", - " [ 0, 393, 454, ..., 103, 413, 412],\n", - " [ 92, 259, 87, ..., 80, 216, 224],\n", - " [461, 142, 4, ..., 26, 150, 9]])\n", - "torch.Size([200, 19])\n", - "tensor([[370, 416, 380, ..., 424, 294, 224],\n", - " [428, 1, 201, ..., 103, 88, 306],\n", - " [ 90, 142, 335, ..., 132, 224, 440],\n", - " ...,\n", - " [444, 114, 291, ..., 452, 335, 112],\n", - " [206, 1, 0, ..., 425, 0, 415],\n", - " [ 49, 0, 54, ..., 448, 36, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[228, 116, 424, ..., 184, 451, 72],\n", - " [158, 171, 0, ..., 424, 224, 253],\n", - " [235, 455, 7, ..., 358, 421, 224],\n", - " ...,\n", - " [455, 439, 0, ..., 147, 88, 440],\n", - " [369, 1, 335, ..., 97, 418, 436],\n", - " [ 0, 107, 464, ..., 103, 109, 37]])\n", - "torch.Size([200, 19])\n", - "tensor([[131, 464, 4, ..., 0, 202, 81],\n", - " [ 16, 439, 69, ..., 430, 59, 224],\n", - " [411, 1, 161, ..., 208, 423, 37],\n", - " ...,\n", - " [185, 302, 417, ..., 76, 354, 274],\n", - " [ 0, 0, 420, ..., 147, 202, 141],\n", - " [ 0, 49, 0, ..., 256, 417, 190]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 1, 99, ..., 147, 0, 224],\n", - " [169, 151, 464, ..., 97, 0, 202],\n", - " [114, 0, 18, ..., 103, 33, 15],\n", - " ...,\n", - " [ 48, 1, 48, ..., 0, 447, 436],\n", - " [424, 248, 48, ..., 424, 413, 436],\n", - " [203, 455, 4, ..., 420, 224, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[231, 1, 26, ..., 424, 414, 440],\n", - " [201, 35, 413, ..., 111, 125, 440],\n", - " [ 0, 38, 224, ..., 87, 354, 415],\n", - " ...,\n", - " [ 0, 1, 413, ..., 431, 114, 90],\n", - " [ 0, 1, 411, ..., 97, 0, 15],\n", - " [350, 334, 180, ..., 422, 82, 105]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 28, 148, 462, ..., 51, 160, 51],\n", - " [280, 41, 286, ..., 103, 160, 224],\n", - " [ 59, 15, 125, ..., 80, 335, 37],\n", - " ...,\n", - " [ 0, 336, 48, ..., 245, 335, 421],\n", - " [206, 415, 0, ..., 424, 335, 452],\n", - " [ 58, 23, 279, ..., 26, 61, 92]])\n", - "torch.Size([200, 19])\n", - "tensor([[439, 418, 335, ..., 424, 224, 99],\n", - " [169, 95, 335, ..., 51, 92, 90],\n", - " [ 0, 360, 459, ..., 147, 136, 147],\n", - " ...,\n", - " [169, 71, 180, ..., 103, 13, 426],\n", - " [424, 288, 462, ..., 365, 397, 258],\n", - " [ 0, 429, 454, ..., 171, 154, 191]])\n", - "torch.Size([200, 19])\n", - "tensor([[100, 448, 93, ..., 423, 431, 224],\n", - " [ 98, 1, 233, ..., 103, 51, 437],\n", - " [153, 455, 48, ..., 424, 418, 26],\n", - " ...,\n", - " [169, 0, 443, ..., 451, 35, 436],\n", - " [ 12, 156, 223, ..., 103, 10, 15],\n", - " [ 79, 1, 92, ..., 455, 146, 440]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 184, 430, ..., 463, 15, 335],\n", - " [446, 158, 0, ..., 35, 452, 436],\n", - " [ 3, 198, 1, ..., 430, 6, 147],\n", - " ...,\n", - " [267, 3, 0, ..., 0, 7, 412],\n", - " [ 0, 1, 0, ..., 103, 354, 224],\n", - " [458, 92, 1, ..., 97, 335, 37]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 70, 1, 429, ..., 3, 70, 224],\n", - " [446, 3, 0, ..., 17, 335, 421],\n", - " [158, 70, 290, ..., 147, 136, 449],\n", - " ...,\n", - " [ 0, 0, 383, ..., 455, 40, 80],\n", - " [224, 438, 377, ..., 103, 37, 51],\n", - " [158, 1, 0, ..., 437, 369, 412]])\n", - "torch.Size([200, 19])\n", - "tensor([[231, 428, 91, ..., 103, 4, 224],\n", - " [228, 81, 89, ..., 275, 166, 441],\n", - " [ 0, 147, 413, ..., 2, 202, 224],\n", - " ...,\n", - " [132, 174, 457, ..., 424, 114, 436],\n", - " [454, 0, 99, ..., 420, 421, 185],\n", - " [186, 31, 0, ..., 246, 335, 412]])\n", - "torch.Size([200, 19])\n", - "tensor([[267, 350, 48, ..., 103, 332, 157],\n", - " [418, 1, 1, ..., 218, 260, 410],\n", - " [ 92, 92, 6, ..., 0, 224, 336],\n", - " ...,\n", - " [170, 0, 152, ..., 107, 15, 452],\n", - " [ 69, 145, 392, ..., 103, 52, 55],\n", - " [444, 460, 89, ..., 420, 164, 358]])\n", - "torch.Size([200, 19])\n", - "tensor([[455, 0, 1, ..., 103, 419, 224],\n", - " [434, 439, 0, ..., 161, 96, 429],\n", - " [ 0, 13, 294, ..., 0, 65, 224],\n", - " ...,\n", - " [169, 114, 1, ..., 463, 302, 224],\n", - " [444, 28, 48, ..., 437, 79, 410],\n", - " [124, 15, 429, ..., 97, 59, 301]])\n", - "torch.Size([200, 19])\n", - "tensor([[426, 391, 96, ..., 103, 166, 1],\n", - " [431, 171, 379, ..., 103, 132, 440],\n", - " [143, 435, 0, ..., 129, 412, 56],\n", - " ...,\n", - " [462, 159, 0, ..., 324, 26, 440],\n", - " [158, 81, 271, ..., 81, 136, 436],\n", - " [ 0, 224, 200, ..., 430, 418, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 33, 17, 455, ..., 189, 15, 15],\n", - " [169, 8, 335, ..., 171, 441, 7],\n", - " [428, 50, 60, ..., 103, 0, 440],\n", - " ...,\n", - " [258, 17, 1, ..., 136, 322, 411],\n", - " [428, 77, 92, ..., 115, 242, 440],\n", - " [119, 70, 92, ..., 103, 302, 358]])\n", - "torch.Size([200, 19])\n", - "tensor([[269, 1, 202, ..., 147, 465, 335],\n", - " [ 0, 0, 48, ..., 3, 108, 440],\n", - " [199, 1, 335, ..., 437, 105, 30],\n", - " ...,\n", - " [ 0, 148, 48, ..., 103, 189, 166],\n", - " [ 53, 3, 4, ..., 410, 114, 436],\n", - " [428, 125, 433, ..., 103, 414, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[419, 439, 225, ..., 437, 224, 92],\n", - " [ 0, 1, 435, ..., 81, 335, 224],\n", - " [279, 45, 48, ..., 103, 417, 258],\n", - " ...,\n", - " [241, 347, 48, ..., 424, 62, 59],\n", - " [455, 70, 0, ..., 424, 58, 347],\n", - " [ 81, 455, 1, ..., 147, 0, 440]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 1, 0, ..., 451, 417, 92],\n", - " [454, 84, 126, ..., 0, 2, 224],\n", - " [455, 50, 48, ..., 147, 454, 37],\n", - " ...,\n", - " [111, 0, 41, ..., 92, 65, 424],\n", - " [ 0, 107, 271, ..., 103, 126, 37],\n", - " [130, 73, 3, ..., 0, 413, 455]])\n", - "torch.Size([200, 19])\n", - "tensor([[148, 203, 0, ..., 103, 146, 258],\n", - " [ 0, 1, 324, ..., 330, 231, 440],\n", - " [169, 0, 258, ..., 103, 105, 440],\n", - " ...,\n", - " [347, 59, 462, ..., 30, 125, 189],\n", - " [169, 229, 8, ..., 109, 0, 120],\n", - " [428, 439, 335, ..., 424, 411, 258]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 33, 269, ..., 103, 315, 15],\n", - " [ 50, 156, 152, ..., 422, 202, 440],\n", - " [418, 447, 48, ..., 424, 124, 224],\n", - " ...,\n", - " [137, 59, 254, ..., 103, 26, 440],\n", - " [241, 143, 391, ..., 437, 213, 224],\n", - " [402, 191, 1, ..., 103, 335, 436]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 3, 1, 183, ..., 424, 114, 147],\n", - " [ 84, 1, 4, ..., 81, 225, 465],\n", - " [ 0, 460, 131, ..., 103, 79, 281],\n", - " ...,\n", - " [419, 1, 0, ..., 430, 243, 224],\n", - " [455, 258, 5, ..., 15, 421, 440],\n", - " [161, 0, 288, ..., 103, 354, 436]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 314, 4, ..., 295, 0, 330],\n", - " [455, 107, 20, ..., 81, 0, 335],\n", - " [459, 252, 366, ..., 103, 224, 114],\n", - " ...,\n", - " [431, 1, 269, ..., 411, 72, 437],\n", - " [428, 438, 0, ..., 424, 369, 79],\n", - " [ 0, 455, 1, ..., 103, 79, 401]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 174, 418, ..., 103, 413, 335],\n", - " [258, 366, 0, ..., 103, 336, 224],\n", - " [ 58, 258, 140, ..., 20, 96, 9],\n", - " ...,\n", - " [186, 423, 3, ..., 429, 45, 15],\n", - " [439, 49, 391, ..., 103, 40, 224],\n", - " [428, 1, 428, ..., 147, 313, 37]])\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([200, 19])\n", - "tensor([[273, 12, 0, ..., 0, 335, 437],\n", - " [153, 464, 1, ..., 97, 39, 1],\n", - " [142, 433, 48, ..., 103, 26, 181],\n", - " ...,\n", - " [428, 1, 271, ..., 103, 413, 224],\n", - " [ 0, 31, 308, ..., 448, 335, 422],\n", - " [ 3, 0, 435, ..., 103, 264, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[452, 313, 1, ..., 424, 302, 224],\n", - " [ 34, 174, 335, ..., 45, 336, 15],\n", - " [ 72, 26, 295, ..., 81, 252, 438],\n", - " ...,\n", - " [169, 0, 49, ..., 147, 82, 112],\n", - " [ 0, 70, 170, ..., 437, 147, 458],\n", - " [ 3, 87, 114, ..., 103, 335, 37]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 48, 198, 249, ..., 103, 141, 336],\n", - " [ 0, 334, 180, ..., 411, 419, 15],\n", - " [439, 123, 0, ..., 103, 80, 224],\n", - " ...,\n", - " [439, 412, 101, ..., 430, 278, 15],\n", - " [423, 84, 0, ..., 0, 180, 436],\n", - " [ 79, 34, 200, ..., 0, 40, 440]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 72, 114, 1, ..., 147, 295, 440],\n", - " [239, 12, 1, ..., 402, 224, 420],\n", - " [ 0, 224, 383, ..., 137, 315, 440],\n", - " ...,\n", - " [423, 172, 69, ..., 103, 114, 420],\n", - " [297, 423, 1, ..., 103, 35, 9],\n", - " [424, 1, 269, ..., 424, 225, 190]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 165, 13, ..., 103, 354, 335],\n", - " [ 17, 150, 43, ..., 424, 347, 224],\n", - " [228, 0, 23, ..., 192, 434, 441],\n", - " ...,\n", - " [125, 92, 1, ..., 26, 180, 335],\n", - " [169, 1, 48, ..., 103, 0, 86],\n", - " [ 0, 0, 54, ..., 81, 202, 80]])\n", - "torch.Size([200, 19])\n", - "tensor([[131, 1, 61, ..., 295, 413, 92],\n", - " [ 0, 420, 48, ..., 197, 412, 335],\n", - " [146, 1, 335, ..., 423, 111, 224],\n", - " ...,\n", - " [400, 455, 60, ..., 440, 449, 224],\n", - " [280, 15, 3, ..., 147, 437, 336],\n", - " [400, 179, 0, ..., 0, 220, 436]])\n", - "torch.Size([200, 19])\n", - "tensor([[119, 174, 53, ..., 103, 202, 376],\n", - " [169, 28, 136, ..., 424, 35, 224],\n", - " [318, 1, 235, ..., 147, 418, 224],\n", - " ...,\n", - " [ 0, 242, 0, ..., 208, 413, 335],\n", - " [ 0, 1, 136, ..., 147, 202, 376],\n", - " [167, 1, 335, ..., 103, 96, 420]])\n", - "torch.Size([200, 19])\n", - "tensor([[105, 303, 433, ..., 465, 177, 59],\n", - " [ 0, 334, 335, ..., 0, 147, 376],\n", - " [ 7, 111, 191, ..., 402, 70, 109],\n", - " ...,\n", - " [ 32, 26, 4, ..., 430, 168, 377],\n", - " [276, 455, 106, ..., 437, 6, 378],\n", - " [ 39, 1, 95, ..., 147, 335, 335]])\n", - "torch.Size([200, 19])\n", - "tensor([[132, 153, 5, ..., 424, 410, 440],\n", - " [ 31, 123, 0, ..., 97, 335, 15],\n", - " [ 0, 347, 9, ..., 103, 5, 440],\n", - " ...,\n", - " [101, 284, 335, ..., 102, 354, 15],\n", - " [198, 83, 183, ..., 424, 177, 426],\n", - " [ 0, 3, 0, ..., 164, 429, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 412, 153, ..., 103, 73, 440],\n", - " [ 0, 258, 48, ..., 147, 335, 436],\n", - " [ 99, 1, 335, ..., 425, 56, 9],\n", - " ...,\n", - " [229, 3, 1, ..., 2, 335, 440],\n", - " [ 0, 59, 60, ..., 424, 441, 436],\n", - " [284, 84, 179, ..., 424, 161, 125]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 84, 336, ..., 172, 319, 82],\n", - " [130, 431, 1, ..., 147, 413, 15],\n", - " [419, 236, 18, ..., 81, 158, 37],\n", - " ...,\n", - " [ 0, 50, 48, ..., 423, 185, 1],\n", - " [435, 1, 48, ..., 70, 425, 15],\n", - " [147, 3, 335, ..., 103, 5, 147]])\n", - "torch.Size([200, 19])\n", - "tensor([[439, 225, 335, ..., 97, 0, 5],\n", - " [ 0, 234, 60, ..., 437, 416, 436],\n", - " [244, 248, 380, ..., 420, 26, 9],\n", - " ...,\n", - " [418, 156, 291, ..., 424, 150, 224],\n", - " [ 53, 439, 0, ..., 302, 79, 436],\n", - " [191, 361, 375, ..., 422, 110, 440]])\n", - "torch.Size([200, 19])\n", - "tensor([[347, 70, 433, ..., 133, 79, 259],\n", - " [114, 142, 4, ..., 424, 70, 37],\n", - " [ 0, 287, 48, ..., 430, 454, 224],\n", - " ...,\n", - " [ 48, 1, 455, ..., 147, 335, 86],\n", - " [ 35, 158, 428, ..., 376, 413, 335],\n", - " [142, 70, 150, ..., 32, 2, 147]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 92, 377, ..., 103, 259, 224],\n", - " [ 3, 28, 92, ..., 291, 0, 224],\n", - " [158, 26, 48, ..., 437, 441, 143],\n", - " ...,\n", - " [ 0, 421, 1, ..., 424, 88, 138],\n", - " [231, 433, 18, ..., 193, 1, 258],\n", - " [ 13, 435, 48, ..., 441, 226, 147]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 81, 241, 163, ..., 430, 7, 436],\n", - " [ 0, 139, 0, ..., 103, 181, 224],\n", - " [ 0, 242, 1, ..., 103, 421, 208],\n", - " ...,\n", - " [164, 1, 92, ..., 103, 330, 436],\n", - " [ 68, 81, 48, ..., 420, 313, 224],\n", - " [284, 124, 417, ..., 0, 414, 440]])\n", - "torch.Size([200, 19])\n", - "tensor([[163, 50, 269, ..., 455, 243, 9],\n", - " [ 0, 1, 303, ..., 103, 41, 92],\n", - " [ 49, 423, 48, ..., 0, 138, 55],\n", - " ...,\n", - " [ 16, 140, 0, ..., 424, 3, 440],\n", - " [433, 0, 199, ..., 302, 335, 452],\n", - " [455, 25, 1, ..., 424, 58, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 40, 0, ..., 410, 414, 436],\n", - " [165, 412, 125, ..., 103, 3, 380],\n", - " [ 76, 20, 0, ..., 358, 335, 440],\n", - " ...,\n", - " [191, 50, 269, ..., 0, 137, 37],\n", - " [323, 172, 335, ..., 92, 38, 37],\n", - " [ 0, 452, 136, ..., 103, 260, 365]])\n", - "torch.Size([200, 19])\n", - "tensor([[225, 1, 158, ..., 103, 185, 335],\n", - " [ 0, 258, 413, ..., 0, 109, 224],\n", - " [ 0, 31, 4, ..., 147, 354, 224],\n", - " ...,\n", - " [130, 1, 1, ..., 103, 213, 449],\n", - " [391, 225, 243, ..., 430, 26, 224],\n", - " [111, 418, 48, ..., 103, 29, 140]])\n", - "torch.Size([200, 19])\n", - "tensor([[239, 1, 8, ..., 103, 315, 15],\n", - " [461, 438, 1, ..., 39, 433, 224],\n", - " [411, 380, 454, ..., 245, 429, 15],\n", - " ...,\n", - " [422, 92, 256, ..., 26, 109, 224],\n", - " [ 25, 31, 1, ..., 417, 0, 402],\n", - " [ 0, 460, 0, ..., 0, 164, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 7, 431, 413, ..., 330, 421, 147],\n", - " [126, 358, 413, ..., 437, 421, 15],\n", - " [131, 92, 48, ..., 302, 202, 224],\n", - " ...,\n", - " [284, 67, 1, ..., 103, 30, 189],\n", - " [ 0, 83, 462, ..., 245, 3, 431],\n", - " [ 0, 33, 413, ..., 103, 423, 92]])\n", - "torch.Size([200, 19])\n", - "tensor([[152, 431, 454, ..., 39, 441, 224],\n", - " [ 50, 32, 324, ..., 422, 45, 440],\n", - " [444, 191, 315, ..., 451, 59, 65],\n", - " ...,\n", - " [ 0, 334, 1, ..., 13, 0, 449],\n", - " [ 0, 184, 1, ..., 358, 410, 441],\n", - " [124, 152, 6, ..., 424, 0, 455]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 59, 36, 335, ..., 425, 147, 26],\n", - " [431, 439, 445, ..., 430, 114, 458],\n", - " [297, 1, 224, ..., 0, 398, 5],\n", - " ...,\n", - " [224, 114, 1, ..., 261, 411, 436],\n", - " [ 81, 183, 0, ..., 424, 414, 224],\n", - " [ 26, 49, 0, ..., 103, 413, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 84, 333, ..., 103, 412, 224],\n", - " [184, 202, 54, ..., 103, 417, 224],\n", - " [169, 1, 48, ..., 103, 429, 429],\n", - " ...,\n", - " [ 73, 334, 435, ..., 302, 412, 224],\n", - " [202, 92, 6, ..., 208, 109, 224],\n", - " [101, 54, 305, ..., 103, 336, 15]])\n", - "torch.Size([200, 19])\n", - "tensor([[439, 1, 48, ..., 465, 423, 147],\n", - " [ 17, 1, 462, ..., 154, 213, 147],\n", - " [ 0, 123, 115, ..., 103, 410, 440],\n", - " ...,\n", - " [462, 1, 107, ..., 424, 202, 1],\n", - " [ 48, 111, 462, ..., 316, 354, 224],\n", - " [419, 158, 9, ..., 26, 459, 28]])\n", - "torch.Size([200, 19])\n", - "tensor([[185, 1, 1, ..., 147, 451, 35],\n", - " [ 0, 37, 131, ..., 451, 0, 185],\n", - " [241, 224, 1, ..., 103, 459, 440],\n", - " ...,\n", - " [239, 1, 48, ..., 26, 455, 279],\n", - " [ 0, 0, 380, ..., 263, 3, 224],\n", - " [455, 1, 26, ..., 437, 224, 147]])\n", - "torch.Size([200, 19])\n", - "tensor([[411, 225, 317, ..., 441, 249, 82],\n", - " [152, 456, 0, ..., 64, 447, 335],\n", - " [ 28, 121, 34, ..., 424, 93, 335],\n", - " ...,\n", - " [276, 438, 0, ..., 448, 295, 436],\n", - " [191, 250, 0, ..., 2, 356, 37],\n", - " [147, 167, 131, ..., 81, 454, 58]])\n", - "torch.Size([200, 19])\n", - "tensor([[142, 57, 65, ..., 103, 45, 37],\n", - " [ 0, 438, 148, ..., 61, 222, 253],\n", - " [246, 99, 180, ..., 103, 274, 436],\n", - " ...,\n", - " [ 3, 0, 219, ..., 452, 335, 202],\n", - " [428, 1, 291, ..., 455, 202, 185],\n", - " [439, 1, 454, ..., 103, 459, 335]])\n", - "torch.Size([200, 19])\n", - "tensor([[119, 224, 4, ..., 133, 15, 224],\n", - " [139, 443, 424, ..., 64, 58, 208],\n", - " [174, 391, 48, ..., 97, 410, 419],\n", - " ...,\n", - " [ 0, 1, 219, ..., 103, 3, 440],\n", - " [169, 1, 107, ..., 103, 434, 440],\n", - " [411, 1, 152, ..., 97, 202, 35]])\n", - "torch.Size([200, 19])\n", - "tensor([[130, 133, 48, ..., 424, 411, 258],\n", - " [ 0, 447, 1, ..., 437, 35, 389],\n", - " [411, 0, 48, ..., 0, 39, 224],\n", - " ...,\n", - " [ 0, 59, 206, ..., 56, 45, 129],\n", - " [ 3, 0, 286, ..., 424, 132, 224],\n", - " [ 0, 135, 180, ..., 97, 465, 440]])\n", - "torch.Size([200, 19])\n", - "tensor([[125, 16, 342, ..., 424, 160, 436],\n", - " [241, 3, 1, ..., 422, 23, 224],\n", - " [ 0, 3, 413, ..., 424, 335, 133],\n", - " ...,\n", - " [254, 5, 151, ..., 3, 336, 3],\n", - " [451, 464, 457, ..., 97, 12, 9],\n", - " [412, 191, 324, ..., 324, 413, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[439, 1, 335, ..., 0, 27, 160],\n", - " [439, 1, 31, ..., 103, 335, 37],\n", - " [169, 20, 4, ..., 424, 224, 410],\n", - " ...,\n", - " [152, 16, 24, ..., 9, 192, 9],\n", - " [ 59, 391, 48, ..., 419, 81, 224],\n", - " [ 0, 3, 48, ..., 114, 40, 1]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 1, 0, ..., 365, 79, 38],\n", - " [ 69, 464, 324, ..., 147, 197, 15],\n", - " [169, 433, 269, ..., 103, 414, 225],\n", - " ...,\n", - " [174, 103, 1, ..., 103, 354, 420],\n", - " [ 0, 1, 0, ..., 0, 125, 70],\n", - " [ 16, 1, 347, ..., 173, 224, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[439, 1, 39, ..., 431, 224, 224],\n", - " [117, 84, 60, ..., 445, 20, 436],\n", - " [158, 455, 179, ..., 103, 451, 436],\n", - " ...,\n", - " [164, 433, 48, ..., 26, 442, 440],\n", - " [456, 455, 69, ..., 424, 126, 335],\n", - " [455, 1, 57, ..., 137, 4, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[293, 1, 1, ..., 424, 164, 441],\n", - " [323, 0, 48, ..., 425, 33, 434],\n", - " [203, 1, 0, ..., 424, 0, 440],\n", - " ...,\n", - " [139, 419, 9, ..., 192, 410, 258],\n", - " [ 81, 198, 148, ..., 76, 0, 224],\n", - " [ 48, 107, 0, ..., 437, 79, 224]])\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([200, 19])\n", - "tensor([[169, 158, 335, ..., 115, 164, 335],\n", - " [336, 31, 429, ..., 136, 410, 422],\n", - " [234, 358, 328, ..., 451, 40, 224],\n", - " ...,\n", - " [ 9, 234, 48, ..., 402, 335, 461],\n", - " [174, 455, 0, ..., 97, 462, 15],\n", - " [139, 156, 219, ..., 424, 294, 15]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 455, 324, ..., 148, 354, 140],\n", - " [ 92, 70, 258, ..., 103, 279, 224],\n", - " [439, 435, 48, ..., 85, 215, 147],\n", - " ...,\n", - " [180, 1, 0, ..., 32, 28, 224],\n", - " [ 0, 250, 0, ..., 103, 441, 202],\n", - " [ 0, 1, 196, ..., 97, 80, 336]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 48, 70, 1, ..., 314, 258, 440],\n", - " [ 0, 0, 48, ..., 424, 105, 440],\n", - " [ 0, 35, 380, ..., 437, 147, 434],\n", - " ...,\n", - " [228, 183, 4, ..., 437, 224, 224],\n", - " [206, 70, 224, ..., 147, 335, 224],\n", - " [460, 269, 48, ..., 424, 431, 18]])\n", - "torch.Size([200, 19])\n", - "tensor([[439, 1, 12, ..., 314, 231, 114],\n", - " [439, 92, 335, ..., 314, 465, 335],\n", - " [ 5, 426, 256, ..., 224, 323, 446],\n", - " ...,\n", - " [ 69, 418, 294, ..., 147, 444, 440],\n", - " [107, 433, 417, ..., 15, 0, 371],\n", - " [158, 439, 65, ..., 246, 402, 335]])\n", - "torch.Size([200, 19])\n", - "tensor([[358, 191, 429, ..., 103, 397, 226],\n", - " [199, 251, 435, ..., 424, 413, 224],\n", - " [336, 0, 48, ..., 64, 217, 335],\n", - " ...,\n", - " [169, 129, 449, ..., 437, 126, 224],\n", - " [158, 179, 0, ..., 97, 27, 421],\n", - " [ 0, 444, 302, ..., 103, 105, 458]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 0, 48, ..., 189, 213, 40],\n", - " [ 0, 151, 151, ..., 424, 1, 440],\n", - " [151, 291, 452, ..., 33, 376, 99],\n", - " ...,\n", - " [104, 121, 149, ..., 147, 146, 224],\n", - " [145, 152, 6, ..., 103, 213, 433],\n", - " [415, 70, 3, ..., 103, 224, 15]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 81, 0, ..., 193, 10, 189],\n", - " [ 0, 1, 424, ..., 97, 30, 189],\n", - " [ 38, 428, 1, ..., 103, 441, 347],\n", - " ...,\n", - " [ 24, 191, 424, ..., 430, 5, 72],\n", - " [239, 0, 286, ..., 103, 125, 436],\n", - " [357, 156, 48, ..., 438, 269, 9]])\n", - "torch.Size([200, 19])\n", - "tensor([[418, 1, 41, ..., 115, 417, 411],\n", - " [ 0, 180, 26, ..., 420, 96, 1],\n", - " [ 88, 0, 335, ..., 103, 59, 92],\n", - " ...,\n", - " [444, 70, 241, ..., 30, 67, 440],\n", - " [391, 321, 1, ..., 417, 102, 224],\n", - " [254, 101, 89, ..., 431, 76, 335]])\n", - "torch.Size([200, 19])\n", - "tensor([[119, 424, 8, ..., 415, 414, 224],\n", - " [174, 151, 0, ..., 430, 410, 224],\n", - " [358, 1, 462, ..., 30, 425, 173],\n", - " ...,\n", - " [202, 151, 202, ..., 92, 459, 224],\n", - " [291, 149, 413, ..., 85, 225, 410],\n", - " [ 72, 66, 21, ..., 266, 40, 145]])\n", - "torch.Size([200, 19])\n", - "tensor([[158, 224, 1, ..., 103, 0, 190],\n", - " [ 0, 1, 291, ..., 291, 417, 224],\n", - " [433, 443, 1, ..., 437, 59, 436],\n", - " ...,\n", - " [439, 455, 93, ..., 422, 335, 224],\n", - " [241, 142, 294, ..., 262, 70, 37],\n", - " [250, 1, 8, ..., 330, 136, 436]])\n", - "torch.Size([200, 19])\n", - "tensor([[180, 0, 0, ..., 411, 133, 72],\n", - " [ 25, 81, 8, ..., 147, 7, 335],\n", - " [ 0, 3, 48, ..., 103, 180, 181],\n", - " ...,\n", - " [ 0, 59, 10, ..., 424, 0, 258],\n", - " [ 0, 33, 76, ..., 103, 335, 26],\n", - " [159, 0, 206, ..., 137, 410, 335]])\n", - "torch.Size([200, 19])\n", - "tensor([[458, 438, 180, ..., 80, 40, 37],\n", - " [ 0, 1, 0, ..., 451, 243, 15],\n", - " [145, 446, 269, ..., 9, 313, 145],\n", - " ...,\n", - " [ 3, 225, 33, ..., 448, 354, 436],\n", - " [439, 174, 98, ..., 147, 434, 224],\n", - " [430, 118, 317, ..., 369, 419, 281]])\n", - "torch.Size([200, 19])\n", - "tensor([[336, 98, 0, ..., 411, 10, 224],\n", - " [297, 419, 0, ..., 2, 126, 335],\n", - " [ 0, 1, 258, ..., 422, 376, 224],\n", - " ...,\n", - " [ 0, 32, 433, ..., 147, 227, 410],\n", - " [137, 447, 9, ..., 376, 146, 192],\n", - " [125, 75, 152, ..., 429, 150, 426]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 81, 348, 302, ..., 424, 224, 224],\n", - " [431, 140, 243, ..., 424, 335, 440],\n", - " [402, 16, 369, ..., 97, 0, 37],\n", - " ...,\n", - " [225, 1, 3, ..., 434, 27, 335],\n", - " [169, 171, 442, ..., 64, 150, 440],\n", - " [151, 435, 1, ..., 58, 231, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[220, 59, 299, ..., 437, 302, 224],\n", - " [ 99, 255, 0, ..., 154, 414, 37],\n", - " [ 0, 358, 158, ..., 189, 431, 37],\n", - " ...,\n", - " [ 0, 287, 304, ..., 424, 459, 5],\n", - " [167, 280, 48, ..., 424, 132, 224],\n", - " [165, 70, 4, ..., 437, 412, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[324, 435, 297, ..., 430, 413, 462],\n", - " [207, 1, 249, ..., 27, 279, 1],\n", - " [369, 1, 279, ..., 17, 15, 419],\n", - " ...,\n", - " [224, 0, 49, ..., 103, 79, 224],\n", - " [ 0, 169, 54, ..., 114, 20, 224],\n", - " [283, 3, 190, ..., 424, 413, 140]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 81, 78, 0, ..., 147, 459, 368],\n", - " [131, 270, 159, ..., 103, 224, 440],\n", - " [257, 1, 23, ..., 256, 264, 41],\n", - " ...,\n", - " [174, 26, 23, ..., 103, 125, 258],\n", - " [191, 121, 1, ..., 75, 72, 62],\n", - " [280, 0, 140, ..., 103, 3, 295]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 1, 48, ..., 103, 264, 158],\n", - " [ 73, 58, 60, ..., 103, 3, 335],\n", - " [ 16, 334, 0, ..., 103, 138, 438],\n", - " ...,\n", - " [136, 1, 335, ..., 269, 315, 436],\n", - " [411, 1, 48, ..., 103, 222, 9],\n", - " [169, 1, 237, ..., 0, 461, 114]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 26, 34, 1, ..., 256, 0, 90],\n", - " [131, 121, 1, ..., 424, 114, 224],\n", - " [131, 1, 0, ..., 424, 459, 446],\n", - " ...,\n", - " [ 0, 76, 1, ..., 103, 5, 418],\n", - " [ 0, 1, 48, ..., 103, 27, 121],\n", - " [203, 35, 305, ..., 424, 64, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 87, 0, 114, ..., 147, 291, 436],\n", - " [105, 80, 115, ..., 103, 321, 224],\n", - " [454, 206, 48, ..., 103, 380, 336],\n", - " ...,\n", - " [ 0, 58, 100, ..., 154, 160, 224],\n", - " [169, 1, 335, ..., 0, 452, 37],\n", - " [ 0, 297, 1, ..., 425, 330, 37]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 84, 1, 4, ..., 103, 459, 124],\n", - " [ 77, 369, 420, ..., 103, 410, 224],\n", - " [ 96, 75, 335, ..., 13, 15, 224],\n", - " ...,\n", - " [ 31, 0, 69, ..., 45, 202, 436],\n", - " [ 72, 250, 4, ..., 434, 465, 412],\n", - " [ 33, 96, 0, ..., 139, 89, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 39, 152, 335, ..., 448, 59, 28],\n", - " [439, 158, 454, ..., 103, 114, 173],\n", - " [ 0, 75, 24, ..., 437, 441, 440],\n", - " ...,\n", - " [418, 111, 294, ..., 103, 414, 347],\n", - " [460, 358, 74, ..., 103, 419, 369],\n", - " [ 0, 15, 225, ..., 227, 269, 246]])\n", - "torch.Size([200, 19])\n", - "tensor([[447, 10, 335, ..., 103, 416, 380],\n", - " [ 0, 103, 48, ..., 103, 136, 37],\n", - " [402, 262, 0, ..., 103, 376, 295],\n", - " ...,\n", - " [ 17, 0, 0, ..., 103, 224, 109],\n", - " [257, 1, 258, ..., 147, 335, 437],\n", - " [370, 0, 81, ..., 103, 258, 202]])\n", - "torch.Size([200, 19])\n", - "tensor([[135, 12, 48, ..., 411, 109, 224],\n", - " [428, 1, 335, ..., 422, 147, 380],\n", - " [119, 0, 48, ..., 61, 146, 147],\n", - " ...,\n", - " [415, 1, 0, ..., 302, 319, 274],\n", - " [ 86, 59, 8, ..., 47, 335, 440],\n", - " [ 0, 140, 31, ..., 425, 202, 157]])\n", - "torch.Size([200, 19])\n", - "tensor([[239, 60, 106, ..., 148, 213, 90],\n", - " [167, 0, 225, ..., 81, 0, 1],\n", - " [410, 81, 65, ..., 311, 111, 92],\n", - " ...,\n", - " [337, 3, 0, ..., 26, 414, 225],\n", - " [ 0, 66, 26, ..., 103, 0, 224],\n", - " [241, 428, 429, ..., 103, 0, 92]])\n", - "torch.Size([200, 19])\n", - "tensor([[159, 1, 0, ..., 103, 402, 15],\n", - " [ 5, 0, 458, ..., 130, 225, 15],\n", - " [ 0, 0, 101, ..., 147, 59, 440],\n", - " ...,\n", - " [411, 1, 380, ..., 430, 30, 15],\n", - " [456, 458, 136, ..., 103, 111, 1],\n", - " [ 96, 1, 92, ..., 455, 419, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[337, 1, 69, ..., 30, 111, 436],\n", - " [269, 139, 0, ..., 103, 147, 418],\n", - " [126, 223, 335, ..., 437, 412, 92],\n", - " ...,\n", - " [148, 459, 1, ..., 256, 319, 335],\n", - " [ 52, 202, 136, ..., 330, 441, 436],\n", - " [ 0, 25, 92, ..., 437, 108, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 418, 413, ..., 81, 31, 440],\n", - " [366, 280, 0, ..., 424, 465, 440],\n", - " [ 3, 70, 92, ..., 103, 376, 15],\n", - " ...,\n", - " [159, 206, 0, ..., 0, 109, 35],\n", - " [ 0, 387, 152, ..., 147, 40, 146],\n", - " [ 3, 291, 4, ..., 457, 136, 133]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 0, 1, 210, ..., 437, 181, 121],\n", - " [380, 0, 461, ..., 20, 93, 99],\n", - " [169, 84, 48, ..., 30, 51, 440],\n", - " ...,\n", - " [309, 455, 3, ..., 147, 313, 336],\n", - " [284, 1, 443, ..., 103, 202, 147],\n", - " [361, 1, 446, ..., 103, 354, 15]])\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([200, 19])\n", - "tensor([[ 48, 1, 26, ..., 424, 356, 235],\n", - " [ 0, 0, 20, ..., 114, 423, 437],\n", - " [191, 1, 299, ..., 103, 425, 440],\n", - " ...,\n", - " [100, 67, 1, ..., 437, 37, 94],\n", - " [ 0, 70, 118, ..., 103, 332, 422],\n", - " [ 73, 455, 0, ..., 455, 441, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[439, 1, 93, ..., 64, 40, 224],\n", - " [ 67, 3, 294, ..., 64, 224, 440],\n", - " [169, 337, 1, ..., 147, 146, 1],\n", - " ...,\n", - " [444, 423, 48, ..., 81, 335, 38],\n", - " [439, 1, 48, ..., 103, 0, 420],\n", - " [ 0, 0, 48, ..., 426, 221, 224]])\n", - "torch.Size([200, 19])\n", - "tensor([[ 59, 1, 0, ..., 317, 105, 335],\n", - " [439, 1, 224, ..., 414, 37, 440],\n", - " [ 70, 1, 0, ..., 3, 295, 436],\n", - " ...,\n", - " [415, 1, 0, ..., 136, 224, 335],\n", - " [213, 0, 1, ..., 103, 411, 436],\n", - " [ 0, 0, 0, ..., 358, 414, 1]])\n" + "Batch #: 0\n" ] } ], @@ -962,27 +144,57 @@ { "cell_type": "code", "execution_count": null, - "id": "9c6d7d97", + "id": "1f27301f", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", - "execution_count": 11, - "id": "3d543ae6", + "execution_count": 9, + "id": "8a16a24a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 1.1758, 0.0677, -0.5526, ..., -0.7825, -0.5267, -0.1815],\n", + " [ 0.2582, -0.7442, 0.0576, ..., -1.0060, -0.2658, -0.9660],\n", + " [ 0.9580, -0.5079, -0.8934, ..., -0.5971, -0.0238, -0.3293],\n", + " ...,\n", + " [ 0.0645, 0.1041, -0.8826, ..., 0.0947, -0.1519, -0.1774],\n", + " [ 0.4002, 0.1209, -0.4545, ..., -0.8281, -1.0184, 0.7384],\n", + " [ 1.1294, -0.6548, -0.0167, ..., -1.0560, -0.1584, -0.8694]])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "model.ranks()" + "model.ranks" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "b924530f", + "execution_count": 10, + "id": "3663ecb4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'MMvecALR' object has no attribute 'ranks_matrix'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# h = model.ranks_df - model.ranks_df.mean(axis=0)\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m h \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mranks_matrix\u001b[49m\n\u001b[1;32m 4\u001b[0m k \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mlatent_dim\n", + "File \u001b[0;32m~/opt/miniconda3/envs/mmvec/lib/python3.10/site-packages/torch/nn/modules/module.py:1185\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1185\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, name))\n", + "\u001b[0;31mAttributeError\u001b[0m: 'MMvecALR' object has no attribute 'ranks_matrix'" + ] + } + ], "source": [ "# h = model.ranks_df - model.ranks_df.mean(axis=0)\n", "h = model.ranks_matrix\n", @@ -993,7 +205,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c64267df", + "id": "6b3bcb21", "metadata": {}, "outputs": [], "source": [ @@ -1003,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66ee80e0", + "id": "52e80094", "metadata": {}, "outputs": [], "source": [ @@ -1013,7 +225,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c1ba8000", + "id": "b24646d3", "metadata": {}, "outputs": [], "source": [ @@ -1023,7 +235,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b550250e", + "id": "afc2f855", "metadata": {}, "outputs": [], "source": [ @@ -1033,7 +245,7 @@ { "cell_type": "code", "execution_count": null, - "id": "42e4637c", + "id": "5b427a72", "metadata": {}, "outputs": [], "source": [ @@ -1042,525 +254,33 @@ }, { "cell_type": "code", - "execution_count": 13, - "id": "1cf51885", + "execution_count": 11, + "id": "752acb7a", "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
(2,3-dihydroxy-3-methylbutanoate)(2,5-diaminohexanoate)(3-hydroxypyridine)(3-methyladenine)(4-oxoproline)(5,6-dihydrothymine)(alanyl-leucine)(dehydroalanine)(glycero-3-phosphoethanolamine)(indoleacrylate)...thyminetryptophantyrosineuracilurateuridineurocanatevalinexanthinexylitol
rplo 1 (Cyanobacteria)1.1757660.067746-0.552615-0.1754780.6255560.0718180.807590-0.603510-0.627171-0.381616...-0.0435110.210599-0.1623330.025163-0.007784-0.5240750.659753-0.782468-0.526718-0.181540
rplo 2 (Firmicutes)0.258168-0.7441900.057575-0.1398430.546534-0.1656670.343791-0.439517-0.666039-0.410150...0.1379200.1694660.3499170.342212-0.071381-0.052086-0.100996-1.005960-0.265847-0.966022
rplo 60 (Firmicutes)0.958025-0.507875-0.8934010.241754-0.286902-0.0437520.2722530.417226-0.2924480.065090...0.6973500.051952-0.6837150.277871-0.3961330.8440740.610815-0.597109-0.023769-0.329333
rplo 7 (Actinobacteria)0.9198030.355543-0.450468-0.3769220.4424900.0494980.4198290.1836860.026987-0.625513...-0.1346560.4477720.1403080.308640-0.013243-0.8494601.202816-0.2298810.251655-0.254032
rplo 10 (Firmicutes)1.143667-0.617700-0.6932220.1995390.394505-0.239950-0.1745970.0469500.183876-0.300330...0.495663-0.2309000.343985-0.021149-0.2151280.4891720.304803-0.8025400.111719-0.414126
..................................................................
rplo 95 (Proteobacteria)-0.8979420.236029-0.405040-0.3714270.381187-0.0110430.1398770.2280630.1039170.207617...0.3556110.438821-0.261516-0.1030440.0158210.6492420.696189-0.050961-0.2506350.593225
rplo 96 (unknown)0.8719050.071470-0.382577-0.0898500.0997910.226773-0.2171440.7699840.7849190.361632...0.619834-0.3059870.616732-0.252027-0.8365410.1751150.724142-0.492409-0.038374-0.038857
rplo 97 (Firmicutes)0.0645210.104134-0.882605-0.478958-0.463571-0.6453460.0435120.4364980.9191830.265277...-0.427435-0.177776-0.9167010.122711-0.3019270.9857570.0654690.094737-0.151891-0.177370
rplo 98 (Actinobacteria)0.4001700.120926-0.454491-0.0275680.462932-0.8092990.197001-0.167618-0.0427040.013358...1.281862-0.094519-0.3422070.218514-0.343730-0.531529-0.677377-0.828105-1.0183690.738386
rplo 99 (Cyanobacteria)1.129443-0.654812-0.0166710.0420460.2327940.0029600.087835-0.866775-1.0589320.196746...0.2675050.2468620.594449-0.000066-0.6790700.2334830.244907-1.055982-0.158424-0.869372
\n", - "

466 rows × 85 columns

\n", - "
" - ], - "text/plain": [ - " (2,3-dihydroxy-3-methylbutanoate) \\\n", - "rplo 1 (Cyanobacteria) 1.175766 \n", - "rplo 2 (Firmicutes) 0.258168 \n", - "rplo 60 (Firmicutes) 0.958025 \n", - "rplo 7 (Actinobacteria) 0.919803 \n", - "rplo 10 (Firmicutes) 1.143667 \n", - "... ... \n", - "rplo 95 (Proteobacteria) -0.897942 \n", - "rplo 96 (unknown) 0.871905 \n", - "rplo 97 (Firmicutes) 0.064521 \n", - "rplo 98 (Actinobacteria) 0.400170 \n", - "rplo 99 (Cyanobacteria) 1.129443 \n", - "\n", - " (2,5-diaminohexanoate) (3-hydroxypyridine) \\\n", - "rplo 1 (Cyanobacteria) 0.067746 -0.552615 \n", - "rplo 2 (Firmicutes) -0.744190 0.057575 \n", - "rplo 60 (Firmicutes) -0.507875 -0.893401 \n", - "rplo 7 (Actinobacteria) 0.355543 -0.450468 \n", - "rplo 10 (Firmicutes) -0.617700 -0.693222 \n", - "... ... ... \n", - "rplo 95 (Proteobacteria) 0.236029 -0.405040 \n", - "rplo 96 (unknown) 0.071470 -0.382577 \n", - "rplo 97 (Firmicutes) 0.104134 -0.882605 \n", - "rplo 98 (Actinobacteria) 0.120926 -0.454491 \n", - "rplo 99 (Cyanobacteria) -0.654812 -0.016671 \n", - "\n", - " (3-methyladenine) (4-oxoproline) \\\n", - "rplo 1 (Cyanobacteria) -0.175478 0.625556 \n", - "rplo 2 (Firmicutes) -0.139843 0.546534 \n", - "rplo 60 (Firmicutes) 0.241754 -0.286902 \n", - "rplo 7 (Actinobacteria) -0.376922 0.442490 \n", - "rplo 10 (Firmicutes) 0.199539 0.394505 \n", - "... ... ... \n", - "rplo 95 (Proteobacteria) -0.371427 0.381187 \n", - "rplo 96 (unknown) -0.089850 0.099791 \n", - "rplo 97 (Firmicutes) -0.478958 -0.463571 \n", - "rplo 98 (Actinobacteria) -0.027568 0.462932 \n", - "rplo 99 (Cyanobacteria) 0.042046 0.232794 \n", - "\n", - " (5,6-dihydrothymine) (alanyl-leucine) \\\n", - "rplo 1 (Cyanobacteria) 0.071818 0.807590 \n", - "rplo 2 (Firmicutes) -0.165667 0.343791 \n", - "rplo 60 (Firmicutes) -0.043752 0.272253 \n", - "rplo 7 (Actinobacteria) 0.049498 0.419829 \n", - "rplo 10 (Firmicutes) -0.239950 -0.174597 \n", - "... ... ... \n", - "rplo 95 (Proteobacteria) -0.011043 0.139877 \n", - "rplo 96 (unknown) 0.226773 -0.217144 \n", - "rplo 97 (Firmicutes) -0.645346 0.043512 \n", - "rplo 98 (Actinobacteria) -0.809299 0.197001 \n", - "rplo 99 (Cyanobacteria) 0.002960 0.087835 \n", - "\n", - " (dehydroalanine) (glycero-3-phosphoethanolamine) \\\n", - "rplo 1 (Cyanobacteria) -0.603510 -0.627171 \n", - "rplo 2 (Firmicutes) -0.439517 -0.666039 \n", - "rplo 60 (Firmicutes) 0.417226 -0.292448 \n", - "rplo 7 (Actinobacteria) 0.183686 0.026987 \n", - "rplo 10 (Firmicutes) 0.046950 0.183876 \n", - "... ... ... \n", - "rplo 95 (Proteobacteria) 0.228063 0.103917 \n", - "rplo 96 (unknown) 0.769984 0.784919 \n", - "rplo 97 (Firmicutes) 0.436498 0.919183 \n", - "rplo 98 (Actinobacteria) -0.167618 -0.042704 \n", - "rplo 99 (Cyanobacteria) -0.866775 -1.058932 \n", - "\n", - " (indoleacrylate) ... thymine tryptophan \\\n", - "rplo 1 (Cyanobacteria) -0.381616 ... -0.043511 0.210599 \n", - "rplo 2 (Firmicutes) -0.410150 ... 0.137920 0.169466 \n", - "rplo 60 (Firmicutes) 0.065090 ... 0.697350 0.051952 \n", - "rplo 7 (Actinobacteria) -0.625513 ... -0.134656 0.447772 \n", - "rplo 10 (Firmicutes) -0.300330 ... 0.495663 -0.230900 \n", - "... ... ... ... ... \n", - "rplo 95 (Proteobacteria) 0.207617 ... 0.355611 0.438821 \n", - "rplo 96 (unknown) 0.361632 ... 0.619834 -0.305987 \n", - "rplo 97 (Firmicutes) 0.265277 ... -0.427435 -0.177776 \n", - "rplo 98 (Actinobacteria) 0.013358 ... 1.281862 -0.094519 \n", - "rplo 99 (Cyanobacteria) 0.196746 ... 0.267505 0.246862 \n", - "\n", - " tyrosine uracil urate uridine urocanate \\\n", - "rplo 1 (Cyanobacteria) -0.162333 0.025163 -0.007784 -0.524075 0.659753 \n", - "rplo 2 (Firmicutes) 0.349917 0.342212 -0.071381 -0.052086 -0.100996 \n", - "rplo 60 (Firmicutes) -0.683715 0.277871 -0.396133 0.844074 0.610815 \n", - "rplo 7 (Actinobacteria) 0.140308 0.308640 -0.013243 -0.849460 1.202816 \n", - "rplo 10 (Firmicutes) 0.343985 -0.021149 -0.215128 0.489172 0.304803 \n", - "... ... ... ... ... ... \n", - "rplo 95 (Proteobacteria) -0.261516 -0.103044 0.015821 0.649242 0.696189 \n", - "rplo 96 (unknown) 0.616732 -0.252027 -0.836541 0.175115 0.724142 \n", - "rplo 97 (Firmicutes) -0.916701 0.122711 -0.301927 0.985757 0.065469 \n", - "rplo 98 (Actinobacteria) -0.342207 0.218514 -0.343730 -0.531529 -0.677377 \n", - "rplo 99 (Cyanobacteria) 0.594449 -0.000066 -0.679070 0.233483 0.244907 \n", - "\n", - " valine xanthine xylitol \n", - "rplo 1 (Cyanobacteria) -0.782468 -0.526718 -0.181540 \n", - "rplo 2 (Firmicutes) -1.005960 -0.265847 -0.966022 \n", - "rplo 60 (Firmicutes) -0.597109 -0.023769 -0.329333 \n", - "rplo 7 (Actinobacteria) -0.229881 0.251655 -0.254032 \n", - "rplo 10 (Firmicutes) -0.802540 0.111719 -0.414126 \n", - "... ... ... ... \n", - "rplo 95 (Proteobacteria) -0.050961 -0.250635 0.593225 \n", - "rplo 96 (unknown) -0.492409 -0.038374 -0.038857 \n", - "rplo 97 (Firmicutes) 0.094737 -0.151891 -0.177370 \n", - "rplo 98 (Actinobacteria) -0.828105 -1.018369 0.738386 \n", - "rplo 99 (Cyanobacteria) -1.055982 -0.158424 -0.869372 \n", - "\n", - "[466 rows x 85 columns]" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" + "ename": "AttributeError", + "evalue": "'MMvecALR' object has no attribute 'U'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [11]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mU\u001b[49m\n", + "File \u001b[0;32m~/opt/miniconda3/envs/mmvec/lib/python3.10/site-packages/torch/nn/modules/module.py:1185\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1185\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, name))\n", + "\u001b[0;31mAttributeError\u001b[0m: 'MMvecALR' object has no attribute 'U'" + ] } ], "source": [ - "model.ranks_df" + "model.U" ] }, { "cell_type": "code", - "execution_count": 48, - "id": "a0d8a2a2", + "execution_count": null, + "id": "18824f01", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([466, 85])\n", - "torch.Size([85])\n", - "torch.Size([85, 85])\n", - "torch.Size([466])\n", - "tensor([ -1.1445, -6.3680, -4.3185, 4.9945, -1.4082, -8.2258, 1.8486,\n", - " -4.3756, 4.7173, -6.9559, 3.2740, 0.4376, -5.9801, 6.3349,\n", - " 1.0003, -8.6350, 2.8014, 0.6915, 1.2820, -4.5347, 4.0379,\n", - " -5.6501, 1.9838, 0.2808, -0.2633, -3.2443, -4.1948, 1.5267,\n", - " -1.1784, -2.8921, 2.6564, 1.6168, -0.1698, -0.9756, -4.8180,\n", - " -4.7775, 1.7069, -1.3242, -4.7067, 0.6504, -5.5535, 4.1691,\n", - " -2.5773, 4.0220, -2.2959, -6.7313, -6.9085, 1.4211, 3.6404,\n", - " 1.5132, 2.4939, 0.5881, 1.9394, 7.5292, -8.3888, -0.6539,\n", - " 10.1397, -2.4303, -3.2335, 1.4971, -1.7446, 2.4674, 3.4813,\n", - " 0.0457, 0.8573, 3.2580, -3.3791, 3.7058, 4.9621, -1.6585,\n", - " 0.2361, 0.1459, 4.4815, -4.4632, 7.9341, -3.4729, 6.7684,\n", - " -3.3562, 2.3691, 6.9417, 5.4023, -0.2494, 6.6950, 3.0350,\n", - " -3.9309, 0.9241, 0.5092, 1.8853, 2.6824, 6.8036, -0.4861,\n", - " 5.1062, 2.4624, -0.9210, 9.8531, 0.5789, 1.9744, -6.0279,\n", - " 2.6953, -7.4324, 3.2620, -2.9513, 9.6945, -1.3096, 2.4214,\n", - " -0.9256, -4.1337, 11.0306, 0.4112, 0.1319, -4.7474, 4.4109,\n", - " 2.7715, 2.8321, -0.6753, 3.6894, 2.2041, -6.7310, 3.7233,\n", - " -2.9628, 3.3455, -6.5863, -4.7063, 5.9036, -3.2456, -0.0869,\n", - " -3.5623, -6.0653, -5.9259, 3.8253, -11.0783, 0.0408, -4.8903,\n", - " 0.9617, -1.5991, 4.5272, 5.0266, 1.9491, -3.0679, -0.6566,\n", - " -5.9211, 7.6033, -7.0827, -3.6042, 1.3228, -6.4924, 1.7801,\n", - " -2.6599, 5.8849, 0.3166, -2.8488, -5.0392, -7.5366, -1.9267,\n", - " -1.2711, -5.4646, 1.5345, 2.9971, 0.9353, -5.5945, 5.0444,\n", - " -0.5978, 3.8224, 3.5736, 0.0708, 2.3189, 0.8124, -10.2031,\n", - " 0.1141, -3.4993, 1.6761, 1.1674, -6.0600, -5.8619, -5.7287,\n", - " -3.9315, -1.6097, 3.0353, 1.5874, 0.5476, 2.2777, 0.8731,\n", - " -0.6450, -7.1289, -0.0645, 0.5143, 0.9316, 2.2382, 3.6106,\n", - " -8.4546, -1.4214, -3.5985, -3.4575, -0.5916, 2.5210, -0.4692,\n", - " 2.5000, -6.4857, 0.7608, -5.6512, -7.5485, 4.4014, -6.1880,\n", - " -4.3598, -3.7198, 2.0294, -3.4157, 3.5193, -0.0775, 4.9616,\n", - " -2.4156, -2.3211, -0.2058, 5.9809, 2.5119, 0.0125, -2.1752,\n", - " -5.8627, 3.3680, 3.4236, -4.8590, -1.8029, -2.5155, 5.7980,\n", - " -5.7988, -6.8575, 8.2150, 6.3653, -3.2859, 3.0349, -9.1496,\n", - " 6.2771, -5.9259, 7.2524, 3.4612, -0.6388, 6.2876, 3.6062,\n", - " 1.0134, 0.9312, -2.3465, 3.3057, 1.0286, 9.0334, -8.9800,\n", - " -3.1666, 5.6102, -1.1290, -0.4083, 4.8317, 9.2724, 0.0997,\n", - " -8.8108, 3.4332, 0.3276, 5.0469, -2.0226, -1.6557, 5.6105,\n", - " 1.8530, 2.8858, 1.5988, 2.5177, -2.7918, -2.6911, 2.7218,\n", - " -3.1462, -6.2753, -2.6276, -1.8484, 3.0457, -3.4599, -5.8190,\n", - " -0.9930, -9.0980, 7.5351, -1.4414, -3.8330, -3.9160, -2.0748,\n", - " -4.7279, -2.6979, 2.2114, -9.7617, -1.9074, 9.8307, 1.1703,\n", - " 2.3597, 3.6719, -1.4355, 1.3314, 0.9512, -5.2816, -3.2768,\n", - " 2.1892, -8.9302, -2.4061, 4.7443, -0.6404, -7.9222, 3.9574,\n", - " -5.7212, -2.4539, 3.4378, -7.4782, 1.8264, -1.5297, 4.7548,\n", - " -5.7164, -1.8924, -0.9265, -3.2981, -2.6631, -0.3037, 2.1184,\n", - " 2.4061, -3.7237, 0.9267, 2.6104, -3.0550, 7.8785, 1.0147,\n", - " 3.5998, 3.8647, 3.3049, 0.7033, -4.0938, -6.9029, -2.7553,\n", - " -0.8194, 4.4504, -0.2810, -3.4939, -1.3974, 4.2549, -9.4413,\n", - " 6.4951, 5.9425, -2.5674, 3.1822, -0.9808, -4.4396, 6.4448,\n", - " -0.3536, 5.1797, 0.8818, 4.1052, 4.9712, -0.7238, 8.7621,\n", - " -5.2645, 1.5924, -4.0963, 4.6621, -10.9097, 0.4642, -0.5150,\n", - " -2.9584, -5.4681, 2.4455, -1.9391, 4.9934, -4.7105, -0.8750,\n", - " 6.3088, 0.2136, -2.9872, -1.8482, -5.2081, 1.9450, -3.2619,\n", - " -0.6486, 3.6653, -1.8660, -1.0397, 8.5315, -1.5133, 4.1649,\n", - " -0.9625, 0.8924, -1.6494, 5.3174, 7.2113, -0.4926, 2.1117,\n", - " 2.0516, 3.9590, 3.3258, 4.5366, 2.1683, -1.0748, 2.4090,\n", - " -4.1125, -1.6299, 1.8558, 0.0114, -3.8395, -0.3071, -0.2672,\n", - " 4.3818, 2.9695, 3.4528, 9.6955, -3.0135, -3.6088, 3.4343,\n", - " -3.9485, -3.1757, 4.3005, 1.0197, -3.4628, 0.1942, 0.7603,\n", - " 2.1585, 4.4071, 2.2928, 10.1469, 7.1473, 5.0083, -3.6591,\n", - " -0.3181, 4.8017, 2.0600, 0.7875, 3.8353, 1.9623, 3.0753,\n", - " 3.4961, 0.2156, -1.6791, 2.8405, 3.2189, 5.8801, 0.5369,\n", - " 1.1090, 0.5457, 1.0708, 3.6782, 2.6795, 0.1788, 7.0609,\n", - " -0.4870, -1.4217, -3.9887, 4.7482, -3.6168, -2.9442, -3.7465,\n", - " -0.3917, 5.7974, -1.7506, 1.5932, 2.9426, -4.3741, 0.0520,\n", - " 0.4566, -2.2609, 1.0170, 4.9163, -2.8058, -7.9425, 5.4053,\n", - " -1.4912, -7.1008, -8.9607, -5.4042])\n" - ] - } - ], + "outputs": [], "source": [ "model.get_ordination()" ] @@ -1580,7 +300,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38b0b239", + "id": "99a455e5", "metadata": {}, "outputs": [], "source": [ @@ -1602,39 +322,10 @@ }, { "cell_type": "code", - "execution_count": 115, - "id": "6157d538", + "execution_count": null, + "id": "76a6f81b", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(tensor([[9., 6., 7.],\n", - " [3., 8., 2.],\n", - " [6., 5., 3.],\n", - " [4., 6., 4.],\n", - " [7., 5., 0.],\n", - " [6., 0., 3.],\n", - " [2., 4., 3.],\n", - " [5., 4., 9.],\n", - " [7., 3., 1.]]),\n", - " tensor([[ 3.5556, 1.4444, 3.4444],\n", - " [-2.4444, 3.4444, -1.5556],\n", - " [ 0.5556, 0.4444, -0.5556],\n", - " [-1.4444, 1.4444, 0.4444],\n", - " [ 1.5556, 0.4444, -3.5556],\n", - " [ 0.5556, -4.5556, -0.5556],\n", - " [-3.4444, -0.5556, -0.5556],\n", - " [-0.4444, -0.5556, 5.4444],\n", - " [ 1.5556, -1.5556, -2.5556]]),\n", - " tensor([5.4444, 4.5556, 3.5556]))" - ] - }, - "execution_count": 115, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "a = torch.randint(10, (9, 3), dtype=torch.float)\n", "b = torch.randint(10, (3,))\n", @@ -1643,91 +334,19 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": null, "id": "aa9e8f31", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([466, 85])\n", - "torch.Size([85])\n", - "torch.Size([85, 85])\n", - "compare unmodified:\n", - "tensor([3.9370e+01, 3.0685e+01, 2.7607e+01, 2.6057e+01, 2.4771e+01, 2.4166e+01,\n", - " 2.3556e+01, 2.1152e+01, 2.0435e+01, 1.9499e+01, 1.8038e+01, 1.7893e+01,\n", - " 1.6577e+01, 1.5838e+01, 1.4549e+01, 8.0469e+00, 1.3352e-05, 7.6159e-06,\n", - " 6.4840e-06, 5.7771e-06, 5.1234e-06, 4.8650e-06, 4.7076e-06, 4.3261e-06,\n", - " 3.7033e-06, 3.3649e-06, 3.2426e-06, 3.1312e-06, 2.9980e-06, 2.9524e-06,\n", - " 2.7562e-06, 2.7255e-06, 2.5842e-06, 2.5818e-06, 2.5675e-06, 2.5544e-06,\n", - " 2.5496e-06, 2.5244e-06, 2.5015e-06, 2.4779e-06, 2.4762e-06, 2.4727e-06,\n", - " 2.4645e-06, 2.4528e-06, 2.4424e-06, 2.4378e-06, 2.4306e-06, 2.4263e-06,\n", - " 2.4155e-06, 2.3750e-06, 2.2824e-06, 2.2305e-06, 2.2250e-06, 2.2146e-06,\n", - " 2.2125e-06, 2.1982e-06, 2.1954e-06, 2.1732e-06, 2.1682e-06, 2.1579e-06,\n", - " 2.1559e-06, 2.1535e-06, 2.1495e-06, 2.1278e-06, 2.0853e-06, 2.0755e-06,\n", - " 2.0694e-06, 2.0658e-06, 2.0506e-06, 2.0451e-06, 2.0284e-06, 2.0273e-06,\n", - " 2.0137e-06, 2.0100e-06, 2.0076e-06, 2.0001e-06, 1.9894e-06, 1.9617e-06,\n", - " 1.9169e-06, 1.9098e-06, 1.9008e-06, 1.8981e-06, 1.8605e-06, 1.7143e-06,\n", - " 1.3593e-06])\n", - "sqrt:\n", - "tensor([6.2745e+00, 5.5394e+00, 5.2542e+00, 5.1046e+00, 4.9770e+00, 4.9159e+00,\n", - " 4.8535e+00, 4.5991e+00, 4.5205e+00, 4.4158e+00, 4.2471e+00, 4.2300e+00,\n", - " 4.0714e+00, 3.9797e+00, 3.8143e+00, 2.8367e+00, 3.6540e-03, 2.7597e-03,\n", - " 2.5464e-03, 2.4036e-03, 2.2635e-03, 2.2057e-03, 2.1697e-03, 2.0799e-03,\n", - " 1.9244e-03, 1.8344e-03, 1.8007e-03, 1.7695e-03, 1.7315e-03, 1.7182e-03,\n", - " 1.6602e-03, 1.6509e-03, 1.6075e-03, 1.6068e-03, 1.6023e-03, 1.5983e-03,\n", - " 1.5968e-03, 1.5888e-03, 1.5816e-03, 1.5741e-03, 1.5736e-03, 1.5725e-03,\n", - " 1.5699e-03, 1.5661e-03, 1.5628e-03, 1.5613e-03, 1.5590e-03, 1.5577e-03,\n", - " 1.5542e-03, 1.5411e-03, 1.5108e-03, 1.4935e-03, 1.4916e-03, 1.4882e-03,\n", - " 1.4874e-03, 1.4826e-03, 1.4817e-03, 1.4742e-03, 1.4725e-03, 1.4690e-03,\n", - " 1.4683e-03, 1.4675e-03, 1.4661e-03, 1.4587e-03, 1.4441e-03, 1.4407e-03,\n", - " 1.4385e-03, 1.4373e-03, 1.4320e-03, 1.4301e-03, 1.4242e-03, 1.4238e-03,\n", - " 1.4190e-03, 1.4177e-03, 1.4169e-03, 1.4142e-03, 1.4104e-03, 1.4006e-03,\n", - " 1.3845e-03, 1.3820e-03, 1.3787e-03, 1.3777e-03, 1.3640e-03, 1.3093e-03,\n", - " 1.1659e-03])\n", - "torch.Size([466, 85])\n", - "tensor([[-8.1189e-01, -1.8219e+00, -2.1267e+00, ..., 6.5056e-09,\n", - " 1.0232e-07, -5.1138e-09],\n", - " [-4.1947e+00, 3.4689e-02, 1.1555e+00, ..., -4.7035e-08,\n", - " 3.9148e-08, -6.2629e-08],\n", - " [-9.1145e-01, 1.1518e+00, -6.4364e-01, ..., 1.0274e-07,\n", - " 1.3688e-07, -2.0395e-10],\n", - " ...,\n", - " [ 9.2180e-01, -2.7336e-01, -2.8572e-01, ..., -5.4841e-08,\n", - " 5.8509e-08, -1.8683e-08],\n", - " [-9.7168e-01, -1.1153e+00, -1.3348e+00, ..., -1.0824e-07,\n", - " -6.6422e-08, -4.7472e-08],\n", - " [-4.2655e+00, -1.0002e+00, 2.2608e+00, ..., -1.4434e-07,\n", - " -2.8319e-08, 4.3824e-08]])\n", - "['PC0', 'PC1', 'PC2', 'PC3', 'PC4', 'PC5', 'PC6', 'PC7', 'PC8', 'PC9', 'PC10', 'PC11', 'PC12', 'PC13', 'PC14', 'PC15', 'PC16', 'PC17', 'PC18', 'PC19', 'PC20', 'PC21', 'PC22', 'PC23', 'PC24', 'PC25', 'PC26', 'PC27', 'PC28', 'PC29', 'PC30', 'PC31', 'PC32', 'PC33', 'PC34', 'PC35', 'PC36', 'PC37', 'PC38', 'PC39', 'PC40', 'PC41', 'PC42', 'PC43', 'PC44', 'PC45', 'PC46', 'PC47', 'PC48', 'PC49', 'PC50', 'PC51', 'PC52', 'PC53', 'PC54', 'PC55', 'PC56', 'PC57', 'PC58', 'PC59', 'PC60', 'PC61', 'PC62', 'PC63', 'PC64', 'PC65', 'PC66', 'PC67', 'PC68', 'PC69', 'PC70', 'PC71', 'PC72', 'PC73', 'PC74', 'PC75', 'PC76', 'PC77', 'PC78', 'PC79', 'PC80', 'PC81', 'PC82', 'PC83', 'PC84']\n", - "tensor([[-8.1189e-01, -1.8219e+00, -2.1267e+00, ..., 6.5056e-09,\n", - " 1.0232e-07, -5.1138e-09],\n", - " [-4.1947e+00, 3.4689e-02, 1.1555e+00, ..., -4.7035e-08,\n", - " 3.9148e-08, -6.2629e-08],\n", - " [-9.1145e-01, 1.1518e+00, -6.4364e-01, ..., 1.0274e-07,\n", - " 1.3688e-07, -2.0395e-10],\n", - " ...,\n", - " [ 9.2180e-01, -2.7336e-01, -2.8572e-01, ..., -5.4841e-08,\n", - " 5.8509e-08, -1.8683e-08],\n", - " [-9.7168e-01, -1.1153e+00, -1.3348e+00, ..., -1.0824e-07,\n", - " -6.6422e-08, -4.7472e-08],\n", - " [-4.2655e+00, -1.0002e+00, 2.2608e+00, ..., -1.4434e-07,\n", - " -2.8319e-08, 4.3824e-08]])\n" - ] - } - ], + "outputs": [], "source": [ "bp = model.get_ordination()" ] }, { "cell_type": "code", - "execution_count": 116, - "id": "7ac106a9", - "metadata": { - "collapsed": true - }, + "execution_count": null, + "id": "6b850001", + "metadata": {}, "outputs": [], "source": [ "import os.path" @@ -1735,8 +354,8 @@ }, { "cell_type": "code", - "execution_count": 117, - "id": "780e8bd7", + "execution_count": null, + "id": "8e3b2153", "metadata": {}, "outputs": [], "source": [ @@ -1745,458 +364,20 @@ }, { "cell_type": "code", - "execution_count": 118, - "id": "188a5f70", + "execution_count": null, + "id": "b18f859b", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'/Users/keeganevans/Desktop/biplot'" - ] - }, - "execution_count": 118, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "bp.write(bp_path)" ] }, { "cell_type": "code", - "execution_count": 119, - "id": "ba4b85ce", + "execution_count": null, + "id": "34569491", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
(2,3-dihydroxy-3-methylbutanoate)(2,5-diaminohexanoate)(3-hydroxypyridine)(3-methyladenine)(4-oxoproline)(5,6-dihydrothymine)(alanyl-leucine)(dehydroalanine)(glycero-3-phosphoethanolamine)(indoleacrylate)...thyminetryptophantyrosineuracilurateuridineurocanatevalinexanthinexylitol
rplo 1 (Cyanobacteria)1.1757660.067746-0.552615-0.1754780.6255560.0718180.807590-0.603510-0.627171-0.381616...-0.0435110.210599-0.1623330.025163-0.007784-0.5240750.659753-0.782468-0.526718-0.181540
rplo 2 (Firmicutes)0.258168-0.7441900.057575-0.1398430.546534-0.1656670.343791-0.439517-0.666039-0.410150...0.1379200.1694660.3499170.342212-0.071381-0.052086-0.100996-1.005960-0.265847-0.966022
rplo 60 (Firmicutes)0.958025-0.507875-0.8934010.241754-0.286902-0.0437520.2722530.417226-0.2924480.065090...0.6973500.051952-0.6837150.277871-0.3961330.8440740.610815-0.597109-0.023769-0.329333
rplo 7 (Actinobacteria)0.9198030.355543-0.450468-0.3769220.4424900.0494980.4198290.1836860.026987-0.625513...-0.1346560.4477720.1403080.308640-0.013243-0.8494601.202816-0.2298810.251655-0.254032
rplo 10 (Firmicutes)1.143667-0.617700-0.6932220.1995390.394505-0.239950-0.1745970.0469500.183876-0.300330...0.495663-0.2309000.343985-0.021149-0.2151280.4891720.304803-0.8025400.111719-0.414126
..................................................................
rplo 95 (Proteobacteria)-0.8979420.236029-0.405040-0.3714270.381187-0.0110430.1398770.2280630.1039170.207617...0.3556110.438821-0.261516-0.1030440.0158210.6492420.696189-0.050961-0.2506350.593225
rplo 96 (unknown)0.8719050.071470-0.382577-0.0898500.0997910.226773-0.2171440.7699840.7849190.361632...0.619834-0.3059870.616732-0.252027-0.8365410.1751150.724142-0.492409-0.038374-0.038857
rplo 97 (Firmicutes)0.0645210.104134-0.882605-0.478958-0.463571-0.6453460.0435120.4364980.9191830.265277...-0.427435-0.177776-0.9167010.122711-0.3019270.9857570.0654690.094737-0.151891-0.177370
rplo 98 (Actinobacteria)0.4001700.120926-0.454491-0.0275680.462932-0.8092990.197001-0.167618-0.0427040.013358...1.281862-0.094519-0.3422070.218514-0.343730-0.531529-0.677377-0.828105-1.0183690.738386
rplo 99 (Cyanobacteria)1.129443-0.654812-0.0166710.0420460.2327940.0029600.087835-0.866775-1.0589320.196746...0.2675050.2468620.594449-0.000066-0.6790700.2334830.244907-1.055982-0.158424-0.869372
\n", - "

466 rows × 85 columns

\n", - "
" - ], - "text/plain": [ - " (2,3-dihydroxy-3-methylbutanoate) \\\n", - "rplo 1 (Cyanobacteria) 1.175766 \n", - "rplo 2 (Firmicutes) 0.258168 \n", - "rplo 60 (Firmicutes) 0.958025 \n", - "rplo 7 (Actinobacteria) 0.919803 \n", - "rplo 10 (Firmicutes) 1.143667 \n", - "... ... \n", - "rplo 95 (Proteobacteria) -0.897942 \n", - "rplo 96 (unknown) 0.871905 \n", - "rplo 97 (Firmicutes) 0.064521 \n", - "rplo 98 (Actinobacteria) 0.400170 \n", - "rplo 99 (Cyanobacteria) 1.129443 \n", - "\n", - " (2,5-diaminohexanoate) (3-hydroxypyridine) \\\n", - "rplo 1 (Cyanobacteria) 0.067746 -0.552615 \n", - "rplo 2 (Firmicutes) -0.744190 0.057575 \n", - "rplo 60 (Firmicutes) -0.507875 -0.893401 \n", - "rplo 7 (Actinobacteria) 0.355543 -0.450468 \n", - "rplo 10 (Firmicutes) -0.617700 -0.693222 \n", - "... ... ... \n", - "rplo 95 (Proteobacteria) 0.236029 -0.405040 \n", - "rplo 96 (unknown) 0.071470 -0.382577 \n", - "rplo 97 (Firmicutes) 0.104134 -0.882605 \n", - "rplo 98 (Actinobacteria) 0.120926 -0.454491 \n", - "rplo 99 (Cyanobacteria) -0.654812 -0.016671 \n", - "\n", - " (3-methyladenine) (4-oxoproline) \\\n", - "rplo 1 (Cyanobacteria) -0.175478 0.625556 \n", - "rplo 2 (Firmicutes) -0.139843 0.546534 \n", - "rplo 60 (Firmicutes) 0.241754 -0.286902 \n", - "rplo 7 (Actinobacteria) -0.376922 0.442490 \n", - "rplo 10 (Firmicutes) 0.199539 0.394505 \n", - "... ... ... \n", - "rplo 95 (Proteobacteria) -0.371427 0.381187 \n", - "rplo 96 (unknown) -0.089850 0.099791 \n", - "rplo 97 (Firmicutes) -0.478958 -0.463571 \n", - "rplo 98 (Actinobacteria) -0.027568 0.462932 \n", - "rplo 99 (Cyanobacteria) 0.042046 0.232794 \n", - "\n", - " (5,6-dihydrothymine) (alanyl-leucine) \\\n", - "rplo 1 (Cyanobacteria) 0.071818 0.807590 \n", - "rplo 2 (Firmicutes) -0.165667 0.343791 \n", - "rplo 60 (Firmicutes) -0.043752 0.272253 \n", - "rplo 7 (Actinobacteria) 0.049498 0.419829 \n", - "rplo 10 (Firmicutes) -0.239950 -0.174597 \n", - "... ... ... \n", - "rplo 95 (Proteobacteria) -0.011043 0.139877 \n", - "rplo 96 (unknown) 0.226773 -0.217144 \n", - "rplo 97 (Firmicutes) -0.645346 0.043512 \n", - "rplo 98 (Actinobacteria) -0.809299 0.197001 \n", - "rplo 99 (Cyanobacteria) 0.002960 0.087835 \n", - "\n", - " (dehydroalanine) (glycero-3-phosphoethanolamine) \\\n", - "rplo 1 (Cyanobacteria) -0.603510 -0.627171 \n", - "rplo 2 (Firmicutes) -0.439517 -0.666039 \n", - "rplo 60 (Firmicutes) 0.417226 -0.292448 \n", - "rplo 7 (Actinobacteria) 0.183686 0.026987 \n", - "rplo 10 (Firmicutes) 0.046950 0.183876 \n", - "... ... ... \n", - "rplo 95 (Proteobacteria) 0.228063 0.103917 \n", - "rplo 96 (unknown) 0.769984 0.784919 \n", - "rplo 97 (Firmicutes) 0.436498 0.919183 \n", - "rplo 98 (Actinobacteria) -0.167618 -0.042704 \n", - "rplo 99 (Cyanobacteria) -0.866775 -1.058932 \n", - "\n", - " (indoleacrylate) ... thymine tryptophan \\\n", - "rplo 1 (Cyanobacteria) -0.381616 ... -0.043511 0.210599 \n", - "rplo 2 (Firmicutes) -0.410150 ... 0.137920 0.169466 \n", - "rplo 60 (Firmicutes) 0.065090 ... 0.697350 0.051952 \n", - "rplo 7 (Actinobacteria) -0.625513 ... -0.134656 0.447772 \n", - "rplo 10 (Firmicutes) -0.300330 ... 0.495663 -0.230900 \n", - "... ... ... ... ... \n", - "rplo 95 (Proteobacteria) 0.207617 ... 0.355611 0.438821 \n", - "rplo 96 (unknown) 0.361632 ... 0.619834 -0.305987 \n", - "rplo 97 (Firmicutes) 0.265277 ... -0.427435 -0.177776 \n", - "rplo 98 (Actinobacteria) 0.013358 ... 1.281862 -0.094519 \n", - "rplo 99 (Cyanobacteria) 0.196746 ... 0.267505 0.246862 \n", - "\n", - " tyrosine uracil urate uridine urocanate \\\n", - "rplo 1 (Cyanobacteria) -0.162333 0.025163 -0.007784 -0.524075 0.659753 \n", - "rplo 2 (Firmicutes) 0.349917 0.342212 -0.071381 -0.052086 -0.100996 \n", - "rplo 60 (Firmicutes) -0.683715 0.277871 -0.396133 0.844074 0.610815 \n", - "rplo 7 (Actinobacteria) 0.140308 0.308640 -0.013243 -0.849460 1.202816 \n", - "rplo 10 (Firmicutes) 0.343985 -0.021149 -0.215128 0.489172 0.304803 \n", - "... ... ... ... ... ... \n", - "rplo 95 (Proteobacteria) -0.261516 -0.103044 0.015821 0.649242 0.696189 \n", - "rplo 96 (unknown) 0.616732 -0.252027 -0.836541 0.175115 0.724142 \n", - "rplo 97 (Firmicutes) -0.916701 0.122711 -0.301927 0.985757 0.065469 \n", - "rplo 98 (Actinobacteria) -0.342207 0.218514 -0.343730 -0.531529 -0.677377 \n", - "rplo 99 (Cyanobacteria) 0.594449 -0.000066 -0.679070 0.233483 0.244907 \n", - "\n", - " valine xanthine xylitol \n", - "rplo 1 (Cyanobacteria) -0.782468 -0.526718 -0.181540 \n", - "rplo 2 (Firmicutes) -1.005960 -0.265847 -0.966022 \n", - "rplo 60 (Firmicutes) -0.597109 -0.023769 -0.329333 \n", - "rplo 7 (Actinobacteria) -0.229881 0.251655 -0.254032 \n", - "rplo 10 (Firmicutes) -0.802540 0.111719 -0.414126 \n", - "... ... ... ... \n", - "rplo 95 (Proteobacteria) -0.050961 -0.250635 0.593225 \n", - "rplo 96 (unknown) -0.492409 -0.038374 -0.038857 \n", - "rplo 97 (Firmicutes) 0.094737 -0.151891 -0.177370 \n", - "rplo 98 (Actinobacteria) -0.828105 -1.018369 0.738386 \n", - "rplo 99 (Cyanobacteria) -1.055982 -0.158424 -0.869372 \n", - "\n", - "[466 rows x 85 columns]" - ] - }, - "execution_count": 119, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.ranks_df" ] diff --git a/mmvec/ALR.py b/mmvec/ALR.py index 766487e..1f8ca45 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -62,7 +62,7 @@ def __init__(self, microbes, metabolites, latent_dim, sigma_u=1, self.encoder = nn.Embedding(self.num_microbes, self.latent_dim) self.decoder = LinearALR(self.latent_dim, self.num_metabolites) - + def forward(self, X): # Three likelihoods, the likelihood of each weight and the likelihood @@ -140,7 +140,7 @@ def u_bias(self): def v_bias(self): #ensure consistent access return self.decoder.linear.bias.detach() - + @property def U(self): U = torch.cat( @@ -157,7 +157,7 @@ def V(self): torch.ones((1, self.num_metabolites - 1)), self.decoder.linear.weight.detach().T), dim=0) - return V + return V def ranks_dataframe(self): return pd.DataFrame(self.ranks(), index=self.microbe_idx, From 11e65fff36df3732e01d07ff661f64d29f5d98f4 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 5 May 2022 12:33:37 -0700 Subject: [PATCH 12/27] IMP: cleanup refactor examples directory --- examples/refactor/041222pytorchdraft.ipynb | 495 --------------------- examples/refactor/041422pytorchdraft.ipynb | 495 --------------------- 2 files changed, 990 deletions(-) delete mode 100644 examples/refactor/041222pytorchdraft.ipynb delete mode 100644 examples/refactor/041422pytorchdraft.ipynb diff --git a/examples/refactor/041222pytorchdraft.ipynb b/examples/refactor/041222pytorchdraft.ipynb deleted file mode 100644 index 019843d..0000000 --- a/examples/refactor/041222pytorchdraft.ipynb +++ /dev/null @@ -1,495 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "213bcdfc", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "from torch.utils.data import Dataset\n", - "from torch.distributions import Multinomial\n", - "import biom" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "382bb9ce", - "metadata": {}, - "outputs": [], - "source": [ - "# some example data\n", - "microbes = biom.load_table(\"./soil_microbes.biom\")\n", - "metabolites = biom.load_table(\"./soil_metabolites.biom\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "96fac3bf", - "metadata": {}, - "outputs": [], - "source": [ - "class MicrobeMetaboliteData(Dataset):\n", - " def __init__(self, microbes: biom.table, metabolites: biom.table):\n", - " # arrange\n", - " self.microbes = microbes.to_dataframe().T \n", - " self.metabolites = metabolites.to_dataframe().T\n", - " \n", - " # only samples that have results\n", - " self.microbes = self.microbes.loc[self.metabolites.index]\n", - " \n", - " # convert to tensors/final form\n", - " self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)\n", - " self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)\n", - " \n", - " # counts\n", - " self.microbe_count = self.microbes.shape[1]\n", - " self.metabolite_count = self.metabolites.shape[1]\n", - " \n", - " # relative frequencies\n", - " self.microbe_relative_frequency = (self.microbes.T\n", - " / self.microbes.sum(1)\n", - " ).T\n", - " \n", - " self.metabolite_relative_frequency = (self.metabolites.T\n", - " / self.metabolites.sum(1)\n", - " ).T\n", - " \n", - " self.total_microbe_observations = self.microbes.sum()\n", - " \n", - " def __len__(self):\n", - " return self.total_microbe_observations" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "234ccc47", - "metadata": {}, - "outputs": [], - "source": [ - "example_data = MicrobeMetaboliteData(microbes, metabolites)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "0ab12e60", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "424846" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "example_data.total_microbe_observations.item()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f106a231", - "metadata": {}, - "outputs": [], - "source": [ - "class MMVec(nn.Module):\n", - " def __init__(self, num_microbes, num_metabolites, latent_dim):\n", - " super().__init__()\n", - " #\n", - " self.encoder = nn.Embedding(num_microbes, latent_dim)\n", - " self.decoder = nn.Sequential(\n", - " nn.Linear(latent_dim, num_metabolites),\n", - " # [batch, sample, metabolite]\n", - " nn.Softmax(dim=2)\n", - " )\n", - " \n", - " # X = batch_size of microbe indexes\n", - " # Y = expected metabolite data\n", - " def forward(self, X, Y):\n", - " \n", - " # pass our random draws to our embedding\n", - " z = self.encoder(X)\n", - " \n", - " # from latent dimensions in embedding through\n", - " # our linear function to predicted metabolite frequencies which\n", - " # we then normalize with softmax\n", - " y_pred = self.decoder(z)\n", - " \n", - " # total_count=0 and validate_args=False allows skipping total count when calling log_prob\n", - " # as there having floating point issues leading to \"incorrect\" total counts.\n", - " # This multinomial is generated from the output of the single\n", - " forward_dist = Multinomial(total_count=0,\n", - " validate_args=False,\n", - " probs=y_pred)\n", - " \n", - " # the log probability of drawing our expected results from our \"predictions\"\n", - " forward_dist = forward_dist.log_prob(Y)\n", - " \n", - " # get sample loss, a sample in each \"row\"/ zeroeth dimension of the tensor\n", - " forward_dist = forward_dist.mean(0)\n", - " \n", - " # total log probability loss in regards to all samples\n", - " lp = forward_dist.mean()\n", - "\n", - " return lp" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "b74bdf61", - "metadata": {}, - "outputs": [], - "source": [ - "mmvec_model = MMVec(example_data.microbe_count, example_data.metabolite_count, 15)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "cbc8d647", - "metadata": {}, - "outputs": [], - "source": [ - "def train_loop(dataset, model, optimizer, batch_size):\n", - " \n", - " # because we are wanting to look at all of the samples together we are having to \n", - " # handle our own batching for now. This method currently leads to slight over-\n", - " # sampling but can be refined.\n", - " n_batches = torch.div(dataset.total_microbe_observations.item(),\n", - " batch_size,\n", - " rounding_mode = 'floor') + 1\n", - " \n", - " # We will want to implement batching functionality later for\n", - " # paralizability, but for now running on cpu this works.\n", - " for batch in range(n_batches * epochs):\n", - " \n", - " # the draws we will be training each batch on that will\n", - " # be fed to all samples in our model. This step will probably be\n", - " # moved to a sampler or collate_fn somewhere in the dataset/dataloader\n", - " # but how exactly that will work is not clear at the moment\n", - " draws = torch.multinomial(dataset.microbe_relative_frequency,\n", - " batch_size,\n", - " replacement=True).T\n", - " \n", - " # \"forward step\", our model generates our \"predictions\", so there is no need to\n", - " # call `forward` separately.\n", - " lp = model(draws,\n", - " dataset.metabolite_relative_frequency)\n", - " \n", - " # this location is idiomatic but flexible\n", - " optimizer.zero_grad()\n", - " \n", - " # the typical training bit.\n", - " lp.backward()\n", - " optimizer.step()\n", - " \n", - " if batch % 100 == 0:\n", - " print(f\"loss: {lp.item()}\\nBatch #: {batch}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cfb75b21", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loss: -4.114527225494385\n", - "Batch #: 0\n", - "loss: -3.6144325733184814\n", - "Batch #: 100\n", - "loss: -3.0469698905944824\n", - "Batch #: 200\n", - "loss: -2.70939564704895\n", - "Batch #: 300\n", - "loss: -2.5499744415283203\n", - "Batch #: 400\n", - "loss: -2.473045587539673\n", - "Batch #: 500\n", - "loss: -2.4374732971191406\n", - "Batch #: 600\n", - "loss: -2.421781539916992\n", - "Batch #: 700\n", - "loss: -2.4101920127868652\n", - "Batch #: 800\n", - "loss: -2.4041030406951904\n", - "Batch #: 900\n", - "loss: -2.4012131690979004\n", - "Batch #: 1000\n", - "loss: -2.397974967956543\n", - "Batch #: 1100\n", - "loss: -2.3931915760040283\n", - "Batch #: 1200\n", - "loss: -2.3923048973083496\n", - "Batch #: 1300\n", - "loss: -2.389982223510742\n", - "Batch #: 1400\n", - "loss: -2.3868303298950195\n", - "Batch #: 1500\n", - "loss: -2.3855628967285156\n", - "Batch #: 1600\n", - "loss: -2.382643222808838\n", - "Batch #: 1700\n", - "loss: -2.381664991378784\n", - "Batch #: 1800\n", - "loss: -2.3774473667144775\n", - "Batch #: 1900\n", - "loss: -2.378610372543335\n", - "Batch #: 2000\n", - "loss: -2.3776485919952393\n", - "Batch #: 2100\n", - "loss: -2.376375675201416\n", - "Batch #: 2200\n", - "loss: -2.3723671436309814\n", - "Batch #: 2300\n", - "loss: -2.372851848602295\n", - "Batch #: 2400\n", - "loss: -2.373134136199951\n", - "Batch #: 2500\n", - "loss: -2.3704051971435547\n", - "Batch #: 2600\n", - "loss: -2.37052059173584\n", - "Batch #: 2700\n", - "loss: -2.371293306350708\n", - "Batch #: 2800\n", - "loss: -2.3711659908294678\n", - "Batch #: 2900\n", - "loss: -2.3693435192108154\n", - "Batch #: 3000\n", - "loss: -2.370833396911621\n", - "Batch #: 3100\n", - "loss: -2.36956787109375\n", - "Batch #: 3200\n", - "loss: -2.3683981895446777\n", - "Batch #: 3300\n", - "loss: -2.368025064468384\n", - "Batch #: 3400\n", - "loss: -2.3673665523529053\n", - "Batch #: 3500\n", - "loss: -2.3669538497924805\n", - "Batch #: 3600\n", - "loss: -2.364877700805664\n", - "Batch #: 3700\n", - "loss: -2.3676393032073975\n", - "Batch #: 3800\n", - "loss: -2.3655707836151123\n", - "Batch #: 3900\n", - "loss: -2.365952253341675\n", - "Batch #: 4000\n", - "loss: -2.366527557373047\n", - "Batch #: 4100\n", - "loss: -2.364421844482422\n", - "Batch #: 4200\n", - "loss: -2.363978385925293\n", - "Batch #: 4300\n", - "loss: -2.3649704456329346\n", - "Batch #: 4400\n", - "loss: -2.364382743835449\n", - "Batch #: 4500\n", - "loss: -2.361299991607666\n", - "Batch #: 4600\n", - "loss: -2.3609752655029297\n", - "Batch #: 4700\n", - "loss: -2.3623459339141846\n", - "Batch #: 4800\n", - "loss: -2.3606176376342773\n", - "Batch #: 4900\n", - "loss: -2.3621227741241455\n", - "Batch #: 5000\n", - "loss: -2.3601856231689453\n", - "Batch #: 5100\n", - "loss: -2.3616325855255127\n", - "Batch #: 5200\n", - "loss: -2.3607864379882812\n", - "Batch #: 5300\n", - "loss: -2.3603267669677734\n", - "Batch #: 5400\n", - "loss: -2.3611979484558105\n", - "Batch #: 5500\n", - "loss: -2.36138653755188\n", - "Batch #: 5600\n", - "loss: -2.3617565631866455\n", - "Batch #: 5700\n", - "loss: -2.3602635860443115\n", - "Batch #: 5800\n", - "loss: -2.3588624000549316\n", - "Batch #: 5900\n", - "loss: -2.363048791885376\n", - "Batch #: 6000\n", - "loss: -2.357430934906006\n", - "Batch #: 6100\n", - "loss: -2.359692335128784\n", - "Batch #: 6200\n", - "loss: -2.359476327896118\n", - "Batch #: 6300\n", - "loss: -2.358708381652832\n", - "Batch #: 6400\n", - "loss: -2.3578848838806152\n", - "Batch #: 6500\n", - "loss: -2.3591620922088623\n", - "Batch #: 6600\n", - "loss: -2.3596458435058594\n", - "Batch #: 6700\n", - "loss: -2.358290672302246\n", - "Batch #: 6800\n", - "loss: -2.3569066524505615\n", - "Batch #: 6900\n", - "loss: -2.3586177825927734\n", - "Batch #: 7000\n", - "loss: -2.359415054321289\n", - "Batch #: 7100\n", - "loss: -2.358649969100952\n", - "Batch #: 7200\n", - "loss: -2.35966420173645\n", - "Batch #: 7300\n", - "loss: -2.358867883682251\n", - "Batch #: 7400\n", - "loss: -2.3568341732025146\n", - "Batch #: 7500\n", - "loss: -2.3596749305725098\n", - "Batch #: 7600\n", - "loss: -2.359412670135498\n", - "Batch #: 7700\n", - "loss: -2.357198476791382\n", - "Batch #: 7800\n", - "loss: -2.358001947402954\n", - "Batch #: 7900\n", - "loss: -2.3569891452789307\n", - "Batch #: 8000\n", - "loss: -2.3587193489074707\n", - "Batch #: 8100\n", - "loss: -2.3581130504608154\n", - "Batch #: 8200\n", - "loss: -2.3578381538391113\n", - "Batch #: 8300\n", - "loss: -2.357231855392456\n", - "Batch #: 8400\n", - "loss: -2.3578529357910156\n", - "Batch #: 8500\n", - "loss: -2.3557262420654297\n", - "Batch #: 8600\n", - "loss: -2.355126142501831\n", - "Batch #: 8700\n", - "loss: -2.3567700386047363\n", - "Batch #: 8800\n", - "loss: -2.3553476333618164\n", - "Batch #: 8900\n", - "loss: -2.356520175933838\n", - "Batch #: 9000\n", - "loss: -2.3572936058044434\n", - "Batch #: 9100\n", - "loss: -2.358710527420044\n", - "Batch #: 9200\n", - "loss: -2.3547816276550293\n", - "Batch #: 9300\n", - "loss: -2.3565027713775635\n", - "Batch #: 9400\n", - "loss: -2.3561108112335205\n", - "Batch #: 9500\n", - "loss: -2.356635808944702\n", - "Batch #: 9600\n", - "loss: -2.356121301651001\n", - "Batch #: 9700\n", - "loss: -2.3586411476135254\n", - "Batch #: 9800\n", - "loss: -2.3572912216186523\n", - "Batch #: 9900\n", - "loss: -2.35567045211792\n", - "Batch #: 10000\n", - "loss: -2.3584144115448\n", - "Batch #: 10100\n", - "loss: -2.3562276363372803\n", - "Batch #: 10200\n", - "loss: -2.3546085357666016\n", - "Batch #: 10300\n", - "loss: -2.3559350967407227\n", - "Batch #: 10400\n", - "loss: -2.356455087661743\n", - "Batch #: 10500\n", - "loss: -2.3574140071868896\n", - "Batch #: 10600\n", - "loss: -2.3562002182006836\n", - "Batch #: 10700\n", - "loss: -2.35746169090271\n", - "Batch #: 10800\n", - "loss: -2.3548736572265625\n", - "Batch #: 10900\n", - "loss: -2.3564090728759766\n", - "Batch #: 11000\n", - "loss: -2.3564658164978027\n", - "Batch #: 11100\n", - "loss: -2.3554699420928955\n", - "Batch #: 11200\n", - "loss: -2.3563244342803955\n", - "Batch #: 11300\n", - "loss: -2.357598066329956\n", - "Batch #: 11400\n", - "loss: -2.35477614402771\n", - "Batch #: 11500\n", - "loss: -2.3572442531585693\n", - "Batch #: 11600\n", - "loss: -2.357273817062378\n", - "Batch #: 11700\n", - "loss: -2.3560562133789062\n", - "Batch #: 11800\n", - "loss: -2.355698823928833\n", - "Batch #: 11900\n", - "loss: -2.3559463024139404\n", - "Batch #: 12000\n", - "loss: -2.35664439201355\n", - "Batch #: 12100\n", - "loss: -2.355379104614258\n", - "Batch #: 12200\n", - "loss: -2.354964256286621\n", - "Batch #: 12300\n" - ] - } - ], - "source": [ - "learning_rate = 1e-3\n", - "batch_size = 500\n", - "epochs = 25\n", - "optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)\n", - "\n", - "# run the training loop \n", - "train_loop(dataset=example_data, model=mmvec_model, optimizer=optimizer, batch_size=batch_size)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/refactor/041422pytorchdraft.ipynb b/examples/refactor/041422pytorchdraft.ipynb deleted file mode 100644 index 019843d..0000000 --- a/examples/refactor/041422pytorchdraft.ipynb +++ /dev/null @@ -1,495 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "213bcdfc", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "from torch.utils.data import Dataset\n", - "from torch.distributions import Multinomial\n", - "import biom" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "382bb9ce", - "metadata": {}, - "outputs": [], - "source": [ - "# some example data\n", - "microbes = biom.load_table(\"./soil_microbes.biom\")\n", - "metabolites = biom.load_table(\"./soil_metabolites.biom\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "96fac3bf", - "metadata": {}, - "outputs": [], - "source": [ - "class MicrobeMetaboliteData(Dataset):\n", - " def __init__(self, microbes: biom.table, metabolites: biom.table):\n", - " # arrange\n", - " self.microbes = microbes.to_dataframe().T \n", - " self.metabolites = metabolites.to_dataframe().T\n", - " \n", - " # only samples that have results\n", - " self.microbes = self.microbes.loc[self.metabolites.index]\n", - " \n", - " # convert to tensors/final form\n", - " self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)\n", - " self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)\n", - " \n", - " # counts\n", - " self.microbe_count = self.microbes.shape[1]\n", - " self.metabolite_count = self.metabolites.shape[1]\n", - " \n", - " # relative frequencies\n", - " self.microbe_relative_frequency = (self.microbes.T\n", - " / self.microbes.sum(1)\n", - " ).T\n", - " \n", - " self.metabolite_relative_frequency = (self.metabolites.T\n", - " / self.metabolites.sum(1)\n", - " ).T\n", - " \n", - " self.total_microbe_observations = self.microbes.sum()\n", - " \n", - " def __len__(self):\n", - " return self.total_microbe_observations" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "234ccc47", - "metadata": {}, - "outputs": [], - "source": [ - "example_data = MicrobeMetaboliteData(microbes, metabolites)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "0ab12e60", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "424846" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "example_data.total_microbe_observations.item()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f106a231", - "metadata": {}, - "outputs": [], - "source": [ - "class MMVec(nn.Module):\n", - " def __init__(self, num_microbes, num_metabolites, latent_dim):\n", - " super().__init__()\n", - " #\n", - " self.encoder = nn.Embedding(num_microbes, latent_dim)\n", - " self.decoder = nn.Sequential(\n", - " nn.Linear(latent_dim, num_metabolites),\n", - " # [batch, sample, metabolite]\n", - " nn.Softmax(dim=2)\n", - " )\n", - " \n", - " # X = batch_size of microbe indexes\n", - " # Y = expected metabolite data\n", - " def forward(self, X, Y):\n", - " \n", - " # pass our random draws to our embedding\n", - " z = self.encoder(X)\n", - " \n", - " # from latent dimensions in embedding through\n", - " # our linear function to predicted metabolite frequencies which\n", - " # we then normalize with softmax\n", - " y_pred = self.decoder(z)\n", - " \n", - " # total_count=0 and validate_args=False allows skipping total count when calling log_prob\n", - " # as there having floating point issues leading to \"incorrect\" total counts.\n", - " # This multinomial is generated from the output of the single\n", - " forward_dist = Multinomial(total_count=0,\n", - " validate_args=False,\n", - " probs=y_pred)\n", - " \n", - " # the log probability of drawing our expected results from our \"predictions\"\n", - " forward_dist = forward_dist.log_prob(Y)\n", - " \n", - " # get sample loss, a sample in each \"row\"/ zeroeth dimension of the tensor\n", - " forward_dist = forward_dist.mean(0)\n", - " \n", - " # total log probability loss in regards to all samples\n", - " lp = forward_dist.mean()\n", - "\n", - " return lp" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "b74bdf61", - "metadata": {}, - "outputs": [], - "source": [ - "mmvec_model = MMVec(example_data.microbe_count, example_data.metabolite_count, 15)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "cbc8d647", - "metadata": {}, - "outputs": [], - "source": [ - "def train_loop(dataset, model, optimizer, batch_size):\n", - " \n", - " # because we are wanting to look at all of the samples together we are having to \n", - " # handle our own batching for now. This method currently leads to slight over-\n", - " # sampling but can be refined.\n", - " n_batches = torch.div(dataset.total_microbe_observations.item(),\n", - " batch_size,\n", - " rounding_mode = 'floor') + 1\n", - " \n", - " # We will want to implement batching functionality later for\n", - " # paralizability, but for now running on cpu this works.\n", - " for batch in range(n_batches * epochs):\n", - " \n", - " # the draws we will be training each batch on that will\n", - " # be fed to all samples in our model. This step will probably be\n", - " # moved to a sampler or collate_fn somewhere in the dataset/dataloader\n", - " # but how exactly that will work is not clear at the moment\n", - " draws = torch.multinomial(dataset.microbe_relative_frequency,\n", - " batch_size,\n", - " replacement=True).T\n", - " \n", - " # \"forward step\", our model generates our \"predictions\", so there is no need to\n", - " # call `forward` separately.\n", - " lp = model(draws,\n", - " dataset.metabolite_relative_frequency)\n", - " \n", - " # this location is idiomatic but flexible\n", - " optimizer.zero_grad()\n", - " \n", - " # the typical training bit.\n", - " lp.backward()\n", - " optimizer.step()\n", - " \n", - " if batch % 100 == 0:\n", - " print(f\"loss: {lp.item()}\\nBatch #: {batch}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cfb75b21", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loss: -4.114527225494385\n", - "Batch #: 0\n", - "loss: -3.6144325733184814\n", - "Batch #: 100\n", - "loss: -3.0469698905944824\n", - "Batch #: 200\n", - "loss: -2.70939564704895\n", - "Batch #: 300\n", - "loss: -2.5499744415283203\n", - "Batch #: 400\n", - "loss: -2.473045587539673\n", - "Batch #: 500\n", - "loss: -2.4374732971191406\n", - "Batch #: 600\n", - "loss: -2.421781539916992\n", - "Batch #: 700\n", - "loss: -2.4101920127868652\n", - "Batch #: 800\n", - "loss: -2.4041030406951904\n", - "Batch #: 900\n", - "loss: -2.4012131690979004\n", - "Batch #: 1000\n", - "loss: -2.397974967956543\n", - "Batch #: 1100\n", - "loss: -2.3931915760040283\n", - "Batch #: 1200\n", - "loss: -2.3923048973083496\n", - "Batch #: 1300\n", - "loss: -2.389982223510742\n", - "Batch #: 1400\n", - "loss: -2.3868303298950195\n", - "Batch #: 1500\n", - "loss: -2.3855628967285156\n", - "Batch #: 1600\n", - "loss: -2.382643222808838\n", - "Batch #: 1700\n", - "loss: -2.381664991378784\n", - "Batch #: 1800\n", - "loss: -2.3774473667144775\n", - "Batch #: 1900\n", - "loss: -2.378610372543335\n", - "Batch #: 2000\n", - "loss: -2.3776485919952393\n", - "Batch #: 2100\n", - "loss: -2.376375675201416\n", - "Batch #: 2200\n", - "loss: -2.3723671436309814\n", - "Batch #: 2300\n", - "loss: -2.372851848602295\n", - "Batch #: 2400\n", - "loss: -2.373134136199951\n", - "Batch #: 2500\n", - "loss: -2.3704051971435547\n", - "Batch #: 2600\n", - "loss: -2.37052059173584\n", - "Batch #: 2700\n", - "loss: -2.371293306350708\n", - "Batch #: 2800\n", - "loss: -2.3711659908294678\n", - "Batch #: 2900\n", - "loss: -2.3693435192108154\n", - "Batch #: 3000\n", - "loss: -2.370833396911621\n", - "Batch #: 3100\n", - "loss: -2.36956787109375\n", - "Batch #: 3200\n", - "loss: -2.3683981895446777\n", - "Batch #: 3300\n", - "loss: -2.368025064468384\n", - "Batch #: 3400\n", - "loss: -2.3673665523529053\n", - "Batch #: 3500\n", - "loss: -2.3669538497924805\n", - "Batch #: 3600\n", - "loss: -2.364877700805664\n", - "Batch #: 3700\n", - "loss: -2.3676393032073975\n", - "Batch #: 3800\n", - "loss: -2.3655707836151123\n", - "Batch #: 3900\n", - "loss: -2.365952253341675\n", - "Batch #: 4000\n", - "loss: -2.366527557373047\n", - "Batch #: 4100\n", - "loss: -2.364421844482422\n", - "Batch #: 4200\n", - "loss: -2.363978385925293\n", - "Batch #: 4300\n", - "loss: -2.3649704456329346\n", - "Batch #: 4400\n", - "loss: -2.364382743835449\n", - "Batch #: 4500\n", - "loss: -2.361299991607666\n", - "Batch #: 4600\n", - "loss: -2.3609752655029297\n", - "Batch #: 4700\n", - "loss: -2.3623459339141846\n", - "Batch #: 4800\n", - "loss: -2.3606176376342773\n", - "Batch #: 4900\n", - "loss: -2.3621227741241455\n", - "Batch #: 5000\n", - "loss: -2.3601856231689453\n", - "Batch #: 5100\n", - "loss: -2.3616325855255127\n", - "Batch #: 5200\n", - "loss: -2.3607864379882812\n", - "Batch #: 5300\n", - "loss: -2.3603267669677734\n", - "Batch #: 5400\n", - "loss: -2.3611979484558105\n", - "Batch #: 5500\n", - "loss: -2.36138653755188\n", - "Batch #: 5600\n", - "loss: -2.3617565631866455\n", - "Batch #: 5700\n", - "loss: -2.3602635860443115\n", - "Batch #: 5800\n", - "loss: -2.3588624000549316\n", - "Batch #: 5900\n", - "loss: -2.363048791885376\n", - "Batch #: 6000\n", - "loss: -2.357430934906006\n", - "Batch #: 6100\n", - "loss: -2.359692335128784\n", - "Batch #: 6200\n", - "loss: -2.359476327896118\n", - "Batch #: 6300\n", - "loss: -2.358708381652832\n", - "Batch #: 6400\n", - "loss: -2.3578848838806152\n", - "Batch #: 6500\n", - "loss: -2.3591620922088623\n", - "Batch #: 6600\n", - "loss: -2.3596458435058594\n", - "Batch #: 6700\n", - "loss: -2.358290672302246\n", - "Batch #: 6800\n", - "loss: -2.3569066524505615\n", - "Batch #: 6900\n", - "loss: -2.3586177825927734\n", - "Batch #: 7000\n", - "loss: -2.359415054321289\n", - "Batch #: 7100\n", - "loss: -2.358649969100952\n", - "Batch #: 7200\n", - "loss: -2.35966420173645\n", - "Batch #: 7300\n", - "loss: -2.358867883682251\n", - "Batch #: 7400\n", - "loss: -2.3568341732025146\n", - "Batch #: 7500\n", - "loss: -2.3596749305725098\n", - "Batch #: 7600\n", - "loss: -2.359412670135498\n", - "Batch #: 7700\n", - "loss: -2.357198476791382\n", - "Batch #: 7800\n", - "loss: -2.358001947402954\n", - "Batch #: 7900\n", - "loss: -2.3569891452789307\n", - "Batch #: 8000\n", - "loss: -2.3587193489074707\n", - "Batch #: 8100\n", - "loss: -2.3581130504608154\n", - "Batch #: 8200\n", - "loss: -2.3578381538391113\n", - "Batch #: 8300\n", - "loss: -2.357231855392456\n", - "Batch #: 8400\n", - "loss: -2.3578529357910156\n", - "Batch #: 8500\n", - "loss: -2.3557262420654297\n", - "Batch #: 8600\n", - "loss: -2.355126142501831\n", - "Batch #: 8700\n", - "loss: -2.3567700386047363\n", - "Batch #: 8800\n", - "loss: -2.3553476333618164\n", - "Batch #: 8900\n", - "loss: -2.356520175933838\n", - "Batch #: 9000\n", - "loss: -2.3572936058044434\n", - "Batch #: 9100\n", - "loss: -2.358710527420044\n", - "Batch #: 9200\n", - "loss: -2.3547816276550293\n", - "Batch #: 9300\n", - "loss: -2.3565027713775635\n", - "Batch #: 9400\n", - "loss: -2.3561108112335205\n", - "Batch #: 9500\n", - "loss: -2.356635808944702\n", - "Batch #: 9600\n", - "loss: -2.356121301651001\n", - "Batch #: 9700\n", - "loss: -2.3586411476135254\n", - "Batch #: 9800\n", - "loss: -2.3572912216186523\n", - "Batch #: 9900\n", - "loss: -2.35567045211792\n", - "Batch #: 10000\n", - "loss: -2.3584144115448\n", - "Batch #: 10100\n", - "loss: -2.3562276363372803\n", - "Batch #: 10200\n", - "loss: -2.3546085357666016\n", - "Batch #: 10300\n", - "loss: -2.3559350967407227\n", - "Batch #: 10400\n", - "loss: -2.356455087661743\n", - "Batch #: 10500\n", - "loss: -2.3574140071868896\n", - "Batch #: 10600\n", - "loss: -2.3562002182006836\n", - "Batch #: 10700\n", - "loss: -2.35746169090271\n", - "Batch #: 10800\n", - "loss: -2.3548736572265625\n", - "Batch #: 10900\n", - "loss: -2.3564090728759766\n", - "Batch #: 11000\n", - "loss: -2.3564658164978027\n", - "Batch #: 11100\n", - "loss: -2.3554699420928955\n", - "Batch #: 11200\n", - "loss: -2.3563244342803955\n", - "Batch #: 11300\n", - "loss: -2.357598066329956\n", - "Batch #: 11400\n", - "loss: -2.35477614402771\n", - "Batch #: 11500\n", - "loss: -2.3572442531585693\n", - "Batch #: 11600\n", - "loss: -2.357273817062378\n", - "Batch #: 11700\n", - "loss: -2.3560562133789062\n", - "Batch #: 11800\n", - "loss: -2.355698823928833\n", - "Batch #: 11900\n", - "loss: -2.3559463024139404\n", - "Batch #: 12000\n", - "loss: -2.35664439201355\n", - "Batch #: 12100\n", - "loss: -2.355379104614258\n", - "Batch #: 12200\n", - "loss: -2.354964256286621\n", - "Batch #: 12300\n" - ] - } - ], - "source": [ - "learning_rate = 1e-3\n", - "batch_size = 500\n", - "epochs = 25\n", - "optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)\n", - "\n", - "# run the training loop \n", - "train_loop(dataset=example_data, model=mmvec_model, optimizer=optimizer, batch_size=batch_size)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From b45f04e0c23cd8dea98918f59743a520fc63699f Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 5 May 2022 12:35:38 -0700 Subject: [PATCH 13/27] IMP: remove sneaky notebook checkpoints --- .../041222pytorchdraft-checkpoint.ipynb | 495 ------------------ .../041422pytorchdraft-checkpoint.ipynb | 495 ------------------ 2 files changed, 990 deletions(-) delete mode 100644 examples/refactor/.ipynb_checkpoints/041222pytorchdraft-checkpoint.ipynb delete mode 100644 examples/refactor/.ipynb_checkpoints/041422pytorchdraft-checkpoint.ipynb diff --git a/examples/refactor/.ipynb_checkpoints/041222pytorchdraft-checkpoint.ipynb b/examples/refactor/.ipynb_checkpoints/041222pytorchdraft-checkpoint.ipynb deleted file mode 100644 index 019843d..0000000 --- a/examples/refactor/.ipynb_checkpoints/041222pytorchdraft-checkpoint.ipynb +++ /dev/null @@ -1,495 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "213bcdfc", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "from torch.utils.data import Dataset\n", - "from torch.distributions import Multinomial\n", - "import biom" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "382bb9ce", - "metadata": {}, - "outputs": [], - "source": [ - "# some example data\n", - "microbes = biom.load_table(\"./soil_microbes.biom\")\n", - "metabolites = biom.load_table(\"./soil_metabolites.biom\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "96fac3bf", - "metadata": {}, - "outputs": [], - "source": [ - "class MicrobeMetaboliteData(Dataset):\n", - " def __init__(self, microbes: biom.table, metabolites: biom.table):\n", - " # arrange\n", - " self.microbes = microbes.to_dataframe().T \n", - " self.metabolites = metabolites.to_dataframe().T\n", - " \n", - " # only samples that have results\n", - " self.microbes = self.microbes.loc[self.metabolites.index]\n", - " \n", - " # convert to tensors/final form\n", - " self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)\n", - " self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)\n", - " \n", - " # counts\n", - " self.microbe_count = self.microbes.shape[1]\n", - " self.metabolite_count = self.metabolites.shape[1]\n", - " \n", - " # relative frequencies\n", - " self.microbe_relative_frequency = (self.microbes.T\n", - " / self.microbes.sum(1)\n", - " ).T\n", - " \n", - " self.metabolite_relative_frequency = (self.metabolites.T\n", - " / self.metabolites.sum(1)\n", - " ).T\n", - " \n", - " self.total_microbe_observations = self.microbes.sum()\n", - " \n", - " def __len__(self):\n", - " return self.total_microbe_observations" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "234ccc47", - "metadata": {}, - "outputs": [], - "source": [ - "example_data = MicrobeMetaboliteData(microbes, metabolites)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "0ab12e60", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "424846" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "example_data.total_microbe_observations.item()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f106a231", - "metadata": {}, - "outputs": [], - "source": [ - "class MMVec(nn.Module):\n", - " def __init__(self, num_microbes, num_metabolites, latent_dim):\n", - " super().__init__()\n", - " #\n", - " self.encoder = nn.Embedding(num_microbes, latent_dim)\n", - " self.decoder = nn.Sequential(\n", - " nn.Linear(latent_dim, num_metabolites),\n", - " # [batch, sample, metabolite]\n", - " nn.Softmax(dim=2)\n", - " )\n", - " \n", - " # X = batch_size of microbe indexes\n", - " # Y = expected metabolite data\n", - " def forward(self, X, Y):\n", - " \n", - " # pass our random draws to our embedding\n", - " z = self.encoder(X)\n", - " \n", - " # from latent dimensions in embedding through\n", - " # our linear function to predicted metabolite frequencies which\n", - " # we then normalize with softmax\n", - " y_pred = self.decoder(z)\n", - " \n", - " # total_count=0 and validate_args=False allows skipping total count when calling log_prob\n", - " # as there having floating point issues leading to \"incorrect\" total counts.\n", - " # This multinomial is generated from the output of the single\n", - " forward_dist = Multinomial(total_count=0,\n", - " validate_args=False,\n", - " probs=y_pred)\n", - " \n", - " # the log probability of drawing our expected results from our \"predictions\"\n", - " forward_dist = forward_dist.log_prob(Y)\n", - " \n", - " # get sample loss, a sample in each \"row\"/ zeroeth dimension of the tensor\n", - " forward_dist = forward_dist.mean(0)\n", - " \n", - " # total log probability loss in regards to all samples\n", - " lp = forward_dist.mean()\n", - "\n", - " return lp" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "b74bdf61", - "metadata": {}, - "outputs": [], - "source": [ - "mmvec_model = MMVec(example_data.microbe_count, example_data.metabolite_count, 15)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "cbc8d647", - "metadata": {}, - "outputs": [], - "source": [ - "def train_loop(dataset, model, optimizer, batch_size):\n", - " \n", - " # because we are wanting to look at all of the samples together we are having to \n", - " # handle our own batching for now. This method currently leads to slight over-\n", - " # sampling but can be refined.\n", - " n_batches = torch.div(dataset.total_microbe_observations.item(),\n", - " batch_size,\n", - " rounding_mode = 'floor') + 1\n", - " \n", - " # We will want to implement batching functionality later for\n", - " # paralizability, but for now running on cpu this works.\n", - " for batch in range(n_batches * epochs):\n", - " \n", - " # the draws we will be training each batch on that will\n", - " # be fed to all samples in our model. This step will probably be\n", - " # moved to a sampler or collate_fn somewhere in the dataset/dataloader\n", - " # but how exactly that will work is not clear at the moment\n", - " draws = torch.multinomial(dataset.microbe_relative_frequency,\n", - " batch_size,\n", - " replacement=True).T\n", - " \n", - " # \"forward step\", our model generates our \"predictions\", so there is no need to\n", - " # call `forward` separately.\n", - " lp = model(draws,\n", - " dataset.metabolite_relative_frequency)\n", - " \n", - " # this location is idiomatic but flexible\n", - " optimizer.zero_grad()\n", - " \n", - " # the typical training bit.\n", - " lp.backward()\n", - " optimizer.step()\n", - " \n", - " if batch % 100 == 0:\n", - " print(f\"loss: {lp.item()}\\nBatch #: {batch}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cfb75b21", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loss: -4.114527225494385\n", - "Batch #: 0\n", - "loss: -3.6144325733184814\n", - "Batch #: 100\n", - "loss: -3.0469698905944824\n", - "Batch #: 200\n", - "loss: -2.70939564704895\n", - "Batch #: 300\n", - "loss: -2.5499744415283203\n", - "Batch #: 400\n", - "loss: -2.473045587539673\n", - "Batch #: 500\n", - "loss: -2.4374732971191406\n", - "Batch #: 600\n", - "loss: -2.421781539916992\n", - "Batch #: 700\n", - "loss: -2.4101920127868652\n", - "Batch #: 800\n", - "loss: -2.4041030406951904\n", - "Batch #: 900\n", - "loss: -2.4012131690979004\n", - "Batch #: 1000\n", - "loss: -2.397974967956543\n", - "Batch #: 1100\n", - "loss: -2.3931915760040283\n", - "Batch #: 1200\n", - "loss: -2.3923048973083496\n", - "Batch #: 1300\n", - "loss: -2.389982223510742\n", - "Batch #: 1400\n", - "loss: -2.3868303298950195\n", - "Batch #: 1500\n", - "loss: -2.3855628967285156\n", - "Batch #: 1600\n", - "loss: -2.382643222808838\n", - "Batch #: 1700\n", - "loss: -2.381664991378784\n", - "Batch #: 1800\n", - "loss: -2.3774473667144775\n", - "Batch #: 1900\n", - "loss: -2.378610372543335\n", - "Batch #: 2000\n", - "loss: -2.3776485919952393\n", - "Batch #: 2100\n", - "loss: -2.376375675201416\n", - "Batch #: 2200\n", - "loss: -2.3723671436309814\n", - "Batch #: 2300\n", - "loss: -2.372851848602295\n", - "Batch #: 2400\n", - "loss: -2.373134136199951\n", - "Batch #: 2500\n", - "loss: -2.3704051971435547\n", - "Batch #: 2600\n", - "loss: -2.37052059173584\n", - "Batch #: 2700\n", - "loss: -2.371293306350708\n", - "Batch #: 2800\n", - "loss: -2.3711659908294678\n", - "Batch #: 2900\n", - "loss: -2.3693435192108154\n", - "Batch #: 3000\n", - "loss: -2.370833396911621\n", - "Batch #: 3100\n", - "loss: -2.36956787109375\n", - "Batch #: 3200\n", - "loss: -2.3683981895446777\n", - "Batch #: 3300\n", - "loss: -2.368025064468384\n", - "Batch #: 3400\n", - "loss: -2.3673665523529053\n", - "Batch #: 3500\n", - "loss: -2.3669538497924805\n", - "Batch #: 3600\n", - "loss: -2.364877700805664\n", - "Batch #: 3700\n", - "loss: -2.3676393032073975\n", - "Batch #: 3800\n", - "loss: -2.3655707836151123\n", - "Batch #: 3900\n", - "loss: -2.365952253341675\n", - "Batch #: 4000\n", - "loss: -2.366527557373047\n", - "Batch #: 4100\n", - "loss: -2.364421844482422\n", - "Batch #: 4200\n", - "loss: -2.363978385925293\n", - "Batch #: 4300\n", - "loss: -2.3649704456329346\n", - "Batch #: 4400\n", - "loss: -2.364382743835449\n", - "Batch #: 4500\n", - "loss: -2.361299991607666\n", - "Batch #: 4600\n", - "loss: -2.3609752655029297\n", - "Batch #: 4700\n", - "loss: -2.3623459339141846\n", - "Batch #: 4800\n", - "loss: -2.3606176376342773\n", - "Batch #: 4900\n", - "loss: -2.3621227741241455\n", - "Batch #: 5000\n", - "loss: -2.3601856231689453\n", - "Batch #: 5100\n", - "loss: -2.3616325855255127\n", - "Batch #: 5200\n", - "loss: -2.3607864379882812\n", - "Batch #: 5300\n", - "loss: -2.3603267669677734\n", - "Batch #: 5400\n", - "loss: -2.3611979484558105\n", - "Batch #: 5500\n", - "loss: -2.36138653755188\n", - "Batch #: 5600\n", - "loss: -2.3617565631866455\n", - "Batch #: 5700\n", - "loss: -2.3602635860443115\n", - "Batch #: 5800\n", - "loss: -2.3588624000549316\n", - "Batch #: 5900\n", - "loss: -2.363048791885376\n", - "Batch #: 6000\n", - "loss: -2.357430934906006\n", - "Batch #: 6100\n", - "loss: -2.359692335128784\n", - "Batch #: 6200\n", - "loss: -2.359476327896118\n", - "Batch #: 6300\n", - "loss: -2.358708381652832\n", - "Batch #: 6400\n", - "loss: -2.3578848838806152\n", - "Batch #: 6500\n", - "loss: -2.3591620922088623\n", - "Batch #: 6600\n", - "loss: -2.3596458435058594\n", - "Batch #: 6700\n", - "loss: -2.358290672302246\n", - "Batch #: 6800\n", - "loss: -2.3569066524505615\n", - "Batch #: 6900\n", - "loss: -2.3586177825927734\n", - "Batch #: 7000\n", - "loss: -2.359415054321289\n", - "Batch #: 7100\n", - "loss: -2.358649969100952\n", - "Batch #: 7200\n", - "loss: -2.35966420173645\n", - "Batch #: 7300\n", - "loss: -2.358867883682251\n", - "Batch #: 7400\n", - "loss: -2.3568341732025146\n", - "Batch #: 7500\n", - "loss: -2.3596749305725098\n", - "Batch #: 7600\n", - "loss: -2.359412670135498\n", - "Batch #: 7700\n", - "loss: -2.357198476791382\n", - "Batch #: 7800\n", - "loss: -2.358001947402954\n", - "Batch #: 7900\n", - "loss: -2.3569891452789307\n", - "Batch #: 8000\n", - "loss: -2.3587193489074707\n", - "Batch #: 8100\n", - "loss: -2.3581130504608154\n", - "Batch #: 8200\n", - "loss: -2.3578381538391113\n", - "Batch #: 8300\n", - "loss: -2.357231855392456\n", - "Batch #: 8400\n", - "loss: -2.3578529357910156\n", - "Batch #: 8500\n", - "loss: -2.3557262420654297\n", - "Batch #: 8600\n", - "loss: -2.355126142501831\n", - "Batch #: 8700\n", - "loss: -2.3567700386047363\n", - "Batch #: 8800\n", - "loss: -2.3553476333618164\n", - "Batch #: 8900\n", - "loss: -2.356520175933838\n", - "Batch #: 9000\n", - "loss: -2.3572936058044434\n", - "Batch #: 9100\n", - "loss: -2.358710527420044\n", - "Batch #: 9200\n", - "loss: -2.3547816276550293\n", - "Batch #: 9300\n", - "loss: -2.3565027713775635\n", - "Batch #: 9400\n", - "loss: -2.3561108112335205\n", - "Batch #: 9500\n", - "loss: -2.356635808944702\n", - "Batch #: 9600\n", - "loss: -2.356121301651001\n", - "Batch #: 9700\n", - "loss: -2.3586411476135254\n", - "Batch #: 9800\n", - "loss: -2.3572912216186523\n", - "Batch #: 9900\n", - "loss: -2.35567045211792\n", - "Batch #: 10000\n", - "loss: -2.3584144115448\n", - "Batch #: 10100\n", - "loss: -2.3562276363372803\n", - "Batch #: 10200\n", - "loss: -2.3546085357666016\n", - "Batch #: 10300\n", - "loss: -2.3559350967407227\n", - "Batch #: 10400\n", - "loss: -2.356455087661743\n", - "Batch #: 10500\n", - "loss: -2.3574140071868896\n", - "Batch #: 10600\n", - "loss: -2.3562002182006836\n", - "Batch #: 10700\n", - "loss: -2.35746169090271\n", - "Batch #: 10800\n", - "loss: -2.3548736572265625\n", - "Batch #: 10900\n", - "loss: -2.3564090728759766\n", - "Batch #: 11000\n", - "loss: -2.3564658164978027\n", - "Batch #: 11100\n", - "loss: -2.3554699420928955\n", - "Batch #: 11200\n", - "loss: -2.3563244342803955\n", - "Batch #: 11300\n", - "loss: -2.357598066329956\n", - "Batch #: 11400\n", - "loss: -2.35477614402771\n", - "Batch #: 11500\n", - "loss: -2.3572442531585693\n", - "Batch #: 11600\n", - "loss: -2.357273817062378\n", - "Batch #: 11700\n", - "loss: -2.3560562133789062\n", - "Batch #: 11800\n", - "loss: -2.355698823928833\n", - "Batch #: 11900\n", - "loss: -2.3559463024139404\n", - "Batch #: 12000\n", - "loss: -2.35664439201355\n", - "Batch #: 12100\n", - "loss: -2.355379104614258\n", - "Batch #: 12200\n", - "loss: -2.354964256286621\n", - "Batch #: 12300\n" - ] - } - ], - "source": [ - "learning_rate = 1e-3\n", - "batch_size = 500\n", - "epochs = 25\n", - "optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)\n", - "\n", - "# run the training loop \n", - "train_loop(dataset=example_data, model=mmvec_model, optimizer=optimizer, batch_size=batch_size)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/refactor/.ipynb_checkpoints/041422pytorchdraft-checkpoint.ipynb b/examples/refactor/.ipynb_checkpoints/041422pytorchdraft-checkpoint.ipynb deleted file mode 100644 index 019843d..0000000 --- a/examples/refactor/.ipynb_checkpoints/041422pytorchdraft-checkpoint.ipynb +++ /dev/null @@ -1,495 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "213bcdfc", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "from torch.utils.data import Dataset\n", - "from torch.distributions import Multinomial\n", - "import biom" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "382bb9ce", - "metadata": {}, - "outputs": [], - "source": [ - "# some example data\n", - "microbes = biom.load_table(\"./soil_microbes.biom\")\n", - "metabolites = biom.load_table(\"./soil_metabolites.biom\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "96fac3bf", - "metadata": {}, - "outputs": [], - "source": [ - "class MicrobeMetaboliteData(Dataset):\n", - " def __init__(self, microbes: biom.table, metabolites: biom.table):\n", - " # arrange\n", - " self.microbes = microbes.to_dataframe().T \n", - " self.metabolites = metabolites.to_dataframe().T\n", - " \n", - " # only samples that have results\n", - " self.microbes = self.microbes.loc[self.metabolites.index]\n", - " \n", - " # convert to tensors/final form\n", - " self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)\n", - " self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)\n", - " \n", - " # counts\n", - " self.microbe_count = self.microbes.shape[1]\n", - " self.metabolite_count = self.metabolites.shape[1]\n", - " \n", - " # relative frequencies\n", - " self.microbe_relative_frequency = (self.microbes.T\n", - " / self.microbes.sum(1)\n", - " ).T\n", - " \n", - " self.metabolite_relative_frequency = (self.metabolites.T\n", - " / self.metabolites.sum(1)\n", - " ).T\n", - " \n", - " self.total_microbe_observations = self.microbes.sum()\n", - " \n", - " def __len__(self):\n", - " return self.total_microbe_observations" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "234ccc47", - "metadata": {}, - "outputs": [], - "source": [ - "example_data = MicrobeMetaboliteData(microbes, metabolites)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "0ab12e60", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "424846" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "example_data.total_microbe_observations.item()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f106a231", - "metadata": {}, - "outputs": [], - "source": [ - "class MMVec(nn.Module):\n", - " def __init__(self, num_microbes, num_metabolites, latent_dim):\n", - " super().__init__()\n", - " #\n", - " self.encoder = nn.Embedding(num_microbes, latent_dim)\n", - " self.decoder = nn.Sequential(\n", - " nn.Linear(latent_dim, num_metabolites),\n", - " # [batch, sample, metabolite]\n", - " nn.Softmax(dim=2)\n", - " )\n", - " \n", - " # X = batch_size of microbe indexes\n", - " # Y = expected metabolite data\n", - " def forward(self, X, Y):\n", - " \n", - " # pass our random draws to our embedding\n", - " z = self.encoder(X)\n", - " \n", - " # from latent dimensions in embedding through\n", - " # our linear function to predicted metabolite frequencies which\n", - " # we then normalize with softmax\n", - " y_pred = self.decoder(z)\n", - " \n", - " # total_count=0 and validate_args=False allows skipping total count when calling log_prob\n", - " # as there having floating point issues leading to \"incorrect\" total counts.\n", - " # This multinomial is generated from the output of the single\n", - " forward_dist = Multinomial(total_count=0,\n", - " validate_args=False,\n", - " probs=y_pred)\n", - " \n", - " # the log probability of drawing our expected results from our \"predictions\"\n", - " forward_dist = forward_dist.log_prob(Y)\n", - " \n", - " # get sample loss, a sample in each \"row\"/ zeroeth dimension of the tensor\n", - " forward_dist = forward_dist.mean(0)\n", - " \n", - " # total log probability loss in regards to all samples\n", - " lp = forward_dist.mean()\n", - "\n", - " return lp" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "b74bdf61", - "metadata": {}, - "outputs": [], - "source": [ - "mmvec_model = MMVec(example_data.microbe_count, example_data.metabolite_count, 15)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "cbc8d647", - "metadata": {}, - "outputs": [], - "source": [ - "def train_loop(dataset, model, optimizer, batch_size):\n", - " \n", - " # because we are wanting to look at all of the samples together we are having to \n", - " # handle our own batching for now. This method currently leads to slight over-\n", - " # sampling but can be refined.\n", - " n_batches = torch.div(dataset.total_microbe_observations.item(),\n", - " batch_size,\n", - " rounding_mode = 'floor') + 1\n", - " \n", - " # We will want to implement batching functionality later for\n", - " # paralizability, but for now running on cpu this works.\n", - " for batch in range(n_batches * epochs):\n", - " \n", - " # the draws we will be training each batch on that will\n", - " # be fed to all samples in our model. This step will probably be\n", - " # moved to a sampler or collate_fn somewhere in the dataset/dataloader\n", - " # but how exactly that will work is not clear at the moment\n", - " draws = torch.multinomial(dataset.microbe_relative_frequency,\n", - " batch_size,\n", - " replacement=True).T\n", - " \n", - " # \"forward step\", our model generates our \"predictions\", so there is no need to\n", - " # call `forward` separately.\n", - " lp = model(draws,\n", - " dataset.metabolite_relative_frequency)\n", - " \n", - " # this location is idiomatic but flexible\n", - " optimizer.zero_grad()\n", - " \n", - " # the typical training bit.\n", - " lp.backward()\n", - " optimizer.step()\n", - " \n", - " if batch % 100 == 0:\n", - " print(f\"loss: {lp.item()}\\nBatch #: {batch}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cfb75b21", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loss: -4.114527225494385\n", - "Batch #: 0\n", - "loss: -3.6144325733184814\n", - "Batch #: 100\n", - "loss: -3.0469698905944824\n", - "Batch #: 200\n", - "loss: -2.70939564704895\n", - "Batch #: 300\n", - "loss: -2.5499744415283203\n", - "Batch #: 400\n", - "loss: -2.473045587539673\n", - "Batch #: 500\n", - "loss: -2.4374732971191406\n", - "Batch #: 600\n", - "loss: -2.421781539916992\n", - "Batch #: 700\n", - "loss: -2.4101920127868652\n", - "Batch #: 800\n", - "loss: -2.4041030406951904\n", - "Batch #: 900\n", - "loss: -2.4012131690979004\n", - "Batch #: 1000\n", - "loss: -2.397974967956543\n", - "Batch #: 1100\n", - "loss: -2.3931915760040283\n", - "Batch #: 1200\n", - "loss: -2.3923048973083496\n", - "Batch #: 1300\n", - "loss: -2.389982223510742\n", - "Batch #: 1400\n", - "loss: -2.3868303298950195\n", - "Batch #: 1500\n", - "loss: -2.3855628967285156\n", - "Batch #: 1600\n", - "loss: -2.382643222808838\n", - "Batch #: 1700\n", - "loss: -2.381664991378784\n", - "Batch #: 1800\n", - "loss: -2.3774473667144775\n", - "Batch #: 1900\n", - "loss: -2.378610372543335\n", - "Batch #: 2000\n", - "loss: -2.3776485919952393\n", - "Batch #: 2100\n", - "loss: -2.376375675201416\n", - "Batch #: 2200\n", - "loss: -2.3723671436309814\n", - "Batch #: 2300\n", - "loss: -2.372851848602295\n", - "Batch #: 2400\n", - "loss: -2.373134136199951\n", - "Batch #: 2500\n", - "loss: -2.3704051971435547\n", - "Batch #: 2600\n", - "loss: -2.37052059173584\n", - "Batch #: 2700\n", - "loss: -2.371293306350708\n", - "Batch #: 2800\n", - "loss: -2.3711659908294678\n", - "Batch #: 2900\n", - "loss: -2.3693435192108154\n", - "Batch #: 3000\n", - "loss: -2.370833396911621\n", - "Batch #: 3100\n", - "loss: -2.36956787109375\n", - "Batch #: 3200\n", - "loss: -2.3683981895446777\n", - "Batch #: 3300\n", - "loss: -2.368025064468384\n", - "Batch #: 3400\n", - "loss: -2.3673665523529053\n", - "Batch #: 3500\n", - "loss: -2.3669538497924805\n", - "Batch #: 3600\n", - "loss: -2.364877700805664\n", - "Batch #: 3700\n", - "loss: -2.3676393032073975\n", - "Batch #: 3800\n", - "loss: -2.3655707836151123\n", - "Batch #: 3900\n", - "loss: -2.365952253341675\n", - "Batch #: 4000\n", - "loss: -2.366527557373047\n", - "Batch #: 4100\n", - "loss: -2.364421844482422\n", - "Batch #: 4200\n", - "loss: -2.363978385925293\n", - "Batch #: 4300\n", - "loss: -2.3649704456329346\n", - "Batch #: 4400\n", - "loss: -2.364382743835449\n", - "Batch #: 4500\n", - "loss: -2.361299991607666\n", - "Batch #: 4600\n", - "loss: -2.3609752655029297\n", - "Batch #: 4700\n", - "loss: -2.3623459339141846\n", - "Batch #: 4800\n", - "loss: -2.3606176376342773\n", - "Batch #: 4900\n", - "loss: -2.3621227741241455\n", - "Batch #: 5000\n", - "loss: -2.3601856231689453\n", - "Batch #: 5100\n", - "loss: -2.3616325855255127\n", - "Batch #: 5200\n", - "loss: -2.3607864379882812\n", - "Batch #: 5300\n", - "loss: -2.3603267669677734\n", - "Batch #: 5400\n", - "loss: -2.3611979484558105\n", - "Batch #: 5500\n", - "loss: -2.36138653755188\n", - "Batch #: 5600\n", - "loss: -2.3617565631866455\n", - "Batch #: 5700\n", - "loss: -2.3602635860443115\n", - "Batch #: 5800\n", - "loss: -2.3588624000549316\n", - "Batch #: 5900\n", - "loss: -2.363048791885376\n", - "Batch #: 6000\n", - "loss: -2.357430934906006\n", - "Batch #: 6100\n", - "loss: -2.359692335128784\n", - "Batch #: 6200\n", - "loss: -2.359476327896118\n", - "Batch #: 6300\n", - "loss: -2.358708381652832\n", - "Batch #: 6400\n", - "loss: -2.3578848838806152\n", - "Batch #: 6500\n", - "loss: -2.3591620922088623\n", - "Batch #: 6600\n", - "loss: -2.3596458435058594\n", - "Batch #: 6700\n", - "loss: -2.358290672302246\n", - "Batch #: 6800\n", - "loss: -2.3569066524505615\n", - "Batch #: 6900\n", - "loss: -2.3586177825927734\n", - "Batch #: 7000\n", - "loss: -2.359415054321289\n", - "Batch #: 7100\n", - "loss: -2.358649969100952\n", - "Batch #: 7200\n", - "loss: -2.35966420173645\n", - "Batch #: 7300\n", - "loss: -2.358867883682251\n", - "Batch #: 7400\n", - "loss: -2.3568341732025146\n", - "Batch #: 7500\n", - "loss: -2.3596749305725098\n", - "Batch #: 7600\n", - "loss: -2.359412670135498\n", - "Batch #: 7700\n", - "loss: -2.357198476791382\n", - "Batch #: 7800\n", - "loss: -2.358001947402954\n", - "Batch #: 7900\n", - "loss: -2.3569891452789307\n", - "Batch #: 8000\n", - "loss: -2.3587193489074707\n", - "Batch #: 8100\n", - "loss: -2.3581130504608154\n", - "Batch #: 8200\n", - "loss: -2.3578381538391113\n", - "Batch #: 8300\n", - "loss: -2.357231855392456\n", - "Batch #: 8400\n", - "loss: -2.3578529357910156\n", - "Batch #: 8500\n", - "loss: -2.3557262420654297\n", - "Batch #: 8600\n", - "loss: -2.355126142501831\n", - "Batch #: 8700\n", - "loss: -2.3567700386047363\n", - "Batch #: 8800\n", - "loss: -2.3553476333618164\n", - "Batch #: 8900\n", - "loss: -2.356520175933838\n", - "Batch #: 9000\n", - "loss: -2.3572936058044434\n", - "Batch #: 9100\n", - "loss: -2.358710527420044\n", - "Batch #: 9200\n", - "loss: -2.3547816276550293\n", - "Batch #: 9300\n", - "loss: -2.3565027713775635\n", - "Batch #: 9400\n", - "loss: -2.3561108112335205\n", - "Batch #: 9500\n", - "loss: -2.356635808944702\n", - "Batch #: 9600\n", - "loss: -2.356121301651001\n", - "Batch #: 9700\n", - "loss: -2.3586411476135254\n", - "Batch #: 9800\n", - "loss: -2.3572912216186523\n", - "Batch #: 9900\n", - "loss: -2.35567045211792\n", - "Batch #: 10000\n", - "loss: -2.3584144115448\n", - "Batch #: 10100\n", - "loss: -2.3562276363372803\n", - "Batch #: 10200\n", - "loss: -2.3546085357666016\n", - "Batch #: 10300\n", - "loss: -2.3559350967407227\n", - "Batch #: 10400\n", - "loss: -2.356455087661743\n", - "Batch #: 10500\n", - "loss: -2.3574140071868896\n", - "Batch #: 10600\n", - "loss: -2.3562002182006836\n", - "Batch #: 10700\n", - "loss: -2.35746169090271\n", - "Batch #: 10800\n", - "loss: -2.3548736572265625\n", - "Batch #: 10900\n", - "loss: -2.3564090728759766\n", - "Batch #: 11000\n", - "loss: -2.3564658164978027\n", - "Batch #: 11100\n", - "loss: -2.3554699420928955\n", - "Batch #: 11200\n", - "loss: -2.3563244342803955\n", - "Batch #: 11300\n", - "loss: -2.357598066329956\n", - "Batch #: 11400\n", - "loss: -2.35477614402771\n", - "Batch #: 11500\n", - "loss: -2.3572442531585693\n", - "Batch #: 11600\n", - "loss: -2.357273817062378\n", - "Batch #: 11700\n", - "loss: -2.3560562133789062\n", - "Batch #: 11800\n", - "loss: -2.355698823928833\n", - "Batch #: 11900\n", - "loss: -2.3559463024139404\n", - "Batch #: 12000\n", - "loss: -2.35664439201355\n", - "Batch #: 12100\n", - "loss: -2.355379104614258\n", - "Batch #: 12200\n", - "loss: -2.354964256286621\n", - "Batch #: 12300\n" - ] - } - ], - "source": [ - "learning_rate = 1e-3\n", - "batch_size = 500\n", - "epochs = 25\n", - "optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)\n", - "\n", - "# run the training loop \n", - "train_loop(dataset=example_data, model=mmvec_model, optimizer=optimizer, batch_size=batch_size)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 26377ffaf9d66e72f48f5b6b96b2354f82e68af8 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 5 May 2022 12:41:05 -0700 Subject: [PATCH 14/27] IMP: restore original mmvec/multimodal.py --- mmvec/multimodal.py | 280 +++++++++++++++++++++++++++++++++++++++ mmvec/old_multimodal.py | 283 ---------------------------------------- 2 files changed, 280 insertions(+), 283 deletions(-) create mode 100644 mmvec/multimodal.py delete mode 100644 mmvec/old_multimodal.py diff --git a/mmvec/multimodal.py b/mmvec/multimodal.py new file mode 100644 index 0000000..eb54ded --- /dev/null +++ b/mmvec/multimodal.py @@ -0,0 +1,280 @@ +import os +import time +from tqdm import tqdm +import numpy as np +import tensorflow as tf +from tensorflow.contrib.distributions import Multinomial, Normal +import datetime + + +class MMvec(object): + + def __init__(self, u_mean=0, u_scale=1, v_mean=0, v_scale=1, + batch_size=50, latent_dim=3, + learning_rate=0.1, beta_1=0.8, beta_2=0.9, + clipnorm=10., device_name='/cpu:0', save_path=None): + """ Build a tensorflow model for microbe-metabolite vectors + + Returns + ------- + loss : tf.Tensor + The log loss of the model. + + Notes + ----- + To enable a GPU, set the device to '/device:GPU:x' + where x is 0 or greater + """ + p = latent_dim + self.device_name = device_name + if save_path is None: + basename = "logdir" + suffix = datetime.datetime.now().strftime("%y%m%d_%H%M%S") + save_path = "_".join([basename, suffix]) + + self.p = p + self.u_mean = u_mean + self.u_scale = u_scale + self.v_mean = v_mean + self.v_scale = v_scale + self.batch_size = batch_size + self.latent_dim = latent_dim + + self.learning_rate = learning_rate + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.clipnorm = clipnorm + self.save_path = save_path + + def __call__(self, session, trainX, trainY, testX, testY): + """ Initialize the actual graph + + Parameters + ---------- + session : tf.Session + Tensorflow session + trainX : sparse array in coo format + Test input OTU table, where rows are samples and columns are + observations + trainY : np.array + Test output metabolite table + testX : sparse array in coo format + Test input OTU table, where rows are samples and columns are + observations. This is mainly for cross validation. + testY : np.array + Test output metabolite table. This is mainly for cross validation. + """ + self.session = session + self.nnz = len(trainX.data) + self.d1 = trainX.shape[1] + self.d2 = trainY.shape[1] + self.cv_size = len(testX.data) + + # keep the multinomial sampling on the cpu + # https://github.com/tensorflow/tensorflow/issues/18058 + with tf.device('/cpu:0'): + X_ph = tf.SparseTensor( + indices=np.array([trainX.row, trainX.col]).T, + values=trainX.data, + dense_shape=trainX.shape) + Y_ph = tf.constant(trainY, dtype=tf.float32) + + X_holdout = tf.SparseTensor( + indices=np.array([testX.row, testX.col]).T, + values=testX.data, + dense_shape=testX.shape) + Y_holdout = tf.constant(testY, dtype=tf.float32) + + total_count = tf.reduce_sum(Y_ph, axis=1) + batch_ids = tf.multinomial( + tf.log(tf.reshape(X_ph.values, [1, -1])), + self.batch_size) + batch_ids = tf.squeeze(batch_ids) + X_samples = tf.gather(X_ph.indices, 0, axis=1) + X_obs = tf.gather(X_ph.indices, 1, axis=1) + sample_ids = tf.gather(X_samples, batch_ids) + + Y_batch = tf.gather(Y_ph, sample_ids) + X_batch = tf.gather(X_obs, batch_ids) + + with tf.device(self.device_name): + self.qUmain = tf.Variable( + tf.random_normal([self.d1, self.p]), name='qU') + self.qUbias = tf.Variable( + tf.random_normal([self.d1, 1]), name='qUbias') + self.qVmain = tf.Variable( + tf.random_normal([self.p, self.d2-1]), name='qV') + self.qVbias = tf.Variable( + tf.random_normal([1, self.d2-1]), name='qVbias') + + qU = tf.concat( + [tf.ones([self.d1, 1]), self.qUbias, self.qUmain], axis=1) + qV = tf.concat( + [self.qVbias, tf.ones([1, self.d2-1]), self.qVmain], axis=0) + + # regression coefficents distribution + Umain = Normal(loc=tf.zeros([self.d1, self.p]) + self.u_mean, + scale=tf.ones([self.d1, self.p]) * self.u_scale, + name='U') + Ubias = Normal(loc=tf.zeros([self.d1, 1]) + self.u_mean, + scale=tf.ones([self.d1, 1]) * self.u_scale, + name='biasU') + + Vmain = Normal(loc=tf.zeros([self.p, self.d2-1]) + self.v_mean, + scale=tf.ones([self.p, self.d2-1]) * self.v_scale, + name='V') + Vbias = Normal(loc=tf.zeros([1, self.d2-1]) + self.v_mean, + scale=tf.ones([1, self.d2-1]) * self.v_scale, + name='biasV') + + du = tf.gather(qU, X_batch, axis=0, name='du') + dv = tf.concat([tf.zeros([self.batch_size, 1]), + du @ qV], axis=1, name='dv') + + tc = tf.gather(total_count, sample_ids) + Y = Multinomial(total_count=tc, logits=dv, name='Y') + num_samples = trainX.shape[0] + norm = num_samples / self.batch_size + logprob_vmain = tf.reduce_sum( + Vmain.log_prob(self.qVmain), name='logprob_vmain') + logprob_vbias = tf.reduce_sum( + Vbias.log_prob(self.qVbias), name='logprob_vbias') + logprob_umain = tf.reduce_sum( + Umain.log_prob(self.qUmain), name='logprob_umain') + logprob_ubias = tf.reduce_sum( + Ubias.log_prob(self.qUbias), name='logprob_ubias') + logprob_y = tf.reduce_sum(Y.log_prob(Y_batch), name='logprob_y') + self.log_loss = - ( + logprob_y * norm + + logprob_umain + logprob_ubias + + logprob_vmain + logprob_vbias + ) + + # keep the multinomial sampling on the cpu + # https://github.com/tensorflow/tensorflow/issues/18058 + with tf.device('/cpu:0'): + # cross validation + with tf.name_scope('accuracy'): + cv_batch_ids = tf.multinomial( + tf.log(tf.reshape(X_holdout.values, [1, -1])), + self.cv_size) + cv_batch_ids = tf.squeeze(cv_batch_ids) + X_cv_samples = tf.gather(X_holdout.indices, 0, axis=1) + X_cv = tf.gather(X_holdout.indices, 1, axis=1) + cv_sample_ids = tf.gather(X_cv_samples, cv_batch_ids) + + Y_cvbatch = tf.gather(Y_holdout, cv_sample_ids) + X_cvbatch = tf.gather(X_cv, cv_batch_ids) + holdout_count = tf.reduce_sum(Y_cvbatch, axis=1) + cv_du = tf.gather(qU, X_cvbatch, axis=0, name='cv_du') + pred = tf.reshape( + holdout_count, [-1, 1]) * tf.nn.softmax( + tf.concat([tf.zeros([ + self.cv_size, 1]), + cv_du @ qV], axis=1, name='pred') + ) + + self.cv = tf.reduce_mean( + tf.squeeze(tf.abs(pred - Y_cvbatch)) + ) + + # keep all summaries on the cpu + with tf.device('/cpu:0'): + tf.summary.scalar('logloss', self.log_loss) + tf.summary.scalar('cv_rmse', self.cv) + tf.summary.histogram('qUmain', self.qUmain) + tf.summary.histogram('qVmain', self.qVmain) + tf.summary.histogram('qUbias', self.qUbias) + tf.summary.histogram('qVbias', self.qVbias) + self.merged = tf.summary.merge_all() + + self.writer = tf.summary.FileWriter( + self.save_path, self.session.graph) + + with tf.device(self.device_name): + with tf.name_scope('optimize'): + optimizer = tf.train.AdamOptimizer( + self.learning_rate, beta1=self.beta_1, beta2=self.beta_2) + + gradients, self.variables = zip( + *optimizer.compute_gradients(self.log_loss)) + self.gradients, _ = tf.clip_by_global_norm( + gradients, self.clipnorm) + self.train = optimizer.apply_gradients( + zip(self.gradients, self.variables)) + + tf.global_variables_initializer().run() + + def ranks(self): + modelU = np.hstack( + (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) + modelV = np.vstack( + (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) + + res = np.hstack((np.zeros((self.U.shape[0], 1)), modelU @ modelV)) + res = res - res.mean(axis=1).reshape(-1, 1) + return res + + def fit(self, epoch=10, summary_interval=1000, checkpoint_interval=3600, + testX=None, testY=None): + """ Fits the model. + + Parameters + ---------- + epoch : int + Number of epochs to train + summary_interval : int + Number of seconds until a summary is recorded + checkpoint_interval : int + Number of seconds until a checkpoint is recorded + + Returns + ------- + loss: float + log likelihood loss. + cv : float + cross validation loss + """ + iterations = epoch * self.nnz // self.batch_size + losses, cvs = [], [] + cv = None + last_checkpoint_time = 0 + last_summary_time = 0 + saver = tf.train.Saver() + now = time.time() + for i in tqdm(range(0, iterations)): + if now - last_summary_time > summary_interval: + + res = self.session.run( + [self.train, self.merged, self.log_loss, self.cv, + self.qUmain, self.qUbias, + self.qVmain, self.qVbias] + ) + train_, summary, loss, cv, rU, rUb, rV, rVb = res + self.writer.add_summary(summary, i) + last_summary_time = now + else: + res = self.session.run( + [self.train, self.log_loss, + self.qUmain, self.qUbias, + self.qVmain, self.qVbias] + ) + train_, loss, rU, rUb, rV, rVb = res + losses.append(loss) + cvs.append(cv) + cv = None + + # checkpoint model + now = time.time() + if now - last_checkpoint_time > checkpoint_interval: + saver.save(self.session, + os.path.join(self.save_path, "model.ckpt"), + global_step=i) + last_checkpoint_time = now + + self.U = rU + self.V = rV + self.Ubias = rUb + self.Vbias = rVb + + return losses, cvs diff --git a/mmvec/old_multimodal.py b/mmvec/old_multimodal.py deleted file mode 100644 index ff741e8..0000000 --- a/mmvec/old_multimodal.py +++ /dev/null @@ -1,283 +0,0 @@ -import torch -import torch.nn as nn -from torch.distributions import Multinomial -#import os -#import time -#from tqdm import tqdm -#import numpy as np -#import tensorflow as tf -#from tensorflow.contrib.distributions import Multinomial, Normal -#import datetime -# -# -#class Old_MMvec(object): -# -# def __init__(self, u_mean=0, u_scale=1, v_mean=0, v_scale=1, -# batch_size=50, latent_dim=3, -# learning_rate=0.1, beta_1=0.8, beta_2=0.9, -# clipnorm=10., device_name='/cpu:0', save_path=None): -# """ Build a tensorflow model for microbe-metabolite vectors -# -# Returns -# ------- -# loss : tf.Tensor -# The log loss of the model. -# -# Notes -# ----- -# To enable a GPU, set the device to '/device:GPU:x' -# where x is 0 or greater -# """ -# p = latent_dim -# self.device_name = device_name -# if save_path is None: -# basename = "logdir" -# suffix = datetime.datetime.now().strftime("%y%m%d_%H%M%S") -# save_path = "_".join([basename, suffix]) -# -# self.p = p -# self.u_mean = u_mean -# self.u_scale = u_scale -# self.v_mean = v_mean -# self.v_scale = v_scale -# self.batch_size = batch_size -# self.latent_dim = latent_dim -# -# self.learning_rate = learning_rate -# self.beta_1 = beta_1 -# self.beta_2 = beta_2 -# self.clipnorm = clipnorm -# self.save_path = save_path -# -# def __call__(self, session, trainX, trainY, testX, testY): -# """ Initialize the actual graph -# -# Parameters -# ---------- -# session : tf.Session -# Tensorflow session -# trainX : sparse array in coo format -# Test input OTU table, where rows are samples and columns are -# observations -# trainY : np.array -# Test output metabolite table -# testX : sparse array in coo format -# Test input OTU table, where rows are samples and columns are -# observations. This is mainly for cross validation. -# testY : np.array -# Test output metabolite table. This is mainly for cross validation. -# """ -# self.session = session -# self.nnz = len(trainX.data) -# self.d1 = trainX.shape[1] -# self.d2 = trainY.shape[1] -# self.cv_size = len(testX.data) -# -# # keep the multinomial sampling on the cpu -# # https://github.com/tensorflow/tensorflow/issues/18058 -# with tf.device('/cpu:0'): -# X_ph = tf.SparseTensor( -# indices=np.array([trainX.row, trainX.col]).T, -# values=trainX.data, -# dense_shape=trainX.shape) -# Y_ph = tf.constant(trainY, dtype=tf.float32) -# -# X_holdout = tf.SparseTensor( -# indices=np.array([testX.row, testX.col]).T, -# values=testX.data, -# dense_shape=testX.shape) -# Y_holdout = tf.constant(testY, dtype=tf.float32) -# -# total_count = tf.reduce_sum(Y_ph, axis=1) -# batch_ids = tf.multinomial( -# tf.log(tf.reshape(X_ph.values, [1, -1])), -# self.batch_size) -# batch_ids = tf.squeeze(batch_ids) -# X_samples = tf.gather(X_ph.indices, 0, axis=1) -# X_obs = tf.gather(X_ph.indices, 1, axis=1) -# sample_ids = tf.gather(X_samples, batch_ids) -# -# Y_batch = tf.gather(Y_ph, sample_ids) -# X_batch = tf.gather(X_obs, batch_ids) -# -# with tf.device(self.device_name): -# self.qUmain = tf.Variable( -# tf.random_normal([self.d1, self.p]), name='qU') -# self.qUbias = tf.Variable( -# tf.random_normal([self.d1, 1]), name='qUbias') -# self.qVmain = tf.Variable( -# tf.random_normal([self.p, self.d2-1]), name='qV') -# self.qVbias = tf.Variable( -# tf.random_normal([1, self.d2-1]), name='qVbias') -# -# qU = tf.concat( -# [tf.ones([self.d1, 1]), self.qUbias, self.qUmain], axis=1) -# qV = tf.concat( -# [self.qVbias, tf.ones([1, self.d2-1]), self.qVmain], axis=0) -# -# # regression coefficents distribution -# Umain = Normal(loc=tf.zeros([self.d1, self.p]) + self.u_mean, -# scale=tf.ones([self.d1, self.p]) * self.u_scale, -# name='U') -# Ubias = Normal(loc=tf.zeros([self.d1, 1]) + self.u_mean, -# scale=tf.ones([self.d1, 1]) * self.u_scale, -# name='biasU') -# -# Vmain = Normal(loc=tf.zeros([self.p, self.d2-1]) + self.v_mean, -# scale=tf.ones([self.p, self.d2-1]) * self.v_scale, -# name='V') -# Vbias = Normal(loc=tf.zeros([1, self.d2-1]) + self.v_mean, -# scale=tf.ones([1, self.d2-1]) * self.v_scale, -# name='biasV') -# -# du = tf.gather(qU, X_batch, axis=0, name='du') -# dv = tf.concat([tf.zeros([self.batch_size, 1]), -# du @ qV], axis=1, name='dv') -# -# tc = tf.gather(total_count, sample_ids) -# Y = Multinomial(total_count=tc, logits=dv, name='Y') -# num_samples = trainX.shape[0] -# norm = num_samples / self.batch_size -# logprob_vmain = tf.reduce_sum( -# Vmain.log_prob(self.qVmain), name='logprob_vmain') -# logprob_vbias = tf.reduce_sum( -# Vbias.log_prob(self.qVbias), name='logprob_vbias') -# logprob_umain = tf.reduce_sum( -# Umain.log_prob(self.qUmain), name='logprob_umain') -# logprob_ubias = tf.reduce_sum( -# Ubias.log_prob(self.qUbias), name='logprob_ubias') -# logprob_y = tf.reduce_sum(Y.log_prob(Y_batch), name='logprob_y') -# self.log_loss = - ( -# logprob_y * norm + -# logprob_umain + logprob_ubias + -# logprob_vmain + logprob_vbias -# ) -# -# # keep the multinomial sampling on the cpu -# # https://github.com/tensorflow/tensorflow/issues/18058 -# with tf.device('/cpu:0'): -# # cross validation -# with tf.name_scope('accuracy'): -# cv_batch_ids = tf.multinomial( -# tf.log(tf.reshape(X_holdout.values, [1, -1])), -# self.cv_size) -# cv_batch_ids = tf.squeeze(cv_batch_ids) -# X_cv_samples = tf.gather(X_holdout.indices, 0, axis=1) -# X_cv = tf.gather(X_holdout.indices, 1, axis=1) -# cv_sample_ids = tf.gather(X_cv_samples, cv_batch_ids) -# -# Y_cvbatch = tf.gather(Y_holdout, cv_sample_ids) -# X_cvbatch = tf.gather(X_cv, cv_batch_ids) -# holdout_count = tf.reduce_sum(Y_cvbatch, axis=1) -# cv_du = tf.gather(qU, X_cvbatch, axis=0, name='cv_du') -# pred = tf.reshape( -# holdout_count, [-1, 1]) * tf.nn.softmax( -# tf.concat([tf.zeros([ -# self.cv_size, 1]), -# cv_du @ qV], axis=1, name='pred') -# ) -# -# self.cv = tf.reduce_mean( -# tf.squeeze(tf.abs(pred - Y_cvbatch)) -# ) -# -# # keep all summaries on the cpu -# with tf.device('/cpu:0'): -# tf.summary.scalar('logloss', self.log_loss) -# tf.summary.scalar('cv_rmse', self.cv) -# tf.summary.histogram('qUmain', self.qUmain) -# tf.summary.histogram('qVmain', self.qVmain) -# tf.summary.histogram('qUbias', self.qUbias) -# tf.summary.histogram('qVbias', self.qVbias) -# self.merged = tf.summary.merge_all() -# -# self.writer = tf.summary.FileWriter( -# self.save_path, self.session.graph) -# -# with tf.device(self.device_name): -# with tf.name_scope('optimize'): -# optimizer = tf.train.AdamOptimizer( -# self.learning_rate, beta1=self.beta_1, beta2=self.beta_2) -# -# gradients, self.variables = zip( -# *optimizer.compute_gradients(self.log_loss)) -# self.gradients, _ = tf.clip_by_global_norm( -# gradients, self.clipnorm) -# self.train = optimizer.apply_gradients( -# zip(self.gradients, self.variables)) -# -# tf.global_variables_initializer().run() -# -# def ranks(self): -# modelU = np.hstack( -# (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) -# modelV = np.vstack( -# (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) -# -# res = np.hstack((np.zeros((self.U.shape[0], 1)), modelU @ modelV)) -# res = res - res.mean(axis=1).reshape(-1, 1) -# return res -# -# def fit(self, epoch=10, summary_interval=1000, checkpoint_interval=3600, -# testX=None, testY=None): -# """ Fits the model. -# -# Parameters -# ---------- -# epoch : int -# Number of epochs to train -# summary_interval : int -# Number of seconds until a summary is recorded -# checkpoint_interval : int -# Number of seconds until a checkpoint is recorded -# -# Returns -# ------- -# loss: float -# log likelihood loss. -# cv : float -# cross validation loss -# """ -# iterations = epoch * self.nnz // self.batch_size -# losses, cvs = [], [] -# cv = None -# last_checkpoint_time = 0 -# last_summary_time = 0 -# saver = tf.train.Saver() -# now = time.time() -# for i in tqdm(range(0, iterations)): -# if now - last_summary_time > summary_interval: -# -# res = self.session.run( -# [self.train, self.merged, self.log_loss, self.cv, -# self.qUmain, self.qUbias, -# self.qVmain, self.qVbias] -# ) -# train_, summary, loss, cv, rU, rUb, rV, rVb = res -# self.writer.add_summary(summary, i) -# last_summary_time = now -# else: -# res = self.session.run( -# [self.train, self.log_loss, -# self.qUmain, self.qUbias, -# self.qVmain, self.qVbias] -# ) -# train_, loss, rU, rUb, rV, rVb = res -# losses.append(loss) -# cvs.append(cv) -# cv = None -# -# # checkpoint model -# now = time.time() -# if now - last_checkpoint_time > checkpoint_interval: -# saver.save(self.session, -# os.path.join(self.save_path, "model.ckpt"), -# global_step=i) -# last_checkpoint_time = now -# -# self.U = rU -# self.V = rV -# self.Ubias = rUb -# self.Vbias = rVb -# -# return losses, cvs From 5cf2dccdebea341d3027c0597105a5ca4ec37b3a Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 5 May 2022 12:42:40 -0700 Subject: [PATCH 15/27] IMP: restore q2 goodies --- mmvec/q2/_transformers.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 mmvec/q2/_transformers.py diff --git a/mmvec/q2/_transformers.py b/mmvec/q2/_transformers.py deleted file mode 100644 index e69de29..0000000 From 15bfa638eb1f465ef491db1a2c5fffb9c6a1bc9e Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 5 May 2022 12:43:27 -0700 Subject: [PATCH 16/27] IMP: q2 goodies --- mmvec/q2/__init__.py | 12 ++ mmvec/q2/_method.py | 126 +++++++++++++++ mmvec/q2/_stats.py | 30 ++++ mmvec/q2/_summary.py | 121 ++++++++++++++ mmvec/q2/_transformer.py | 36 +++++ mmvec/q2/_visualizers.py | 88 ++++++++++ mmvec/q2/assets/index.html | 28 ++++ mmvec/q2/plugin_setup.py | 252 +++++++++++++++++++++++++++++ mmvec/q2/tests/test_method.py | 98 +++++++++++ mmvec/q2/tests/test_visualizers.py | 97 +++++++++++ 10 files changed, 888 insertions(+) create mode 100644 mmvec/q2/__init__.py create mode 100644 mmvec/q2/_method.py create mode 100644 mmvec/q2/_stats.py create mode 100644 mmvec/q2/_summary.py create mode 100644 mmvec/q2/_transformer.py create mode 100644 mmvec/q2/_visualizers.py create mode 100644 mmvec/q2/assets/index.html create mode 100644 mmvec/q2/plugin_setup.py create mode 100644 mmvec/q2/tests/test_method.py create mode 100644 mmvec/q2/tests/test_visualizers.py diff --git a/mmvec/q2/__init__.py b/mmvec/q2/__init__.py new file mode 100644 index 0000000..c8d78a9 --- /dev/null +++ b/mmvec/q2/__init__.py @@ -0,0 +1,12 @@ +from ._stats import (Conditional, ConditionalDirFmt, ConditionalFormat, + MMvecStats, MMvecStatsFormat, MMvecStatsDirFmt) +from ._method import paired_omics +from ._visualizers import heatmap, paired_heatmap +from ._summary import summarize_single, summarize_paired + + +__all__ = ['paired_omics', + 'Conditional', 'ConditionalFormat', 'ConditionalDirFmt', + 'MMvecStats', 'MMvecStatsFormat', 'MMvecStatsDirFmt', + 'heatmap', 'paired_heatmap', + 'summarize_single', 'summarize_paired'] diff --git a/mmvec/q2/_method.py b/mmvec/q2/_method.py new file mode 100644 index 0000000..bb6dfbf --- /dev/null +++ b/mmvec/q2/_method.py @@ -0,0 +1,126 @@ +import biom +import pandas as pd +import numpy as np +import tensorflow as tf +from skbio import OrdinationResults +import qiime2 +from qiime2.plugin import Metadata +from mmvec.multimodal import MMvec +from mmvec.util import split_tables +from scipy.sparse import coo_matrix +from scipy.sparse.linalg import svds + + +def paired_omics(microbes: biom.Table, + metabolites: biom.Table, + metadata: Metadata = None, + training_column: str = None, + num_testing_examples: int = 5, + min_feature_count: int = 10, + epochs: int = 100, + batch_size: int = 50, + latent_dim: int = 3, + input_prior: float = 1, + output_prior: float = 1, + learning_rate: float = 1e-3, + equalize_biplot: float = False, + arm_the_gpu: bool = False, + summary_interval: int = 60) -> ( + pd.DataFrame, OrdinationResults, qiime2.Metadata + ): + + if metadata is not None: + metadata = metadata.to_dataframe() + + if arm_the_gpu: + # pick out the first GPU + device_name = '/device:GPU:0' + else: + device_name = '/cpu:0' + + # Note: there are a couple of biom -> pandas conversions taking + # place here. This is currently done on purpose, since we + # haven't figured out how to handle sparse matrix multiplication + # in the context of this algorithm. That is a future consideration. + res = split_tables( + microbes, metabolites, + metadata=metadata, training_column=training_column, + num_test=num_testing_examples, + min_samples=min_feature_count) + + (train_microbes_df, test_microbes_df, + train_metabolites_df, test_metabolites_df) = res + + train_microbes_coo = coo_matrix(train_microbes_df.values) + test_microbes_coo = coo_matrix(test_microbes_df.values) + + with tf.Graph().as_default(), tf.Session() as session: + model = MMvec( + latent_dim=latent_dim, + u_scale=input_prior, v_scale=output_prior, + batch_size=batch_size, + device_name=device_name, + learning_rate=learning_rate) + model(session, + train_microbes_coo, train_metabolites_df.values, + test_microbes_coo, test_metabolites_df.values) + + loss, cv = model.fit(epoch=epochs, summary_interval=summary_interval) + ranks = pd.DataFrame(model.ranks(), index=train_microbes_df.columns, + columns=train_metabolites_df.columns) + if latent_dim > 0: + u, s, v = svds(ranks - ranks.mean(axis=0), k=latent_dim) + else: + # fake it until you make it + u, s, v = svds(ranks - ranks.mean(axis=0), k=1) + + ranks = ranks.T + ranks.index.name = 'featureid' + s = s[::-1] + u = u[:, ::-1] + v = v[::-1, :] + if equalize_biplot: + microbe_embed = u @ np.sqrt(np.diag(s)) + metabolite_embed = v.T @ np.sqrt(np.diag(s)) + else: + microbe_embed = u @ np.diag(s) + metabolite_embed = v.T + + pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] + features = pd.DataFrame( + microbe_embed, columns=pc_ids, + index=train_microbes_df.columns) + samples = pd.DataFrame( + metabolite_embed, columns=pc_ids, + index=train_metabolites_df.columns) + short_method_name = 'mmvec biplot' + long_method_name = 'Multiomics mmvec biplot' + eigvals = pd.Series(s, index=pc_ids) + proportion_explained = pd.Series(s**2 / np.sum(s**2), index=pc_ids) + biplot = OrdinationResults( + short_method_name, long_method_name, eigvals, + samples=samples, features=features, + proportion_explained=proportion_explained) + + its = np.arange(len(loss)) + convergence_stats = pd.DataFrame( + { + 'loss': loss, + 'cross-validation': cv, + 'iteration': its + } + ) + + convergence_stats.index.name = 'id' + convergence_stats.index = convergence_stats.index.astype(np.str) + + c = convergence_stats['loss'].astype(np.float) + convergence_stats['loss'] = c + + c = convergence_stats['cross-validation'].astype(np.float) + convergence_stats['cross-validation'] = c + + c = convergence_stats['iteration'].astype(np.int) + convergence_stats['iteration'] = c + + return ranks, biplot, qiime2.Metadata(convergence_stats) diff --git a/mmvec/q2/_stats.py b/mmvec/q2/_stats.py new file mode 100644 index 0000000..980e937 --- /dev/null +++ b/mmvec/q2/_stats.py @@ -0,0 +1,30 @@ +from qiime2.plugin import SemanticType, model +from q2_types.feature_data import FeatureData +from q2_types.sample_data import SampleData + + +Conditional = SemanticType('Conditional', + variant_of=FeatureData.field['type']) + + +class ConditionalFormat(model.TextFileFormat): + def validate(*args): + pass + + +ConditionalDirFmt = model.SingleFileDirectoryFormat( + 'ConditionalDirFmt', 'conditionals.tsv', ConditionalFormat) + + +# songbird stats summarizing loss and cv error +MMvecStats = SemanticType('MMvecStats', + variant_of=SampleData.field['type']) + + +class MMvecStatsFormat(model.TextFileFormat): + def validate(*args): + pass + + +MMvecStatsDirFmt = model.SingleFileDirectoryFormat( + 'MMvecStatsDirFmt', 'stats.tsv', MMvecStatsFormat) diff --git a/mmvec/q2/_summary.py b/mmvec/q2/_summary.py new file mode 100644 index 0000000..524b7a9 --- /dev/null +++ b/mmvec/q2/_summary.py @@ -0,0 +1,121 @@ +import os +import qiime2 +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +def _convergence_plot(model, baseline, ax0, ax1): + iterations = np.array(model['iteration']) + cv_model = model.dropna() + ax0.plot(cv_model['iteration'][1:], + np.array(cv_model['cross-validation'].values)[1:], + label='model') + ax0.set_ylabel('Cross validation score', fontsize=14) + ax0.set_xlabel('# Iterations', fontsize=14) + + ax1.plot(iterations[1:], + np.array(model['loss'])[1:], label='model') + ax1.set_ylabel('Loss', fontsize=14) + ax1.set_xlabel('# Iterations', fontsize=14) + + if baseline is not None: + iterations = baseline['iteration'] + cv_baseline = baseline.dropna() + ax0.plot(cv_baseline['iteration'][1:], + np.array(cv_baseline['cross-validation'].values)[1:], + label='baseline') + ax0.set_ylabel('Cross validation score', fontsize=14) + ax0.set_xlabel('# Iterations', fontsize=14) + ax0.legend() + + ax1.plot(iterations[1:], + np.array(baseline['loss'])[1:], label='baseline') + ax1.set_ylabel('Loss', fontsize=14) + ax1.set_xlabel('# Iterations', fontsize=14) + ax1.legend() + + +def _summarize(output_dir: str, model: pd.DataFrame, + baseline: pd.DataFrame = None): + + """ Helper method for generating summary pages + Parameters + ---------- + output_dir : str + Name of output directory + model : pd.DataFrame + Model summary with column names + ['loss', 'cross-validation'] + baseline : pd.DataFrame + Baseline model summary with column names + ['loss', 'cross-validation']. Defaults to None (i.e. if only a single + set of model stats will be summarized). + Note + ---- + There may be synchronizing issues if different summary intervals + were used between analyses. For predictable results, try to use the + same summary interval. + """ + fig, ax = plt.subplots(2, 1, figsize=(10, 10)) + if baseline is None: + _convergence_plot(model, None, ax[0], ax[1]) + q2 = None + else: + + _convergence_plot(model, baseline, ax[0], ax[1]) + + # this provides a pseudo-r2 commonly provided in the context + # of logistic / multinomail model (proposed by Cox & Snell) + # http://www3.stat.sinica.edu.tw/statistica/oldpdf/a16n39.pdf + end = min(10, len(model.index)) + # trim only the last 10 numbers + + # compute a q2 score, which is commonly used in + # partial least squares for cross validation + cv_model = model.dropna() + cv_baseline = baseline.dropna() + + l0 = np.mean(cv_baseline['cross-validation'][-end:]) + lm = np.mean(cv_model['cross-validation'][-end:]) + q2 = 1 - lm / l0 + + plt.tight_layout() + fig.savefig(os.path.join(output_dir, 'convergence-plot.svg')) + fig.savefig(os.path.join(output_dir, 'convergence-plot.pdf')) + + index_fp = os.path.join(output_dir, 'index.html') + with open(index_fp, 'w') as index_f: + index_f.write('\n') + index_f.write('

Convergence summary

\n') + index_f.write( + "

If you don't see anything in these plots, you probably need " + "to decrease your --p-summary-interval. Try setting " + "--p-summary-interval 1, which will record the loss at " + "every second.

\n" + ) + + if q2 is not None: + index_f.write( + '

' + '' + 'Pseudo Q-squared: %f

\n' % q2 + ) + + index_f.write( + 'convergence_plots' + ) + index_f.write('') + index_f.write('Download as PDF
\n') + + +def summarize_single(output_dir: str, model_stats: qiime2.Metadata): + _summarize(output_dir, model_stats.to_dataframe()) + + +def summarize_paired(output_dir: str, + model_stats: qiime2.Metadata, + baseline_stats: qiime2.Metadata): + _summarize(output_dir, + model_stats.to_dataframe(), + baseline_stats.to_dataframe()) diff --git a/mmvec/q2/_transformer.py b/mmvec/q2/_transformer.py new file mode 100644 index 0000000..7c304df --- /dev/null +++ b/mmvec/q2/_transformer.py @@ -0,0 +1,36 @@ +import qiime2 +import pandas as pd + +from mmvec.q2 import ConditionalFormat, MMvecStatsFormat +from mmvec.q2.plugin_setup import plugin + + +@plugin.register_transformer +def _1(ff: ConditionalFormat) -> pd.DataFrame: + df = pd.read_csv(str(ff), sep='\t', comment='#', skip_blank_lines=True, + header=0, index_col=0) + return df + + +@plugin.register_transformer +def _2(df: pd.DataFrame) -> ConditionalFormat: + ff = ConditionalFormat() + df.to_csv(str(ff), sep='\t', header=True, index=True) + return ff + + +@plugin.register_transformer +def _3(ff: ConditionalFormat) -> qiime2.Metadata: + return qiime2.Metadata.load(str(ff)) + + +@plugin.register_transformer +def _4(obj: qiime2.Metadata) -> MMvecStatsFormat: + ff = MMvecStatsFormat() + obj.save(str(ff)) + return ff + + +@plugin.register_transformer +def _5(ff: MMvecStatsFormat) -> qiime2.Metadata: + return qiime2.Metadata.load(str(ff)) diff --git a/mmvec/q2/_visualizers.py b/mmvec/q2/_visualizers.py new file mode 100644 index 0000000..6861768 --- /dev/null +++ b/mmvec/q2/_visualizers.py @@ -0,0 +1,88 @@ +from os.path import join +import pandas as pd +import qiime2 +import biom +import pkg_resources +import q2templates +from mmvec.heatmap import ranks_heatmap, paired_heatmaps + + +TEMPLATES = pkg_resources.resource_filename('mmvec.q2', 'assets') + + +def heatmap(output_dir: str, + ranks: pd.DataFrame, + microbe_metadata: qiime2.CategoricalMetadataColumn = None, + metabolite_metadata: qiime2.CategoricalMetadataColumn = None, + method: str = 'average', + metric: str = 'euclidean', + color_palette: str = 'seismic', + margin_palette: str = 'cubehelix', + x_labels: bool = False, + y_labels: bool = False, + level: int = -1, + row_center: bool = True) -> None: + if microbe_metadata is not None: + microbe_metadata = microbe_metadata.to_series() + if metabolite_metadata is not None: + metabolite_metadata = metabolite_metadata.to_series() + ranks = ranks.T + + if row_center: + ranks = ranks - ranks.mean(axis=0) + + hotmap = ranks_heatmap(ranks, microbe_metadata, metabolite_metadata, + method, metric, color_palette, margin_palette, + x_labels, y_labels, level) + + hotmap.savefig(join(output_dir, 'heatmap.pdf'), bbox_inches='tight') + hotmap.savefig(join(output_dir, 'heatmap.png'), bbox_inches='tight') + + index = join(TEMPLATES, 'index.html') + q2templates.render(index, output_dir, context={ + 'title': 'Rank Heatmap', + 'pdf_fp': 'heatmap.pdf', + 'png_fp': 'heatmap.png'}) + + +def paired_heatmap(output_dir: str, + ranks: pd.DataFrame, + microbes_table: biom.Table, + metabolites_table: biom.Table, + features: str = None, + top_k_microbes: int = 2, + keep_top_samples: bool = True, + microbe_metadata: qiime2.CategoricalMetadataColumn = None, + normalize: str = 'log10', + color_palette: str = 'magma', + top_k_metabolites: int = 50, + level: int = -1, + row_center: bool = True) -> None: + if microbe_metadata is not None: + microbe_metadata = microbe_metadata.to_series() + + ranks = ranks.T + + if row_center: + ranks = ranks - ranks.mean(axis=0) + + select_microbes, select_metabolites, hotmaps = paired_heatmaps( + ranks, microbes_table, metabolites_table, microbe_metadata, features, + top_k_microbes, top_k_metabolites, keep_top_samples, level, normalize, + color_palette) + + hotmaps.savefig(join(output_dir, 'heatmap.pdf'), bbox_inches='tight') + hotmaps.savefig(join(output_dir, 'heatmap.png'), bbox_inches='tight') + select_microbes.to_csv(join(output_dir, 'select_microbes.tsv'), sep='\t') + select_metabolites.to_csv( + join(output_dir, 'select_metabolites.tsv'), sep='\t') + + index = join(TEMPLATES, 'index.html') + q2templates.render(index, output_dir, context={ + 'title': 'Paired Feature Abundance Heatmaps', + 'pdf_fp': 'heatmap.pdf', + 'png_fp': 'heatmap.png', + 'table1_fp': 'select_microbes.tsv', + 'download1_text': 'Download microbe abundances as TSV', + 'table2_fp': 'select_metabolites.tsv', + 'download2_text': 'Download top k metabolite abundances as TSV'}) diff --git a/mmvec/q2/assets/index.html b/mmvec/q2/assets/index.html new file mode 100644 index 0000000..a752d3b --- /dev/null +++ b/mmvec/q2/assets/index.html @@ -0,0 +1,28 @@ +{% extends 'base.html' %} + +{% block title %}rhapsody : {{ title }}{% endblock %} + +{% block fixed %}{% endblock %} + +{% block content %} + +
+

{{ title }}

+ +
+ +{% endblock %} diff --git a/mmvec/q2/plugin_setup.py b/mmvec/q2/plugin_setup.py new file mode 100644 index 0000000..4285418 --- /dev/null +++ b/mmvec/q2/plugin_setup.py @@ -0,0 +1,252 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2016--, gneiss development team. +# +# Distributed under the terms of the Modified BSD License. +# +# The full license is in the file COPYING.txt, distributed with this software. +# ---------------------------------------------------------------------------- +import importlib +import qiime2.plugin +import qiime2.sdk +from mmvec import __version__, _heatmap_choices, _cmaps +from qiime2.plugin import (Str, Properties, Int, Float, Metadata, Bool, + MetadataColumn, Categorical, Range, Choices, List) +from q2_types.feature_table import FeatureTable, Frequency +from q2_types.feature_data import FeatureData +from q2_types.sample_data import SampleData +from q2_types.ordination import PCoAResults +from mmvec.q2 import ( + Conditional, ConditionalFormat, ConditionalDirFmt, + MMvecStats, MMvecStatsFormat, MMvecStatsDirFmt, + paired_omics, heatmap, paired_heatmap, summarize_single, summarize_paired +) + +plugin = qiime2.plugin.Plugin( + name='mmvec', + version=__version__, + website="https://github.com/biocore/mmvec", + short_description='Plugin for performing microbe-metabolite ' + 'co-occurence analysis.', + description='This is a QIIME 2 plugin supporting microbe-metabolite ' + 'co-occurence analysis using mmvec.', + package='mmvec') + +plugin.methods.register_function( + function=paired_omics, + inputs={'microbes': FeatureTable[Frequency], + 'metabolites': FeatureTable[Frequency]}, + parameters={ + 'metadata': Metadata, + 'training_column': Str, + 'num_testing_examples': Int, + 'min_feature_count': Int, + 'epochs': Int, + 'batch_size': Int, + 'arm_the_gpu': Bool, + 'latent_dim': Int, + 'input_prior': Float, + 'output_prior': Float, + 'learning_rate': Float, + 'equalize_biplot': Bool, + 'summary_interval': Int + }, + outputs=[ + ('conditionals', FeatureData[Conditional]), + ('conditional_biplot', PCoAResults % Properties('biplot')), + ('model_stats', SampleData[MMvecStats]), + ], + input_descriptions={ + 'microbes': 'Input table of microbial counts.', + 'metabolites': 'Input table of metabolite intensities.', + }, + output_descriptions={ + 'conditionals': 'Mean-centered Conditional log-probabilities.', + 'conditional_biplot': 'Biplot of microbe-metabolite vectors.', + }, + parameter_descriptions={ + 'metadata': 'Sample metadata table with covariates of interest.', + 'training_column': "The metadata column specifying which " + "samples are for training/testing. " + "Entries must be marked `Train` for training " + "examples and `Test` for testing examples. ", + 'num_testing_examples': "The number of random examples to select " + "if `training_column` isn't specified.", + 'epochs': 'The total number of iterations over the entire dataset.', + 'equalize_biplot': 'Biplot arrows and points are on the same scale.', + 'batch_size': 'The number of samples to be evaluated per ' + 'training iteration.', + 'arm_the_gpu': 'Specifies whether or not to use the GPU.', + 'input_prior': 'Width of normal prior for the microbial ' + 'coefficients. Smaller values will regularize ' + 'parameters towards zero. Values must be greater ' + 'than 0.', + 'output_prior': 'Width of normal prior for the metabolite ' + 'coefficients. Smaller values will regularize ' + 'parameters towards zero. Values must be greater ' + 'than 0.', + 'learning_rate': 'Gradient descent decay rate.' + }, + name='Microbe metabolite vectors', + description="Performs bi-loglinear multinomial regression and calculates " + "the conditional probability ranks of metabolite " + "co-occurence given the microbe presence.", + citations=[] +) + +plugin.visualizers.register_function( + function=heatmap, + inputs={'ranks': FeatureData[Conditional]}, + parameters={ + 'microbe_metadata': MetadataColumn[Categorical], + 'metabolite_metadata': MetadataColumn[Categorical], + 'method': Str % Choices(_heatmap_choices['method']), + 'metric': Str % Choices(_heatmap_choices['metric']), + 'color_palette': Str % Choices(_cmaps['heatmap']), + 'margin_palette': Str % Choices(_cmaps['margins']), + 'x_labels': Bool, + 'y_labels': Bool, + 'level': Int % Range(-1, None), + 'row_center': Bool, + }, + input_descriptions={'ranks': 'Conditional probabilities.'}, + parameter_descriptions={ + 'microbe_metadata': 'Optional microbe metadata for annotating plots.', + 'metabolite_metadata': 'Optional metabolite metadata for annotating ' + 'plots.', + 'method': 'Hierarchical clustering method used in clustermap.', + 'metric': 'Distance metric used in clustermap.', + 'color_palette': 'Color palette for clustermap.', + 'margin_palette': 'Name of color palette to use for annotating ' + 'metadata along margin(s) of clustermap.', + 'x_labels': 'Plot x-axis (metabolite) labels?', + 'y_labels': 'Plot y-axis (microbe) labels?', + 'level': 'taxonomic level for annotating clustermap. Set to -1 if not ' + 'parsing semicolon-delimited taxonomies or wish to print ' + 'entire annotation.', + 'row_center': 'Center conditional probability table ' + 'around average row.' + }, + name='Conditional probability heatmap', + description="Generate heatmap depicting mmvec conditional probabilities.", + citations=[] +) + +plugin.visualizers.register_function( + function=paired_heatmap, + inputs={'ranks': FeatureData[Conditional], + 'microbes_table': FeatureTable[Frequency], + 'metabolites_table': FeatureTable[Frequency]}, + parameters={ + 'microbe_metadata': MetadataColumn[Categorical], + 'features': List[Str], + 'top_k_microbes': Int % Range(0, None), + 'color_palette': Str % Choices(_cmaps['heatmap']), + 'normalize': Str % Choices(['log10', 'z_score_col', 'z_score_row', + 'rel_row', 'rel_col', 'None']), + 'top_k_metabolites': Int % Range(1, None) | Str % Choices(['all']), + 'keep_top_samples': Bool, + 'level': Int % Range(-1, None), + 'row_center': Bool, + }, + input_descriptions={'ranks': 'Conditional probabilities.', + 'microbes_table': 'Microbial feature abundances.', + 'metabolites_table': 'Metabolite feature abundances.'}, + parameter_descriptions={ + 'microbe_metadata': 'Optional microbe metadata for annotating plots.', + 'features': 'Microbial feature IDs to display in heatmap. Use this ' + 'parameter to include named feature IDs in the heatmap. ' + 'Can be used in conjunction with top_k_microbes, in which ' + 'case named features will be displayed first, then top ' + 'microbial features in order of log conditional ' + 'probability maximum values.', + 'top_k_microbes': 'Select top k microbes (those with the highest ' + 'relative abundances) to display on the heatmap. ' + 'Set to "all" to display all metabolites.', + 'color_palette': 'Color palette for clustermap.', + 'normalize': 'Optionally normalize heatmap values by columns or rows.', + 'top_k_metabolites': 'Select top k metabolites associated with each ' + 'of the chosen features to display on heatmap.', + 'keep_top_samples': 'Display only samples in which at least one of ' + 'the selected microbes is the most abundant ' + 'feature.', + 'level': 'taxonomic level for annotating clustermap. Set to -1 if not ' + 'parsing semicolon-delimited taxonomies or wish to print ' + 'entire annotation.', + 'row_center': 'Center conditional probability table ' + 'around average row.' + }, + name='Paired feature abundance heatmaps', + description="Generate paired heatmaps that depict microbial and " + "metabolite feature abundances. The left panel displays the " + "abundance of each selected microbial feature in each sample. " + "The right panel displays the abundances of the top k " + "metabolites most highly correlated with these microbes in " + "each sample. The y-axis (sample axis) is shared between each " + "panel.", + citations=[] +) + + +plugin.visualizers.register_function( + function=summarize_single, + inputs={ + 'model_stats': SampleData[MMvecStats] + }, + parameters={}, + input_descriptions={ + 'model_stats': ( + "Summary information produced by running " + "`qiime mmvec paired-omics`." + ) + }, + parameter_descriptions={ + }, + name='MMvec summary statistics', + description=( + "Visualize the convergence statistics from running " + "`qiime mmvec paired-omics`, giving insight " + "into how the model fit to your data." + ) +) + +plugin.visualizers.register_function( + function=summarize_paired, + inputs={ + 'model_stats': SampleData[MMvecStats], + 'baseline_stats': SampleData[MMvecStats] + }, + parameters={}, + input_descriptions={ + + 'model_stats': ( + "Summary information for the reference model, produced by running " + "`qiime mmvec paired-omics`." + ), + 'baseline_stats': ( + "Summary information for the baseline model, produced by running " + "`qiime mmvec paired-omics`." + ) + + }, + parameter_descriptions={ + }, + name='Paired MMvec summary statistics', + description=( + "Visualize the convergence statistics from two MMvec models, " + "giving insight into how the models fit to your data. " + "The produced visualization includes a 'pseudo-Q-squared' value." + ) +) + +# Register types +plugin.register_formats(MMvecStatsFormat, MMvecStatsDirFmt) +plugin.register_semantic_types(MMvecStats) +plugin.register_semantic_type_to_format( + SampleData[MMvecStats], MMvecStatsDirFmt) + +plugin.register_formats(ConditionalFormat, ConditionalDirFmt) +plugin.register_semantic_types(Conditional) +plugin.register_semantic_type_to_format( + FeatureData[Conditional], ConditionalDirFmt) + +importlib.import_module('mmvec.q2._transformer') diff --git a/mmvec/q2/tests/test_method.py b/mmvec/q2/tests/test_method.py new file mode 100644 index 0000000..2bae849 --- /dev/null +++ b/mmvec/q2/tests/test_method.py @@ -0,0 +1,98 @@ +import biom +import unittest +import numpy as np +import tensorflow as tf +from mmvec.q2._method import paired_omics +from mmvec.util import random_multimodal +from skbio.stats.composition import clr_inv +from scipy.stats import spearmanr +import numpy.testing as npt + + +class TestMMvec(unittest.TestCase): + + def setUp(self): + np.random.seed(1) + res = random_multimodal( + num_microbes=8, num_metabolites=8, num_samples=150, + latent_dim=2, sigmaQ=2, + microbe_total=1000, metabolite_total=10000, seed=1 + ) + (self.microbes, self.metabolites, self.X, self.B, + self.U, self.Ubias, self.V, self.Vbias) = res + n, d1 = self.microbes.shape + n, d2 = self.metabolites.shape + + self.microbes = biom.Table(self.microbes.values.T, + self.microbes.columns, + self.microbes.index) + self.metabolites = biom.Table(self.metabolites.values.T, + self.metabolites.columns, + self.metabolites.index) + U_ = np.hstack( + (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) + V_ = np.vstack( + (self.Vbias, np.ones((1, self.V.shape[1])), self.V)) + + uv = U_ @ V_ + h = np.zeros((d1, 1)) + self.exp_ranks = clr_inv(np.hstack((h, uv))) + + def test_fit(self): + np.random.seed(1) + tf.reset_default_graph() + tf.set_random_seed(0) + latent_dim = 2 + res_ranks, res_biplot, _ = paired_omics( + self.microbes, self.metabolites, + epochs=1000, latent_dim=latent_dim, + min_feature_count=1, learning_rate=0.1 + ) + res_ranks = clr_inv(res_ranks.T) + s_r, s_p = spearmanr(np.ravel(res_ranks), np.ravel(self.exp_ranks)) + + self.assertGreater(s_r, 0.5) + self.assertLess(s_p, 1e-2) + + # make sure the biplot is of the correct dimensions + npt.assert_allclose( + res_biplot.samples.shape, + np.array([self.microbes.shape[0], latent_dim])) + npt.assert_allclose( + res_biplot.features.shape, + np.array([self.metabolites.shape[0], latent_dim])) + + # make sure that the biplot has the correct ordering + self.assertGreater(res_biplot.proportion_explained[0], + res_biplot.proportion_explained[1]) + self.assertGreater(res_biplot.eigvals[0], + res_biplot.eigvals[1]) + + def test_equalize_sv(self): + np.random.seed(1) + tf.reset_default_graph() + tf.set_random_seed(0) + latent_dim = 2 + res_ranks, res_biplot, _ = paired_omics( + self.microbes, self.metabolites, + epochs=1000, latent_dim=latent_dim, + min_feature_count=1, learning_rate=0.1, + equalize_biplot=True + ) + # make sure the biplot is of the correct dimensions + npt.assert_allclose( + res_biplot.samples.shape, + np.array([self.microbes.shape[0], latent_dim])) + npt.assert_allclose( + res_biplot.features.shape, + np.array([self.metabolites.shape[0], latent_dim])) + + # make sure that the biplot has the correct ordering + self.assertGreater(res_biplot.proportion_explained[0], + res_biplot.proportion_explained[1]) + self.assertGreater(res_biplot.eigvals[0], + res_biplot.eigvals[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/mmvec/q2/tests/test_visualizers.py b/mmvec/q2/tests/test_visualizers.py new file mode 100644 index 0000000..6171670 --- /dev/null +++ b/mmvec/q2/tests/test_visualizers.py @@ -0,0 +1,97 @@ +import unittest +import pandas as pd +from qiime2 import Artifact, CategoricalMetadataColumn +from qiime2.plugins import mmvec +import biom +import numpy as np + + +# these tests just make sure the visualizer runs; nuts + bolts are tested in +# the main package. +class TestHeatmap(unittest.TestCase): + + def setUp(self): + _ranks = pd.DataFrame([[4.1, 1.3, 2.1], [0.1, 0.3, 0.2], + [2.2, 4.3, 3.2], [-6.3, -4.4, 2.1]], + index=pd.Index([c for c in 'ABCD'], name='id'), + columns=['m1', 'm2', 'm3']).T + self.ranks = Artifact.import_data('FeatureData[Conditional]', _ranks) + self.taxa = CategoricalMetadataColumn(pd.Series([ + 'k__Bacteria; p__Proteobacteria; c__Deltaproteobacteria; ' + 'o__Desulfobacterales; f__Desulfobulbaceae; g__; s__', + 'k__Bacteria; p__Cyanobacteria; c__Chloroplast; o__Streptophyta', + 'k__Bacteria; p__Proteobacteria; c__Alphaproteobacteria; ' + 'o__Rickettsiales; f__mitochondria; g__Lardizabala; s__biternata', + 'k__Archaea; p__Euryarchaeota; c__Methanomicrobia; ' + 'o__Methanosarcinales; f__Methanosarcinaceae; g__Methanosarcina'], + index=pd.Index([c for c in 'ABCD'], name='feature-id'), + name='Taxon')) + self.metabolites = CategoricalMetadataColumn(pd.Series([ + 'amino acid', 'carbohydrate', 'drug metabolism'], + index=pd.Index(['m1', 'm2', 'm3'], name='feature-id'), + name='Super Pathway')) + + def test_heatmap_default(self): + mmvec.actions.heatmap(self.ranks, self.taxa, self.metabolites) + + def test_heatmap_no_metadata(self): + mmvec.actions.heatmap(self.ranks) + + def test_heatmap_one_metadata(self): + mmvec.actions.heatmap(self.ranks, self.taxa, None) + + def test_heatmap_no_taxonomy_parsing(self): + mmvec.actions.heatmap(self.ranks, self.taxa, None, level=-1) + + def test_heatmap_plot_axis_labels(self): + mmvec.actions.heatmap(self.ranks, x_labels=True, y_labels=True) + + +class TestPairedHeatmap(unittest.TestCase): + + def setUp(self): + _ranks = pd.DataFrame([[4.1, 1.3, 2.1], [0.1, 0.3, 0.2], + [2.2, 4.3, 3.2], [-6.3, -4.4, 2.1]], + index=pd.Index([c for c in 'ABCD'], name='id'), + columns=['m1', 'm2', 'm3']).T + self.ranks = Artifact.import_data('FeatureData[Conditional]', _ranks) + self.taxa = CategoricalMetadataColumn(pd.Series([ + 'k__Bacteria; p__Proteobacteria; c__Deltaproteobacteria; ' + 'o__Desulfobacterales; f__Desulfobulbaceae; g__; s__', + 'k__Bacteria; p__Cyanobacteria; c__Chloroplast; o__Streptophyta', + 'k__Bacteria; p__Proteobacteria; c__Alphaproteobacteria; ' + 'o__Rickettsiales; f__mitochondria; g__Lardizabala; s__biternata', + 'k__Archaea; p__Euryarchaeota; c__Methanomicrobia; ' + 'o__Methanosarcinales; f__Methanosarcinaceae; g__Methanosarcina'], + index=pd.Index([c for c in 'ABCD'], name='feature-id'), + name='Taxon')) + metabolites = biom.Table( + np.array([[9, 8, 2], [2, 1, 2], [9, 4, 5], [8, 8, 7]]), + sample_ids=['s1', 's2', 's3'], + observation_ids=['m1', 'm2', 'm3', 'm4']) + self.metabolites = Artifact.import_data( + 'FeatureTable[Frequency]', metabolites) + microbes = biom.Table( + np.array([[1, 2, 3], [3, 6, 3], [1, 9, 9], [8, 8, 7]]), + sample_ids=['s1', 's2', 's3'], observation_ids=[i for i in 'ABCD']) + self.microbes = Artifact.import_data( + 'FeatureTable[Frequency]', microbes) + + def test_paired_heatmaps_single_feature(self): + mmvec.actions.paired_heatmap( + self.ranks, self.microbes, self.metabolites, features=['C'], + microbe_metadata=self.taxa) + + def test_paired_heatmaps_multifeature(self): + mmvec.actions.paired_heatmap( + self.ranks, self.microbes, self.metabolites, features=['A', 'C']) + + def test_paired_heatmaps_fail_on_unknown_feature(self): + with self.assertRaisesRegex(ValueError, "must represent feature IDs"): + mmvec.actions.paired_heatmap( + self.ranks, self.microbes, self.metabolites, + features=['A', 'barf']) + + +if __name__ == "__main__": + unittest.main() From e7c7f895bd196f081d4a3c6c1fb6e78ec69d9999 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 5 May 2022 12:45:31 -0700 Subject: [PATCH 17/27] IMP: add min pytorch version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4198c7a..fcc0583 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ 'scikit-bio', 'seaborn', 'tqdm', - 'pytorch' + 'pytorch>=1.9.0' ], classifiers=classifiers, entry_points={ From edc554ac67707109d99679e8190d30197d579866 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 5 May 2022 12:50:00 -0700 Subject: [PATCH 18/27] DEBUG: conda-> pip pytorch install name --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fcc0583..d568f26 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ 'scikit-bio', 'seaborn', 'tqdm', - 'pytorch>=1.9.0' + 'torch>=1.9.0' ], classifiers=classifiers, entry_points={ From 5308aa1c68eefdcbb38ffe8fb1378ea595465504 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 5 May 2022 16:11:03 -0700 Subject: [PATCH 19/27] IMP: name tweaks --- mmvec/ALR.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mmvec/ALR.py b/mmvec/ALR.py index 1f8ca45..b97f32d 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -72,13 +72,13 @@ def forward(self, X): z = z + self.encoder_bias[X].reshape((*X.shape, 1)) y_pred = self.decoder(z) - forward_dist = Multinomial(total_count=0, - validate_args=False, - probs=y_pred) + result_dist = Multinomial(total_count=0, + validate_args=False, + probs=y_pred) - forward_dist = forward_dist.log_prob(self.metabolites) + prior = result_dist.log_prob(self.metabolites) - l_y = forward_dist.sum(0).sum() + l_y = prior.sum(0).sum() u_weights = self.encoder.weight l_u = Normal(0, self.sigma_u).log_prob(u_weights).sum() From 6d75cd194a2cc3e5324fcdc823980053ab432d4e Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Fri, 6 May 2022 13:17:59 -0700 Subject: [PATCH 20/27] IMP: var renaming and observation based epochs --- mmvec/ALR.py | 11 ++++++----- mmvec/tests/test_multimodal.py | 7 +++---- mmvec/train.py | 19 +++++++++++-------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/mmvec/ALR.py b/mmvec/ALR.py index b97f32d..2302f14 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -24,9 +24,10 @@ def structure_data(microbes, metabolites): metabolites = torch.tensor(metabolites.values, dtype=torch.int64) microbe_relative_frequency = (microbes.T/microbes.sum(1)).T + nnz = torch.count_nonzero(microbes).item() return (microbes, metabolites, microbe_idx, metabolite_idx, microbe_count, - metabolite_count, microbe_relative_frequency) + metabolite_count, microbe_relative_frequency, nnz) class LinearALR(nn.Module): @@ -51,7 +52,7 @@ def __init__(self, microbes, metabolites, latent_dim, sigma_u=1, (self.microbes, self.metabolites, self.microbe_idx, self. metabolite_idx, self.num_microbes, self.num_metabolites, - self.microbe_relative_freq) = structure_data(microbes, + self.microbe_relative_freq, self.nnz) = structure_data(microbes, metabolites) self.sigma_u = sigma_u self.sigma_v = sigma_v @@ -72,13 +73,13 @@ def forward(self, X): z = z + self.encoder_bias[X].reshape((*X.shape, 1)) y_pred = self.decoder(z) - result_dist = Multinomial(total_count=0, + predicted = torch.distributions.multinomial.Multinomial(total_count=0, validate_args=False, probs=y_pred) - prior = result_dist.log_prob(self.metabolites) + data_likelihood = predicted.log_prob(self.metabolites) - l_y = prior.sum(0).sum() + l_y = data_likelihood.sum(0).sum() u_weights = self.encoder.weight l_u = Normal(0, self.sigma_u).log_prob(u_weights).sum() diff --git a/mmvec/tests/test_multimodal.py b/mmvec/tests/test_multimodal.py index e312f1c..d09456d 100644 --- a/mmvec/tests/test_multimodal.py +++ b/mmvec/tests/test_multimodal.py @@ -42,8 +42,8 @@ def test_fit(self): n, d1 = self.trainX.shape n, d2 = self.trainY.shape model = MMvecALR(self.trainX, self.trainY, latent_dim=2) - mmvec_training_loop(model=model, learning_rate=0.1, batch_size=1000, - epochs=1000) + mmvec_training_loop(model=model, learning_rate=0.01, batch_size=50, + epochs=10000) U_ = np.hstack( (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) @@ -79,8 +79,7 @@ def test_fit(self): # # # sanity check cross validation # self.assertLess(model.cv.eval(), 500) -# - + #class TestMMvecSoilsBenchmark(unittest.TestCase): # def setUp(self): diff --git a/mmvec/train.py b/mmvec/train.py index e2328cb..f42d866 100644 --- a/mmvec/train.py +++ b/mmvec/train.py @@ -1,19 +1,22 @@ import torch def mmvec_training_loop(model, learning_rate, batch_size, epochs): - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.8, 0.9), maximize=True) for epoch in range(epochs): + batch_iterations = int(model.nnz / batch_size) - draws = torch.multinomial(model.microbe_relative_freq, - batch_size, - replacement=True).T + for batch in range(batch_iterations): - mmvec_model = model(draws) + draws = torch.multinomial(model.microbe_relative_freq, + batch_size, + replacement=True).T - optimizer.zero_grad() - mmvec_model.backward() - optimizer.step() + mmvec_model = model(draws) + + optimizer.zero_grad() + mmvec_model.backward() + optimizer.step() if epoch % 500 == 0: print(f"loss: {mmvec_model.item()}\nBatch #: {epoch}") From b38a1d2f3831cec272ebd019ca9248b28479702f Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 19 May 2022 12:04:31 -0700 Subject: [PATCH 21/27] FEAT: plugin-standup --- mmvec/ALR.py | 6 +- mmvec/multimodal.py | 1 + mmvec/q2/_method.py | 138 ++++++++++++++++----------------- mmvec/q2/tests/test_method.py | 24 ++++-- mmvec/tests/test_multimodal.py | 3 +- mmvec/train.py | 5 +- scripts/mmvec | 2 - setup.py | 2 +- 8 files changed, 97 insertions(+), 84 deletions(-) diff --git a/mmvec/ALR.py b/mmvec/ALR.py index 2302f14..992dabe 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -10,8 +10,8 @@ def structure_data(microbes, metabolites): - #microbes = microbes.to_dataframe().T - #metabolites = metabolites.to_dataframe().T + microbes = microbes.to_dataframe().T + metabolites = metabolites.to_dataframe().T microbes = microbes.loc[metabolites.index] microbe_idx = microbes.columns @@ -79,7 +79,7 @@ def forward(self, X): data_likelihood = predicted.log_prob(self.metabolites) - l_y = data_likelihood.sum(0).sum() + l_y = data_likelihood.sum(0).mean() u_weights = self.encoder.weight l_u = Normal(0, self.sigma_u).log_prob(u_weights).sum() diff --git a/mmvec/multimodal.py b/mmvec/multimodal.py index eb54ded..2c7017b 100644 --- a/mmvec/multimodal.py +++ b/mmvec/multimodal.py @@ -4,6 +4,7 @@ import numpy as np import tensorflow as tf from tensorflow.contrib.distributions import Multinomial, Normal +from mmvec.ALR import MMvecALR import datetime diff --git a/mmvec/q2/_method.py b/mmvec/q2/_method.py index bb6dfbf..dbb0fdf 100644 --- a/mmvec/q2/_method.py +++ b/mmvec/q2/_method.py @@ -1,11 +1,11 @@ import biom import pandas as pd import numpy as np -import tensorflow as tf from skbio import OrdinationResults import qiime2 from qiime2.plugin import Metadata -from mmvec.multimodal import MMvec +from mmvec.train import mmvec_training_loop +from mmvec.ALR import MMvecALR from mmvec.util import split_tables from scipy.sparse import coo_matrix from scipy.sparse.linalg import svds @@ -54,73 +54,71 @@ def paired_omics(microbes: biom.Table, train_microbes_coo = coo_matrix(train_microbes_df.values) test_microbes_coo = coo_matrix(test_microbes_df.values) - with tf.Graph().as_default(), tf.Session() as session: - model = MMvec( - latent_dim=latent_dim, - u_scale=input_prior, v_scale=output_prior, - batch_size=batch_size, - device_name=device_name, - learning_rate=learning_rate) - model(session, - train_microbes_coo, train_metabolites_df.values, - test_microbes_coo, test_metabolites_df.values) - - loss, cv = model.fit(epoch=epochs, summary_interval=summary_interval) - ranks = pd.DataFrame(model.ranks(), index=train_microbes_df.columns, - columns=train_metabolites_df.columns) - if latent_dim > 0: - u, s, v = svds(ranks - ranks.mean(axis=0), k=latent_dim) - else: - # fake it until you make it - u, s, v = svds(ranks - ranks.mean(axis=0), k=1) - - ranks = ranks.T - ranks.index.name = 'featureid' - s = s[::-1] - u = u[:, ::-1] - v = v[::-1, :] - if equalize_biplot: - microbe_embed = u @ np.sqrt(np.diag(s)) - metabolite_embed = v.T @ np.sqrt(np.diag(s)) - else: - microbe_embed = u @ np.diag(s) - metabolite_embed = v.T - - pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] - features = pd.DataFrame( - microbe_embed, columns=pc_ids, - index=train_microbes_df.columns) - samples = pd.DataFrame( - metabolite_embed, columns=pc_ids, - index=train_metabolites_df.columns) - short_method_name = 'mmvec biplot' - long_method_name = 'Multiomics mmvec biplot' - eigvals = pd.Series(s, index=pc_ids) - proportion_explained = pd.Series(s**2 / np.sum(s**2), index=pc_ids) - biplot = OrdinationResults( - short_method_name, long_method_name, eigvals, - samples=samples, features=features, - proportion_explained=proportion_explained) - - its = np.arange(len(loss)) - convergence_stats = pd.DataFrame( - { - 'loss': loss, - 'cross-validation': cv, - 'iteration': its - } + #with tf.Graph().as_default(), tf.Session() as session: + model = MMvecALR( + microbes=microbes, + metabolites= metabolites, + latent_dim=latent_dim, + sigma_u=input_prior, sigma_v=output_prior, ) - convergence_stats.index.name = 'id' - convergence_stats.index = convergence_stats.index.astype(np.str) - - c = convergence_stats['loss'].astype(np.float) - convergence_stats['loss'] = c - - c = convergence_stats['cross-validation'].astype(np.float) - convergence_stats['cross-validation'] = c - - c = convergence_stats['iteration'].astype(np.int) - convergence_stats['iteration'] = c - - return ranks, biplot, qiime2.Metadata(convergence_stats) + mmvec_training_loop(model=model, learning_rate=learning_rate, epochs=epochs, batch_size=batch_size, summary_interval=summary_interval) + ranks = model.ranks_dataframe() + #ranks = pd.DataFrame(model.ranks(), index=train_microbes_df.columns, + # columns=train_metabolites_df.columns) + if latent_dim > 0: + u, s, v = svds(ranks - ranks.mean(axis=0), k=latent_dim) + else: + # fake it until you make it + u, s, v = svds(ranks - ranks.mean(axis=0), k=1) + + ranks = ranks.T + ranks.index.name = 'featureid' + s = s[::-1] + u = u[:, ::-1] + v = v[::-1, :] + if equalize_biplot: + microbe_embed = u @ np.sqrt(np.diag(s)) + metabolite_embed = v.T @ np.sqrt(np.diag(s)) + else: + microbe_embed = u @ np.diag(s) + metabolite_embed = v.T + + pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] + features = pd.DataFrame( + microbe_embed, columns=pc_ids, + index=train_microbes_df.columns) + samples = pd.DataFrame( + metabolite_embed, columns=pc_ids, + index=train_metabolites_df.columns) + short_method_name = 'mmvec biplot' + long_method_name = 'Multiomics mmvec biplot' + eigvals = pd.Series(s, index=pc_ids) + proportion_explained = pd.Series(s**2 / np.sum(s**2), index=pc_ids) + biplot = OrdinationResults( + short_method_name, long_method_name, eigvals, + samples=samples, features=features, + proportion_explained=proportion_explained) + + its = np.arange(len(loss)) + convergence_stats = pd.DataFrame( + { + 'loss': loss, + 'cross-validation': cv, + 'iteration': its + } + ) + + convergence_stats.index.name = 'id' + convergence_stats.index = convergence_stats.index.astype(np.str) + + c = convergence_stats['loss'].astype(np.float) + convergence_stats['loss'] = c + + c = convergence_stats['cross-validation'].astype(np.float) + convergence_stats['cross-validation'] = c + + c = convergence_stats['iteration'].astype(np.int) + convergence_stats['iteration'] = c + + return ranks, biplot, qiime2.Metadata(convergence_stats) diff --git a/mmvec/q2/tests/test_method.py b/mmvec/q2/tests/test_method.py index 2bae849..37759d5 100644 --- a/mmvec/q2/tests/test_method.py +++ b/mmvec/q2/tests/test_method.py @@ -1,7 +1,6 @@ import biom import unittest import numpy as np -import tensorflow as tf from mmvec.q2._method import paired_omics from mmvec.util import random_multimodal from skbio.stats.composition import clr_inv @@ -11,6 +10,21 @@ class TestMMvec(unittest.TestCase): + def setUp(self): + # build small simulation + np.random.seed(1) + res = random_multimodal( + num_microbes=8, num_metabolites=8, num_samples=150, + latent_dim=2, sigmaQ=2, + microbe_total=1000, metabolite_total=10000, seed=1 + ) + (self.microbes, self.metabolites, self.X, self.B, + self.U, self.Ubias, self.V, self.Vbias) = res + num_train = 10 + self.trainX = self.microbes.iloc[:-num_train] + self.testX = self.microbes.iloc[-num_train:] + self.trainY = self.metabolites.iloc[:-num_train] + self.testY = self.metabolites.iloc[-num_train:] def setUp(self): np.random.seed(1) res = random_multimodal( @@ -40,8 +54,8 @@ def setUp(self): def test_fit(self): np.random.seed(1) - tf.reset_default_graph() - tf.set_random_seed(0) + #tf.reset_default_graph() + #tf.set_random_seed(0) latent_dim = 2 res_ranks, res_biplot, _ = paired_omics( self.microbes, self.metabolites, @@ -70,8 +84,8 @@ def test_fit(self): def test_equalize_sv(self): np.random.seed(1) - tf.reset_default_graph() - tf.set_random_seed(0) + #tf.reset_default_graph() + #tf.set_random_seed(0) latent_dim = 2 res_ranks, res_biplot, _ = paired_omics( self.microbes, self.metabolites, diff --git a/mmvec/tests/test_multimodal.py b/mmvec/tests/test_multimodal.py index d09456d..6b8e9b6 100644 --- a/mmvec/tests/test_multimodal.py +++ b/mmvec/tests/test_multimodal.py @@ -14,7 +14,7 @@ from mmvec.util import random_multimodal -class TestMMvec(unittest.TestCase): +class TestALR(unittest.TestCase): def setUp(self): # build small simulation np.random.seed(1) @@ -44,6 +44,7 @@ def test_fit(self): model = MMvecALR(self.trainX, self.trainY, latent_dim=2) mmvec_training_loop(model=model, learning_rate=0.01, batch_size=50, epochs=10000) + assert False U_ = np.hstack( (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) diff --git a/mmvec/train.py b/mmvec/train.py index f42d866..246ac84 100644 --- a/mmvec/train.py +++ b/mmvec/train.py @@ -1,6 +1,7 @@ import torch -def mmvec_training_loop(model, learning_rate, batch_size, epochs): +def mmvec_training_loop(model, learning_rate, batch_size, epochs, + summary_interval): optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.8, 0.9), maximize=True) for epoch in range(epochs): @@ -18,5 +19,5 @@ def mmvec_training_loop(model, learning_rate, batch_size, epochs): mmvec_model.backward() optimizer.step() - if epoch % 500 == 0: + if epoch % summary_interval == 0: print(f"loss: {mmvec_model.item()}\nBatch #: {epoch}") diff --git a/scripts/mmvec b/scripts/mmvec index 4976c99..2e03634 100644 --- a/scripts/mmvec +++ b/scripts/mmvec @@ -14,8 +14,6 @@ from skbio.stats.composition import clr_inv as softmax from scipy.stats import entropy, spearmanr from scipy.sparse import coo_matrix from scipy.sparse.linalg import svds -import tensorflow as tf -from tensorflow.contrib.distributions import Multinomial, Normal from mmvec.multimodal import MMvec from mmvec.util import split_tables, format_params import matplotlib.pyplot as plt diff --git a/setup.py b/setup.py index fcc0583..d568f26 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ 'scikit-bio', 'seaborn', 'tqdm', - 'pytorch>=1.9.0' + 'torch>=1.9.0' ], classifiers=classifiers, entry_points={ From ae8c948a007eb7d558e16a87abd5766c3b213bbb Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Thu, 19 May 2022 15:46:02 -0700 Subject: [PATCH 22/27] BUG: getting test interface to work after plugin --- mmvec/ALR.py | 6 ++++-- mmvec/tests/test_multimodal.py | 4 +--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mmvec/ALR.py b/mmvec/ALR.py index 992dabe..167e033 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -10,8 +10,10 @@ def structure_data(microbes, metabolites): - microbes = microbes.to_dataframe().T - metabolites = metabolites.to_dataframe().T + if type(microbes) is not pd.core.frame.DataFrame: + microbes = microbes.to_dataframe().T + if type(metabolites) is not pd.core.frame.DataFrame: + metabolites = metabolites.to_dataframe().T microbes = microbes.loc[metabolites.index] microbe_idx = microbes.columns diff --git a/mmvec/tests/test_multimodal.py b/mmvec/tests/test_multimodal.py index 6b8e9b6..1d809da 100644 --- a/mmvec/tests/test_multimodal.py +++ b/mmvec/tests/test_multimodal.py @@ -43,9 +43,7 @@ def test_fit(self): n, d2 = self.trainY.shape model = MMvecALR(self.trainX, self.trainY, latent_dim=2) mmvec_training_loop(model=model, learning_rate=0.01, batch_size=50, - epochs=10000) - assert False - + epochs=500, summary_interval=100) U_ = np.hstack( (np.ones((self.U.shape[0], 1)), self.Ubias, self.U)) V_ = np.vstack( From 88d5dbc431b47969efe23ae13073239b19cc58e0 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Wed, 25 May 2022 14:20:54 -0700 Subject: [PATCH 23/27] IMP: q2 paired-omics method wired up --- mmvec/q2/_method.py | 73 +++++------------------ mmvec/train.py | 31 ++++++++-- scripts/mmvec | 141 ++++++++++++++++++++++---------------------- 3 files changed, 110 insertions(+), 135 deletions(-) diff --git a/mmvec/q2/_method.py b/mmvec/q2/_method.py index dbb0fdf..0f19e15 100644 --- a/mmvec/q2/_method.py +++ b/mmvec/q2/_method.py @@ -13,7 +13,7 @@ def paired_omics(microbes: biom.Table, metabolites: biom.Table, - metadata: Metadata = None, + metadata: qiime2.Metadata = None, training_column: str = None, num_testing_examples: int = 5, min_feature_count: int = 10, @@ -26,8 +26,8 @@ def paired_omics(microbes: biom.Table, equalize_biplot: float = False, arm_the_gpu: bool = False, summary_interval: int = 60) -> ( - pd.DataFrame, OrdinationResults, qiime2.Metadata - ): + pd.DataFrame, OrdinationResults, qiime2.Metadata + ): if metadata is not None: metadata = metadata.to_dataframe() @@ -54,7 +54,6 @@ def paired_omics(microbes: biom.Table, train_microbes_coo = coo_matrix(train_microbes_df.values) test_microbes_coo = coo_matrix(test_microbes_df.values) - #with tf.Graph().as_default(), tf.Session() as session: model = MMvecALR( microbes=microbes, metabolites= metabolites, @@ -62,63 +61,19 @@ def paired_omics(microbes: biom.Table, sigma_u=input_prior, sigma_v=output_prior, ) - mmvec_training_loop(model=model, learning_rate=learning_rate, epochs=epochs, batch_size=batch_size, summary_interval=summary_interval) - ranks = model.ranks_dataframe() - #ranks = pd.DataFrame(model.ranks(), index=train_microbes_df.columns, - # columns=train_metabolites_df.columns) - if latent_dim > 0: - u, s, v = svds(ranks - ranks.mean(axis=0), k=latent_dim) - else: - # fake it until you make it - u, s, v = svds(ranks - ranks.mean(axis=0), k=1) - - ranks = ranks.T - ranks.index.name = 'featureid' - s = s[::-1] - u = u[:, ::-1] - v = v[::-1, :] - if equalize_biplot: - microbe_embed = u @ np.sqrt(np.diag(s)) - metabolite_embed = v.T @ np.sqrt(np.diag(s)) - else: - microbe_embed = u @ np.diag(s) - metabolite_embed = v.T + convergence_stats = pd.DataFrame.from_records(mmvec_training_loop(model=model, + learning_rate=learning_rate, epochs=epochs, batch_size=batch_size, + summary_interval=summary_interval) + , + columns=['iteration','loss', 'cross-validation']) - pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] - features = pd.DataFrame( - microbe_embed, columns=pc_ids, - index=train_microbes_df.columns) - samples = pd.DataFrame( - metabolite_embed, columns=pc_ids, - index=train_metabolites_df.columns) - short_method_name = 'mmvec biplot' - long_method_name = 'Multiomics mmvec biplot' - eigvals = pd.Series(s, index=pc_ids) - proportion_explained = pd.Series(s**2 / np.sum(s**2), index=pc_ids) - biplot = OrdinationResults( - short_method_name, long_method_name, eigvals, - samples=samples, features=features, - proportion_explained=proportion_explained) - its = np.arange(len(loss)) - convergence_stats = pd.DataFrame( - { - 'loss': loss, - 'cross-validation': cv, - 'iteration': its - } - ) + convergence_stats.astype({'loss': 'float', 'cross-validation': + 'float'}, copy=False) - convergence_stats.index.name = 'id' - convergence_stats.index = convergence_stats.index.astype(np.str) - - c = convergence_stats['loss'].astype(np.float) - convergence_stats['loss'] = c - - c = convergence_stats['cross-validation'].astype(np.float) - convergence_stats['cross-validation'] = c - - c = convergence_stats['iteration'].astype(np.int) - convergence_stats['iteration'] = c + convergence_stats.set_index("iteration", inplace=True) + convergence_stats.index.name="id" + biplot = model.get_ordination() + ranks = model.ranks_dataframe() return ranks, biplot, qiime2.Metadata(convergence_stats) diff --git a/mmvec/train.py b/mmvec/train.py index 246ac84..c6249b7 100644 --- a/mmvec/train.py +++ b/mmvec/train.py @@ -8,16 +8,35 @@ def mmvec_training_loop(model, learning_rate, batch_size, epochs, batch_iterations = int(model.nnz / batch_size) for batch in range(batch_iterations): + #iteration = epoch*batch_iterations + batch + 1 + iteration = epoch * model.nnz // batch_size - draws = torch.multinomial(model.microbe_relative_freq, - batch_size, - replacement=True).T + draws = torch.multinomial( + model.microbe_relative_freq(model.microbes_train), + batch_size, + replacement=True).T - mmvec_model = model(draws) + loss = model(draws, model.metabolites_train) optimizer.zero_grad() - mmvec_model.backward() + loss.backward() optimizer.step() + + + + + if epoch % summary_interval == 0: - print(f"loss: {mmvec_model.item()}\nBatch #: {epoch}") + + with torch.no_grad(): + cv_draw = torch.multinomial( + model.microbe_relative_freq(model.microbes_test), + batch_size, + replacement=True).T + cv_loss = model(cv_draw, model.metabolites_test) + yield (str(iteration), loss.item(), cv_loss.item()) + + else: + yield (str(iteration), loss.item(), None) + diff --git a/scripts/mmvec b/scripts/mmvec index 2e03634..0fd8822 100644 --- a/scripts/mmvec +++ b/scripts/mmvec @@ -14,6 +14,7 @@ from skbio.stats.composition import clr_inv as softmax from scipy.stats import entropy, spearmanr from scipy.sparse import coo_matrix from scipy.sparse.linalg import svds +from tensorflow.contrib.distributions import Multinomial, Normal from mmvec.multimodal import MMvec from mmvec.util import split_tables, format_params import matplotlib.pyplot as plt @@ -155,76 +156,76 @@ def paired_omics(microbe_file, metabolite_file, else: device_name='/cpu:0' - config = tf.ConfigProto() - with tf.Graph().as_default(), tf.Session(config=config) as session: - model = MMvec( - latent_dim=latent_dim, - u_scale=input_prior, v_scale=output_prior, - learning_rate = learning_rate, - beta_1=beta1, beta_2=beta2, - device_name=device_name, - batch_size=batch_size, - clipnorm=clipnorm, save_path=sname) - - model(session, - train_microbes_coo, train_metabolites_df.values, - test_microbes_coo, test_metabolites_df.values) - - loss, cv = model.fit(epoch=epochs, summary_interval=summary_interval, - checkpoint_interval=checkpoint_interval) - - pc_ids = list(range(latent_dim)) - vdim = model.V.shape[0] - V = np.hstack((np.zeros((vdim, 1)), model.V)) - V = V.T - Vbias = np.hstack((np.zeros(1), model.Vbias.ravel())) - - # Save to an embeddings file - Uparam = format_params(model.U, pc_ids, list(train_microbes_df.columns), 'microbe') - Vparam = format_params(V, pc_ids, list(train_metabolites_df.columns), 'metabolite') - df = pd.concat( - ( - Uparam, Vparam, - format_params(model.Ubias, ['bias'], train_microbes_df.columns, 'microbe'), - format_params(Vbias, ['bias'], train_metabolites_df.columns, 'metabolite') - ), axis=0) - - df.to_csv(embeddings_file, sep='\t') - - # Save to a ranks file - ranks = pd.DataFrame(model.ranks(), index=train_microbes_df.columns, - columns=train_metabolites_df.columns) - - u, s, v = svds(ranks - ranks.mean(axis=0), k=latent_dim) - ranks = ranks.T - ranks.index.name = 'featureid' - ranks.to_csv(ranks_file, sep='\t') - # Save to an ordination file - s = s[::-1] - u = u[:, ::-1] - v = v[::-1, :] - if equalize_biplot: - microbe_embed = u @ np.sqrt(np.diag(s)) - metabolite_embed = v.T @ np.sqrt(np.diag(s)) - else: - microbe_embed = u @ np.diag(s) - metabolite_embed = v.T - pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] - features = pd.DataFrame( - microbe_embed, columns=pc_ids, - index=train_microbes_df.columns) - samples = pd.DataFrame( - metabolite_embed, columns=pc_ids, - index=train_metabolites_df.columns) - short_method_name = 'mmvec biplot' - long_method_name = 'Multiomics mmvec biplot' - eigvals = pd.Series(s, index=pc_ids) - proportion_explained = pd.Series(s**2 / np.sum(s**2), index=pc_ids) - biplot = OrdinationResults( - short_method_name, long_method_name, eigvals, - samples=samples, features=features, - proportion_explained=proportion_explained) - biplot.write(ordination_file) + #config = tf.ConfigProto() + #with tf.Graph().as_default(), tf.Session(config=config) as session: + model = MMvec( + latent_dim=latent_dim, + u_scale=input_prior, v_scale=output_prior, + learning_rate = learning_rate, + beta_1=beta1, beta_2=beta2, + device_name=device_name, + batch_size=batch_size, + clipnorm=clipnorm, save_path=sname) + + model(session, + train_microbes_coo, train_metabolites_df.values, + test_microbes_coo, test_metabolites_df.values) + + loss, cv = model.fit(epoch=epochs, summary_interval=summary_interval, + checkpoint_interval=checkpoint_interval) + + pc_ids = list(range(latent_dim)) + vdim = model.V.shape[0] + V = np.hstack((np.zeros((vdim, 1)), model.V)) + V = V.T + Vbias = np.hstack((np.zeros(1), model.Vbias.ravel())) + + # Save to an embeddings file + Uparam = format_params(model.U, pc_ids, list(train_microbes_df.columns), 'microbe') + Vparam = format_params(V, pc_ids, list(train_metabolites_df.columns), 'metabolite') + df = pd.concat( + ( + Uparam, Vparam, + format_params(model.Ubias, ['bias'], train_microbes_df.columns, 'microbe'), + format_params(Vbias, ['bias'], train_metabolites_df.columns, 'metabolite') + ), axis=0) + + df.to_csv(embeddings_file, sep='\t') + + # Save to a ranks file + ranks = pd.DataFrame(model.ranks(), index=train_microbes_df.columns, + columns=train_metabolites_df.columns) + + u, s, v = svds(ranks - ranks.mean(axis=0), k=latent_dim) + ranks = ranks.T + ranks.index.name = 'featureid' + ranks.to_csv(ranks_file, sep='\t') + # Save to an ordination file + s = s[::-1] + u = u[:, ::-1] + v = v[::-1, :] + if equalize_biplot: + microbe_embed = u @ np.sqrt(np.diag(s)) + metabolite_embed = v.T @ np.sqrt(np.diag(s)) + else: + microbe_embed = u @ np.diag(s) + metabolite_embed = v.T + pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] + features = pd.DataFrame( + microbe_embed, columns=pc_ids, + index=train_microbes_df.columns) + samples = pd.DataFrame( + metabolite_embed, columns=pc_ids, + index=train_metabolites_df.columns) + short_method_name = 'mmvec biplot' + long_method_name = 'Multiomics mmvec biplot' + eigvals = pd.Series(s, index=pc_ids) + proportion_explained = pd.Series(s**2 / np.sum(s**2), index=pc_ids) + biplot = OrdinationResults( + short_method_name, long_method_name, eigvals, + samples=samples, features=features, + proportion_explained=proportion_explained) + biplot.write(ordination_file) if __name__ == '__main__': From aa967e4dae6746b440546d37f42ff86522cdd17c Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Wed, 25 May 2022 15:21:49 -0700 Subject: [PATCH 24/27] BUG: fixing index filtering on biom-tables --- mmvec/ALR.py | 59 ++++++++++++++++++++++++++++++++------------- mmvec/q2/_method.py | 1 + 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/mmvec/ALR.py b/mmvec/ALR.py index 167e033..ebae822 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -1,3 +1,4 @@ +from numpy import float64 import pandas as pd import torch @@ -9,27 +10,38 @@ from skbio import OrdinationResults -def structure_data(microbes, metabolites): +def structure_data(microbes, metabolites, holdout_num): if type(microbes) is not pd.core.frame.DataFrame: microbes = microbes.to_dataframe().T if type(metabolites) is not pd.core.frame.DataFrame: metabolites = metabolites.to_dataframe().T - microbes = microbes.loc[metabolites.index] + idx = microbes.index.intersection(metabolites.index) + microbes = microbes.loc[idx] + metabolites = metabolites.loc[idx] microbe_idx = microbes.columns metabolite_idx = metabolites.columns - microbe_count = microbes.shape[1] - metabolite_count = metabolites.shape[1] + microbes_train = torch.tensor(microbes[:-holdout_num].values, + dtype=torch.float64) + metabolites_train = torch.tensor(metabolites[:-holdout_num].values, + dtype=torch.float64) + + microbes_test = torch.tensor(microbes[-holdout_num:].values, + dtype=torch.float64) + metabolites_test = torch.tensor(metabolites[-holdout_num:].values, + dtype=torch.float64) + microbe_count = microbes_train.shape[1] + metabolite_count = metabolites_train.shape[1] microbes = torch.tensor(microbes.values, dtype=torch.int) metabolites = torch.tensor(metabolites.values, dtype=torch.int64) - microbe_relative_frequency = (microbes.T/microbes.sum(1)).T nnz = torch.count_nonzero(microbes).item() - return (microbes, metabolites, microbe_idx, metabolite_idx, microbe_count, - metabolite_count, microbe_relative_frequency, nnz) + return (microbes_train, microbes_test, metabolites_train, + metabolites_test, microbe_idx, metabolite_idx, microbe_count, + metabolite_count, nnz) class LinearALR(nn.Module): @@ -47,15 +59,16 @@ def forward(self, x): class MMvecALR(nn.Module): def __init__(self, microbes, metabolites, latent_dim, sigma_u=1, - sigma_v=1): + sigma_v=1, holdout_num=4): super().__init__() # Data setup - (self.microbes, self.metabolites, - self.microbe_idx, self. metabolite_idx, - self.num_microbes, self.num_metabolites, - self.microbe_relative_freq, self.nnz) = structure_data(microbes, - metabolites) + (self.microbes_train, self.microbes_test, self.metabolites_train, + self.metabolites_test, self.microbe_idx, self. metabolite_idx, + self.num_microbes, self.num_metabolites, + self.nnz) = structure_data( + microbes, metabolites, holdout_num) + self.sigma_u = sigma_u self.sigma_v = sigma_v self.latent_dim = latent_dim @@ -67,7 +80,7 @@ def __init__(self, microbes, metabolites, latent_dim, sigma_u=1, self.decoder = LinearALR(self.latent_dim, self.num_metabolites) - def forward(self, X): + def forward(self, X, Y): # Three likelihoods, the likelihood of each weight and the likelihood # of the data fitting in the way that we thought # LYs @@ -79,7 +92,7 @@ def forward(self, X): validate_args=False, probs=y_pred) - data_likelihood = predicted.log_prob(self.metabolites) + data_likelihood = predicted.log_prob(Y) l_y = data_likelihood.sum(0).mean() @@ -146,8 +159,9 @@ def v_bias(self): @property def U(self): + print (self.encoder.weight.shape) U = torch.cat( - (torch.ones((self.num_microbes, 1)), + (torch.ones((self.encoder.weight.shape[0], 1)), self.u_bias, self.encoder.weight.detach()), dim=1) @@ -157,7 +171,7 @@ def U(self): def V(self): V = torch.cat( (self.v_bias.unsqueeze(dim=0), - torch.ones((1, self.num_metabolites - 1)), + torch.ones((1, self.decoder.linear.weight.shape[0] )), self.decoder.linear.weight.detach().T), dim=0) return V @@ -174,3 +188,14 @@ def ranks(self): ), dim=1) res = res - res.mean(axis=1).reshape(-1, 1) return res + def microbe_relative_freq(self,microbes): + return (microbes.T / microbes.sum(1)).T + + #def loss_fn(self, y_pred, observed): + + # predicted = torch.distributions.multinomial.Multinomial(total_count=0, + # validate_args=False, + # probs=y_pred) + # + # data_likelihood = predicted.log_prob(observed) + # l_y = data_likelihood.sum(1).mean() diff --git a/mmvec/q2/_method.py b/mmvec/q2/_method.py index 0f19e15..e395a7d 100644 --- a/mmvec/q2/_method.py +++ b/mmvec/q2/_method.py @@ -32,6 +32,7 @@ def paired_omics(microbes: biom.Table, if metadata is not None: metadata = metadata.to_dataframe() + #TODO refactor for pytorch! if arm_the_gpu: # pick out the first GPU device_name = '/device:GPU:0' From a26f56ef8aeeb5e8c4c5e8b428c1fd6c88a5fb28 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Wed, 25 May 2022 16:51:11 -0700 Subject: [PATCH 25/27] BUG: index dfs based on intersection of indexes --- examples/refactor/ALR.ipynb | 135 +++++++++++++----------------------- 1 file changed, 50 insertions(+), 85 deletions(-) diff --git a/examples/refactor/ALR.ipynb b/examples/refactor/ALR.ipynb index 9e84413..97a957f 100644 --- a/examples/refactor/ALR.ipynb +++ b/examples/refactor/ALR.ipynb @@ -31,7 +31,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -60,7 +60,7 @@ "microbes = biom.load_table(\"./soil_microbes.biom\")\n", "metabolites = biom.load_table(\"./soil_metabolites.biom\")\n", "\n", - "model = mmvec.ALR.MMvecALR(microbes, metabolites, 15, sigma_u=1, sigma_v=1)\n", + "model = mmvec.ALR.MMvecALR(microbes, metabolites, 15, sigma_u=1, sigma_v=1, holdout_num=5)\n", "\n", "microbes = microbes.to_dataframe().T\n", "metabolites = metabolites.to_dataframe().T\n", @@ -74,6 +74,7 @@ "microbe_relative_frequency = (microbes.T/microbes.sum(1)).T\n", "\n", "microbe_count = microbes.shape[1]\n", + "\n", "metabolite_count = metabolites.shape[1]" ] }, @@ -86,60 +87,65 @@ "source": [ "#model = mmvec.ALR.MMvecALR(microbe_count, metabolite_count, 15, sigma_u=1, sigma_v=1)\n", "learning_rate = 1e-3\n", - "batch_size = 200\n", - "epochs = 100\n", + "batch_size = 500\n", + "epochs = 3000\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, maximize=True)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "29799ea5", "metadata": {}, + "outputs": [], + "source": [ + "model.microbe_relative_freq(model.microbes_train)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b977e212", + "metadata": {}, + "outputs": [], + "source": [ + "maybe = mmvec.train.mmvec_training_loop(\n", + " model=model,\n", + "\n", + " batch_size=batch_size,\n", + " epochs=epochs,\n", + " learning_rate=learning_rate,\n", + " summary_interval = 25)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "30db58c0", + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[0.2017, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [0.0950, 0.1752, 0.0000, ..., 0.0000, 0.0013, 0.0000],\n", - " [0.1416, 0.0973, 0.0000, ..., 0.0000, 0.0064, 0.0000],\n", - " ...,\n", - " [0.0507, 0.0000, 0.0090, ..., 0.0058, 0.0000, 0.0051],\n", - " [0.0382, 0.0076, 0.0025, ..., 0.0000, 0.0009, 0.0100],\n", - " [0.0027, 0.0135, 0.0000, ..., 0.0000, 0.0008, 0.0036]])" + "" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "model.microbe_relative_freq" + "test_stats = pd.Dataframe.from_records(maybe, )" ] }, { "cell_type": "code", - "execution_count": 7, - "id": "b977e212", + "execution_count": null, + "id": "2eb531fe", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loss: -14241466368.0\n", - "Batch #: 0\n" - ] - } - ], - "source": [ - "maybe = mmvec.train.mmvec_training_loop(\n", - " model=model,\n", - " optimizer=optimizer,\n", - " batch_size=batch_size,\n", - " epochs=epochs)" - ] + "outputs": [], + "source": [] }, { "cell_type": "code", @@ -147,54 +153,26 @@ "id": "1f27301f", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "model.microbe_relative_freq(model.microbes_train)" + ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "8a16a24a", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 1.1758, 0.0677, -0.5526, ..., -0.7825, -0.5267, -0.1815],\n", - " [ 0.2582, -0.7442, 0.0576, ..., -1.0060, -0.2658, -0.9660],\n", - " [ 0.9580, -0.5079, -0.8934, ..., -0.5971, -0.0238, -0.3293],\n", - " ...,\n", - " [ 0.0645, 0.1041, -0.8826, ..., 0.0947, -0.1519, -0.1774],\n", - " [ 0.4002, 0.1209, -0.4545, ..., -0.8281, -1.0184, 0.7384],\n", - " [ 1.1294, -0.6548, -0.0167, ..., -1.0560, -0.1584, -0.8694]])" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "model.ranks" + "l" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "3663ecb4", "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'MMvecALR' object has no attribute 'ranks_matrix'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [10]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# h = model.ranks_df - model.ranks_df.mean(axis=0)\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m h \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mranks_matrix\u001b[49m\n\u001b[1;32m 4\u001b[0m k \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mlatent_dim\n", - "File \u001b[0;32m~/opt/miniconda3/envs/mmvec/lib/python3.10/site-packages/torch/nn/modules/module.py:1185\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1185\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, name))\n", - "\u001b[0;31mAttributeError\u001b[0m: 'MMvecALR' object has no attribute 'ranks_matrix'" - ] - } - ], + "outputs": [], "source": [ "# h = model.ranks_df - model.ranks_df.mean(axis=0)\n", "h = model.ranks_matrix\n", @@ -254,23 +232,10 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "752acb7a", "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'MMvecALR' object has no attribute 'U'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Input \u001b[0;32mIn [11]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mU\u001b[49m\n", - "File \u001b[0;32m~/opt/miniconda3/envs/mmvec/lib/python3.10/site-packages/torch/nn/modules/module.py:1185\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1185\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, name))\n", - "\u001b[0;31mAttributeError\u001b[0m: 'MMvecALR' object has no attribute 'U'" - ] - } - ], + "outputs": [], "source": [ "model.U" ] @@ -405,7 +370,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.8.12" } }, "nbformat": 4, From 9d4d26e8dea37d23b6d80655ec8c040b4de9f331 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Tue, 31 May 2022 17:58:36 -0700 Subject: [PATCH 26/27] BUG: paired-omics working now. --- mmvec/ALR.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/mmvec/ALR.py b/mmvec/ALR.py index ebae822..a6111a4 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -15,9 +15,12 @@ def structure_data(microbes, metabolites, holdout_num): microbes = microbes.to_dataframe().T if type(metabolites) is not pd.core.frame.DataFrame: metabolites = metabolites.to_dataframe().T - idx = microbes.index.intersection(metabolites.index) - microbes = microbes.loc[idx] - metabolites = metabolites.loc[idx] + # idx = microbes.index.intersection(metabolites.index) + # microbes = microbes.loc[idx] + # metabolites = metabolites.loc[idx] + #make sure none sum to zero + #microbes = microbes.loc[:, microbes.sum(axis=0) > 0] + # microbes = microbes.loc[microbes.sum(axis=1) > 0] microbe_idx = microbes.columns metabolite_idx = metabolites.columns @@ -64,7 +67,7 @@ def __init__(self, microbes, metabolites, latent_dim, sigma_u=1, # Data setup (self.microbes_train, self.microbes_test, self.metabolites_train, - self.metabolites_test, self.microbe_idx, self. metabolite_idx, + self.metabolites_test, self.microbe_idx, self.metabolite_idx, self.num_microbes, self.num_metabolites, self.nnz) = structure_data( microbes, metabolites, holdout_num) @@ -117,11 +120,19 @@ def get_ordination(self, equalize_biplot=False): # us torch.diag to go from vector to matrix with the vector on dia if equalize_biplot: - microbe_embed = u @ torch.sqrt(torch.diag(s_diag)) + #microbe_embed = u @ torch.sqrt( + # torch.diag(s_diag)).detach().numpy() + microbe_embed = u @ torch.sqrt( + torch.diag(s_diag)) + microbe_embed = microbe_embed.detach().numpy() + #metabolite_embed = v.T @ torch.sqrt(s_diag).detach().numpy() metabolite_embed = v.T @ torch.sqrt(s_diag) + metabolite_embed = metabolite_embed.detach().numpy() else: microbe_embed = u @ torch.diag(s_diag) + microbe_embed = microbe_embed.detach().numpy() metabolite_embed = v.T + metabolite_embed = metabolite_embed.detach().numpy() pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] @@ -135,8 +146,10 @@ def get_ordination(self, equalize_biplot=False): short_method_name = 'mmvec biplot' long_method_name = 'Multiomics mmvec biplot' eigvals = pd.Series(s_diag, index=pc_ids) - proportion_explained = pd.Series(torch.square(s_diag) / - torch.sum(torch.square(s_diag)), index=pc_ids) + proportion_explained = pd.Series( + torch.square(s_diag).detach().numpy() / torch.sum( + torch.square(s_diag)).detach().numpy(), + index=pc_ids, dtype=float64) biplot = OrdinationResults( short_method_name, long_method_name, eigvals, @@ -159,7 +172,6 @@ def v_bias(self): @property def U(self): - print (self.encoder.weight.shape) U = torch.cat( (torch.ones((self.encoder.weight.shape[0], 1)), self.u_bias, From da1bdb85cfe2842e6022f000af6115931e89d9e7 Mon Sep 17 00:00:00 2001 From: Keegan Evans Date: Fri, 3 Jun 2022 14:56:28 -0700 Subject: [PATCH 27/27] TEST: adding tests for ranks_bare --- mmvec/ALR.py | 166 +++++++++++++++++++++---------- mmvec/q2/tests/test_functions.py | 17 ++++ mmvec/q2/tests/test_method.py | 31 +++--- mmvec/train.py | 4 - 4 files changed, 147 insertions(+), 71 deletions(-) create mode 100644 mmvec/q2/tests/test_functions.py diff --git a/mmvec/ALR.py b/mmvec/ALR.py index a6111a4..30b8499 100644 --- a/mmvec/ALR.py +++ b/mmvec/ALR.py @@ -111,53 +111,56 @@ def forward(self, X, Y): return likelihood_sum def get_ordination(self, equalize_biplot=False): - - ranks = self.ranks() - ranks = ranks - ranks.mean(dim=0) - - u, s_diag, v = linalg.svd(ranks, full_matrices=False) - - - # us torch.diag to go from vector to matrix with the vector on dia - if equalize_biplot: - #microbe_embed = u @ torch.sqrt( - # torch.diag(s_diag)).detach().numpy() - microbe_embed = u @ torch.sqrt( - torch.diag(s_diag)) - microbe_embed = microbe_embed.detach().numpy() - #metabolite_embed = v.T @ torch.sqrt(s_diag).detach().numpy() - metabolite_embed = v.T @ torch.sqrt(s_diag) - metabolite_embed = metabolite_embed.detach().numpy() - else: - microbe_embed = u @ torch.diag(s_diag) - microbe_embed = microbe_embed.detach().numpy() - metabolite_embed = v.T - metabolite_embed = metabolite_embed.detach().numpy() - - pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] - - - features = pd.DataFrame( - microbe_embed, columns=pc_ids, index=self.microbe_idx) - - samples = pd.DataFrame(metabolite_embed, columns=pc_ids, - index=self.metabolite_idx) - - short_method_name = 'mmvec biplot' - long_method_name = 'Multiomics mmvec biplot' - eigvals = pd.Series(s_diag, index=pc_ids) - proportion_explained = pd.Series( - torch.square(s_diag).detach().numpy() / torch.sum( - torch.square(s_diag)).detach().numpy(), - index=pc_ids, dtype=float64) - - biplot = OrdinationResults( - short_method_name, long_method_name, eigvals, - samples=samples, features=features, - proportion_explained=proportion_explained) - + biplot = get_ordination_bare(self.ranks(), self.microbe_idx, + self.metabolite_idx, equalize_biplot=False) return biplot +# ranks = self.ranks() +# ranks = ranks - ranks.mean(dim=0) +# +# u, s_diag, v = linalg.svd(ranks, full_matrices=False) +# +# +# # us torch.diag to go from vector to matrix with the vector on dia +# if equalize_biplot: +# #microbe_embed = u @ torch.sqrt( +# # torch.diag(s_diag)).detach().numpy() +# microbe_embed = u @ torch.sqrt( +# torch.diag(s_diag)) +# microbe_embed = microbe_embed.detach().numpy() +# #metabolite_embed = v.T @ torch.sqrt(s_diag).detach().numpy() +# metabolite_embed = v.T @ torch.sqrt(s_diag) +# metabolite_embed = metabolite_embed.detach().numpy() +# else: +# microbe_embed = u @ torch.diag(s_diag) +# microbe_embed = microbe_embed.detach().numpy() +# metabolite_embed = v.T +# metabolite_embed = metabolite_embed.detach().numpy() +# +# pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] +# +# +# features = pd.DataFrame( +# microbe_embed, columns=pc_ids, index=self.microbe_idx) +# +# samples = pd.DataFrame(metabolite_embed, columns=pc_ids, +# index=self.metabolite_idx) +# +# short_method_name = 'mmvec biplot' +# long_method_name = 'Multiomics mmvec biplot' +# eigvals = pd.Series(s_diag, index=pc_ids) +# proportion_explained = pd.Series( +# torch.square(s_diag).detach().numpy() / torch.sum( +# torch.square(s_diag)).detach().numpy(), +# index=pc_ids, dtype=float64) +# +# biplot = OrdinationResults( +# short_method_name, long_method_name, eigvals, +# samples=samples, features=features, +# proportion_explained=proportion_explained) +# +# return biplot +# @property @@ -193,13 +196,16 @@ def ranks_dataframe(self): columns=self.metabolite_idx) def ranks(self): - # Adding the zeros is part of the inverse ALR. - res = torch.cat(( - torch.zeros((self.num_microbes, 1)), - self.U @ self.V - ), dim=1) - res = res - res.mean(axis=1).reshape(-1, 1) - return res + return ranks_bare(self.U, self.V) + #def ranks(self): + # # Adding the zeros is part of the inverse ALR. + # res = torch.cat(( + # torch.zeros((self.num_microbes, 1)), + # self.U @ self.V + # ), dim=1) + # res = res - res.mean(axis=1).reshape(-1, 1) + # return res + def microbe_relative_freq(self,microbes): return (microbes.T / microbes.sum(1)).T @@ -211,3 +217,59 @@ def microbe_relative_freq(self,microbes): # # data_likelihood = predicted.log_prob(observed) # l_y = data_likelihood.sum(1).mean() +### bare functions for testing/method creation + +def ranks_bare(u, v): + # Adding the zeros is part of the inverse ALR. + res = torch.cat(( + torch.zeros((u.shape[0], 1)), + u @ v + ), dim=1) + res = res - res.mean(axis=1).reshape(-1, 1) + return res + +def get_ordination_bare(ranks, microbe_index, metabolite_index, + equalize_biplot=False): + + ranks = ranks - ranks.mean(dim=0) + + u, s_diag, v = linalg.svd(ranks, full_matrices=False) + + + # us torch.diag to go from vector to matrix with the vector on dia + if equalize_biplot: + microbe_embed = u @ torch.sqrt( + torch.diag(s_diag)) + microbe_embed = microbe_embed.detach().numpy() + metabolite_embed = v.T @ torch.sqrt(s_diag) + metabolite_embed = metabolite_embed.detach().numpy() + else: + microbe_embed = u @ torch.diag(s_diag) + microbe_embed = microbe_embed.detach().numpy() + metabolite_embed = v.T + metabolite_embed = metabolite_embed.detach().numpy() + + pc_ids = ['PC%d' % i for i in range(microbe_embed.shape[1])] + + + features = pd.DataFrame( + microbe_embed, columns=pc_ids, index=microbe_index) + + samples = pd.DataFrame(metabolite_embed, columns=pc_ids, + index=metabolite_index) + + short_method_name = 'mmvec biplot' + long_method_name = 'Multiomics mmvec biplot' + eigvals = pd.Series(s_diag, index=pc_ids) + proportion_explained = pd.Series( + torch.square(s_diag).detach().numpy() / torch.sum( + torch.square(s_diag)).detach().numpy(), + index=pc_ids, dtype=float64) + + biplot = OrdinationResults( + short_method_name, long_method_name, eigvals, + samples=samples, features=features, + proportion_explained=proportion_explained) + + return biplot + diff --git a/mmvec/q2/tests/test_functions.py b/mmvec/q2/tests/test_functions.py new file mode 100644 index 0000000..c4ba9bc --- /dev/null +++ b/mmvec/q2/tests/test_functions.py @@ -0,0 +1,17 @@ +import unittest +from numpy.testing._private.utils import assert_equal +import torch +from mmvec.ALR import ranks_bare + +class TestBareFunctions(unittest.TestCase): + def setUp(self) -> None: + self.u = torch.rand(4, 5) + self.v = torch.rand(5, 3) + return super().setUp() + + def test_ranks(self): + ranks = ranks_bare(self.u, self.v) + print(ranks) + + assert_equal(ranks.shape[0], self.u.shape[0]) + assert_equal(ranks.shape[1], (self.v.shape[1] + 1)) diff --git a/mmvec/q2/tests/test_method.py b/mmvec/q2/tests/test_method.py index 37759d5..9062394 100644 --- a/mmvec/q2/tests/test_method.py +++ b/mmvec/q2/tests/test_method.py @@ -10,21 +10,22 @@ class TestMMvec(unittest.TestCase): - def setUp(self): - # build small simulation - np.random.seed(1) - res = random_multimodal( - num_microbes=8, num_metabolites=8, num_samples=150, - latent_dim=2, sigmaQ=2, - microbe_total=1000, metabolite_total=10000, seed=1 - ) - (self.microbes, self.metabolites, self.X, self.B, - self.U, self.Ubias, self.V, self.Vbias) = res - num_train = 10 - self.trainX = self.microbes.iloc[:-num_train] - self.testX = self.microbes.iloc[-num_train:] - self.trainY = self.metabolites.iloc[:-num_train] - self.testY = self.metabolites.iloc[-num_train:] + #def setUp(self): + # # build small simulation + # np.random.seed(1) + # res = random_multimodal( + # num_microbes=8, num_metabolites=8, num_samples=150, + # latent_dim=2, sigmaQ=2, + # microbe_total=1000, metabolite_total=10000, seed=1 + # ) + # (self.microbes, self.metabolites, self.X, self.B, + # self.U, self.Ubias, self.V, self.Vbias) = res + # num_train = 10 + # self.trainX = self.microbes.iloc[:-num_train] + # self.testX = self.microbes.iloc[-num_train:] + # self.trainY = self.metabolites.iloc[:-num_train] + # self.testY = self.metabolites.iloc[-num_train:] + def setUp(self): np.random.seed(1) res = random_multimodal( diff --git a/mmvec/train.py b/mmvec/train.py index c6249b7..f6b7123 100644 --- a/mmvec/train.py +++ b/mmvec/train.py @@ -36,7 +36,3 @@ def mmvec_training_loop(model, learning_rate, batch_size, epochs, replacement=True).T cv_loss = model(cv_draw, model.metabolites_test) yield (str(iteration), loss.item(), cv_loss.item()) - - else: - yield (str(iteration), loss.item(), None) -