Skip to content

Commit

Permalink
修改: doc/recipes/vision/models.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxinwei committed Nov 23, 2023
1 parent 6d3095c commit 265c0af
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 92 deletions.
25 changes: 1 addition & 24 deletions doc/recipes/vision/models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -175,29 +175,6 @@
"text": [
"Downloading: \"https://github.com/pytorch/vision/zipball/main\" to /home/ai/.cache/torch/hub/main.zip\n"
]
},
{
"ename": "RemoteDisconnected",
"evalue": "Remote end closed connection without response",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRemoteDisconnected\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb 单元格 13\u001b[0m line \u001b[0;36m4\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtorch\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39m# Option 1: passing weights param as string\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m model \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mhub\u001b[39m.\u001b[39;49mload(\u001b[39m\"\u001b[39;49m\u001b[39mpytorch/vision\u001b[39;49m\u001b[39m\"\u001b[39;49m, \u001b[39m\"\u001b[39;49m\u001b[39mresnet50\u001b[39;49m\u001b[39m\"\u001b[39;49m, weights\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mIMAGENET1K_V2\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a>\u001b[0m \u001b[39m# Option 2: passing weights param as enum\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/vision/models.ipynb#X24sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a>\u001b[0m weights \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mhub\u001b[39m.\u001b[39mload(\u001b[39m\"\u001b[39m\u001b[39mpytorch/vision\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mget_weight\u001b[39m\u001b[39m\"\u001b[39m, weights\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mResNet50_Weights.IMAGENET1K_V2\u001b[39m\u001b[39m\"\u001b[39m)\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/hub.py:563\u001b[0m, in \u001b[0;36mload\u001b[0;34m(repo_or_dir, model, source, trust_repo, force_reload, verbose, skip_validation, *args, **kwargs)\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 560\u001b[0m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mUnknown source: \u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00msource\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m. Allowed values: \u001b[39m\u001b[39m\"\u001b[39m\u001b[39mgithub\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m | \u001b[39m\u001b[39m\"\u001b[39m\u001b[39mlocal\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m 562\u001b[0m \u001b[39mif\u001b[39;00m source \u001b[39m==\u001b[39m \u001b[39m'\u001b[39m\u001b[39mgithub\u001b[39m\u001b[39m'\u001b[39m:\n\u001b[0;32m--> 563\u001b[0m repo_or_dir \u001b[39m=\u001b[39m _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, \u001b[39m\"\u001b[39;49m\u001b[39mload\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 564\u001b[0m verbose\u001b[39m=\u001b[39;49mverbose, skip_validation\u001b[39m=\u001b[39;49mskip_validation)\n\u001b[1;32m 566\u001b[0m model \u001b[39m=\u001b[39m _load_local(repo_or_dir, model, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 567\u001b[0m \u001b[39mreturn\u001b[39;00m model\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/hub.py:238\u001b[0m, in \u001b[0;36m_get_cache_or_reload\u001b[0;34m(github, force_reload, trust_repo, calling_fn, verbose, skip_validation)\u001b[0m\n\u001b[1;32m 236\u001b[0m url \u001b[39m=\u001b[39m _git_archive_link(repo_owner, repo_name, ref)\n\u001b[1;32m 237\u001b[0m sys\u001b[39m.\u001b[39mstderr\u001b[39m.\u001b[39mwrite(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mDownloading: \u001b[39m\u001b[39m\\\"\u001b[39;00m\u001b[39m{\u001b[39;00murl\u001b[39m}\u001b[39;00m\u001b[39m\\\"\u001b[39;00m\u001b[39m to \u001b[39m\u001b[39m{\u001b[39;00mcached_file\u001b[39m}\u001b[39;00m\u001b[39m\\n\u001b[39;00m\u001b[39m'\u001b[39m)\n\u001b[0;32m--> 238\u001b[0m download_url_to_file(url, cached_file, progress\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)\n\u001b[1;32m 239\u001b[0m \u001b[39mexcept\u001b[39;00m HTTPError \u001b[39mas\u001b[39;00m err:\n\u001b[1;32m 240\u001b[0m \u001b[39mif\u001b[39;00m err\u001b[39m.\u001b[39mcode \u001b[39m==\u001b[39m \u001b[39m300\u001b[39m:\n\u001b[1;32m 241\u001b[0m \u001b[39m# Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch\u001b[39;00m\n\u001b[1;32m 242\u001b[0m \u001b[39m# in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m# See https://git-scm.com/book/en/v2/Git-Internals-Git-References\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m# Here, we do the same as git: we throw a warning, and assume the user wanted the branch\u001b[39;00m\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/hub.py:620\u001b[0m, in \u001b[0;36mdownload_url_to_file\u001b[0;34m(url, dst, hash_prefix, progress)\u001b[0m\n\u001b[1;32m 618\u001b[0m file_size \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 619\u001b[0m req \u001b[39m=\u001b[39m Request(url, headers\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39mUser-Agent\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mtorch.hub\u001b[39m\u001b[39m\"\u001b[39m})\n\u001b[0;32m--> 620\u001b[0m u \u001b[39m=\u001b[39m urlopen(req)\n\u001b[1;32m 621\u001b[0m meta \u001b[39m=\u001b[39m u\u001b[39m.\u001b[39minfo()\n\u001b[1;32m 622\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(meta, \u001b[39m'\u001b[39m\u001b[39mgetheaders\u001b[39m\u001b[39m'\u001b[39m):\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:216\u001b[0m, in \u001b[0;36murlopen\u001b[0;34m(url, data, timeout, cafile, capath, cadefault, context)\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 215\u001b[0m opener \u001b[39m=\u001b[39m _opener\n\u001b[0;32m--> 216\u001b[0m \u001b[39mreturn\u001b[39;00m opener\u001b[39m.\u001b[39;49mopen(url, data, timeout)\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:519\u001b[0m, in \u001b[0;36mOpenerDirector.open\u001b[0;34m(self, fullurl, data, timeout)\u001b[0m\n\u001b[1;32m 516\u001b[0m req \u001b[39m=\u001b[39m meth(req)\n\u001b[1;32m 518\u001b[0m sys\u001b[39m.\u001b[39maudit(\u001b[39m'\u001b[39m\u001b[39murllib.Request\u001b[39m\u001b[39m'\u001b[39m, req\u001b[39m.\u001b[39mfull_url, req\u001b[39m.\u001b[39mdata, req\u001b[39m.\u001b[39mheaders, req\u001b[39m.\u001b[39mget_method())\n\u001b[0;32m--> 519\u001b[0m response \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_open(req, data)\n\u001b[1;32m 521\u001b[0m \u001b[39m# post-process response\u001b[39;00m\n\u001b[1;32m 522\u001b[0m meth_name \u001b[39m=\u001b[39m protocol\u001b[39m+\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m_response\u001b[39m\u001b[39m\"\u001b[39m\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:536\u001b[0m, in \u001b[0;36mOpenerDirector._open\u001b[0;34m(self, req, data)\u001b[0m\n\u001b[1;32m 533\u001b[0m \u001b[39mreturn\u001b[39;00m result\n\u001b[1;32m 535\u001b[0m protocol \u001b[39m=\u001b[39m req\u001b[39m.\u001b[39mtype\n\u001b[0;32m--> 536\u001b[0m result \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_call_chain(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mhandle_open, protocol, protocol \u001b[39m+\u001b[39;49m\n\u001b[1;32m 537\u001b[0m \u001b[39m'\u001b[39;49m\u001b[39m_open\u001b[39;49m\u001b[39m'\u001b[39;49m, req)\n\u001b[1;32m 538\u001b[0m \u001b[39mif\u001b[39;00m result:\n\u001b[1;32m 539\u001b[0m \u001b[39mreturn\u001b[39;00m result\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:496\u001b[0m, in \u001b[0;36mOpenerDirector._call_chain\u001b[0;34m(self, chain, kind, meth_name, *args)\u001b[0m\n\u001b[1;32m 494\u001b[0m \u001b[39mfor\u001b[39;00m handler \u001b[39min\u001b[39;00m handlers:\n\u001b[1;32m 495\u001b[0m func \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(handler, meth_name)\n\u001b[0;32m--> 496\u001b[0m result \u001b[39m=\u001b[39m func(\u001b[39m*\u001b[39;49margs)\n\u001b[1;32m 497\u001b[0m \u001b[39mif\u001b[39;00m result \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 498\u001b[0m \u001b[39mreturn\u001b[39;00m result\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:1391\u001b[0m, in \u001b[0;36mHTTPSHandler.https_open\u001b[0;34m(self, req)\u001b[0m\n\u001b[1;32m 1390\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mhttps_open\u001b[39m(\u001b[39mself\u001b[39m, req):\n\u001b[0;32m-> 1391\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdo_open(http\u001b[39m.\u001b[39;49mclient\u001b[39m.\u001b[39;49mHTTPSConnection, req,\n\u001b[1;32m 1392\u001b[0m context\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_context, check_hostname\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_check_hostname)\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/urllib/request.py:1352\u001b[0m, in \u001b[0;36mAbstractHTTPHandler.do_open\u001b[0;34m(self, http_class, req, **http_conn_args)\u001b[0m\n\u001b[1;32m 1350\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mOSError\u001b[39;00m \u001b[39mas\u001b[39;00m err: \u001b[39m# timeout error\u001b[39;00m\n\u001b[1;32m 1351\u001b[0m \u001b[39mraise\u001b[39;00m URLError(err)\n\u001b[0;32m-> 1352\u001b[0m r \u001b[39m=\u001b[39m h\u001b[39m.\u001b[39;49mgetresponse()\n\u001b[1;32m 1353\u001b[0m \u001b[39mexcept\u001b[39;00m:\n\u001b[1;32m 1354\u001b[0m h\u001b[39m.\u001b[39mclose()\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/http/client.py:1375\u001b[0m, in \u001b[0;36mHTTPConnection.getresponse\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1373\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 1374\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1375\u001b[0m response\u001b[39m.\u001b[39;49mbegin()\n\u001b[1;32m 1376\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mConnectionError\u001b[39;00m:\n\u001b[1;32m 1377\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclose()\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/http/client.py:318\u001b[0m, in \u001b[0;36mHTTPResponse.begin\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[39m# read until we get a non-100 response\u001b[39;00m\n\u001b[1;32m 317\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mTrue\u001b[39;00m:\n\u001b[0;32m--> 318\u001b[0m version, status, reason \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_read_status()\n\u001b[1;32m 319\u001b[0m \u001b[39mif\u001b[39;00m status \u001b[39m!=\u001b[39m CONTINUE:\n\u001b[1;32m 320\u001b[0m \u001b[39mbreak\u001b[39;00m\n",
"File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/http/client.py:287\u001b[0m, in \u001b[0;36mHTTPResponse._read_status\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 283\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mreply:\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mrepr\u001b[39m(line))\n\u001b[1;32m 284\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m line:\n\u001b[1;32m 285\u001b[0m \u001b[39m# Presumably, the server closed the connection before\u001b[39;00m\n\u001b[1;32m 286\u001b[0m \u001b[39m# sending a valid response.\u001b[39;00m\n\u001b[0;32m--> 287\u001b[0m \u001b[39mraise\u001b[39;00m RemoteDisconnected(\u001b[39m\"\u001b[39m\u001b[39mRemote end closed connection without\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 288\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m response\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 289\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 290\u001b[0m version, status, reason \u001b[39m=\u001b[39m line\u001b[39m.\u001b[39msplit(\u001b[39mNone\u001b[39;00m, \u001b[39m2\u001b[39m)\n",
"\u001b[0;31mRemoteDisconnected\u001b[0m: Remote end closed connection without response"
]
}
],
"source": [
Expand Down
87 changes: 19 additions & 68 deletions doc/recipes/vision/plot_transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -276,87 +276,38 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We passed a tuple so we get a tuple back, and the second element is the\n",
"tranformed target dict. Transforms don't really care about the structure of\n",
"the input; as mentioned above, they only care about the **type** of the\n",
"objects and transforms them accordingly.\n",
"我们传递了一个元组,所以我们得到了一个元组作为输出,第二个元素是转换后的目标字典。转换并不真正关心输入的结构;如上所述,它们只关心对象的类型并相应地进行转换。\n",
"\n",
"*Foreign* objects like strings or ints are simply passed-through. This can be\n",
"useful e.g. if you want to associate a path with every single sample when\n",
"debugging!\n",
"\n",
"\n",
"<div class=\"alert alert-info\"><h4>Note</h4><p>**Disclaimer** This note is slightly advanced and can be safely skipped on\n",
" a first read.\n",
"\n",
" Pure :class:`torch.Tensor` objects are, in general, treated as images (or\n",
" as videos for video-specific transforms). Indeed, you may have noticed\n",
" that in the code above we haven't used the\n",
" :class:`~torchvision.tv_tensors.Image` class at all, and yet our images\n",
" got transformed properly. Transforms follow the following logic to\n",
" determine whether a pure Tensor should be treated as an image (or video),\n",
" or just ignored:\n",
"\n",
" * If there is an :class:`~torchvision.tv_tensors.Image`,\n",
" :class:`~torchvision.tv_tensors.Video`,\n",
" or :class:`PIL.Image.Image` instance in the input, all other pure\n",
" tensors are passed-through.\n",
" * If there is no :class:`~torchvision.tv_tensors.Image` or\n",
" :class:`~torchvision.tv_tensors.Video` instance, only the first pure\n",
" :class:`torch.Tensor` will be transformed as image or video, while all\n",
" others will be passed-through. Here \"first\" means \"first in a depth-wise\n",
" traversal\".\n",
"\n",
" This is what happened in the detection example above: the first pure\n",
" tensor was the image so it got transformed properly, and all other pure\n",
" tensor instances like the ``labels`` were passed-through (although labels\n",
" can still be transformed by some transforms like\n",
" :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`!).</p></div>\n",
"\n",
"\n",
"## Transforms and Datasets intercompatibility\n",
"\n",
"Roughly speaking, the output of the datasets must correspond to the input of\n",
"the transforms. How to do that depends on whether you're using the torchvision\n",
"`built-in datatsets <datasets>`, or your own custom datasets.\n",
"对于像字符串或整数这样的外部对象,它们会原样传递。这在调试时可能会很有用,例如,如果你想为每个样本关联一个路径!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 转换和数据集的兼容性\n",
"\n",
"### Using built-in datasets\n",
"粗略地说,数据集的输出必须与转换的输入相对应。如何做到这一点取决于您是使用 torchvision 内置的数据集还是自己的自定义数据集。\n",
"\n",
"If you're just doing image classification, you don't need to do anything. Just\n",
"use ``transform`` argument of the dataset e.g. ``ImageNet(...,\n",
"transform=transforms)`` and you're good to go.\n",
"### 内置的数据集\n",
"\n",
"Torchvision also supports datasets for object detection or segmentation like\n",
":class:`torchvision.datasets.CocoDetection`. Those datasets predate\n",
"the existence of the :mod:`torchvision.transforms.v2` module and of the\n",
"TVTensors, so they don't return TVTensors out of the box.\n",
"如果您只是进行图像分类,则无需执行任何操作。只需使用数据集的 `transform` 参数,例如 ``ImageNet(..., transform=transforms)``,然后就可以开始了。\n",
"\n",
"An easy way to force those datasets to return TVTensors and to make them\n",
"compatible with v2 transforms is to use the\n",
":func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:\n",
"Torchvision 还支持像 {class}`torchvision.datasets.CocoDetection` 这样的目标检测或分割数据集。这些数据集比 {mod}`torchvision.transforms.v2` 模块和 TVTensors 更早出现,因此它们不会自动返回 TVTensors。强制使这些数据集返回 TVTensors 并使它们与 v2 转换兼容的简单方法是使用 {func}`torchvision.datasets.wrap_dataset_for_transforms_v2` 函数:\n",
"\n",
"```python\n",
"from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2\n",
"\n",
"dataset = CocoDetection(..., transforms=my_transforms)\n",
"dataset = wrap_dataset_for_transforms_v2(dataset)\n",
"# Now the dataset returns TVTensors!\n",
"```\n",
"### Using your own datasets\n",
"\n",
"If you have a custom dataset, then you'll need to convert your objects into\n",
"the appropriate TVTensor classes. Creating TVTensor instances is very easy,\n",
"refer to `tv_tensor_creation` for more details.\n",
"\n",
"There are two main places where you can implement that conversion logic:\n",
"\n",
"- At the end of the datasets's ``__getitem__`` method, before returning the\n",
" sample (or by sub-classing the dataset).\n",
"- As the very first step of your transforms pipeline\n",
"\n",
"Either way, the logic will depend on your specific dataset.\n",
"\n"
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 265c0af

Please sign in to comment.