diff --git a/notebooks/03_sgd_from_scratch.ipynb b/notebooks/03_sgd_from_scratch.ipynb index ad66cb4..0916a0f 100644 --- a/notebooks/03_sgd_from_scratch.ipynb +++ b/notebooks/03_sgd_from_scratch.ipynb @@ -1,4097 +1,4101 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "D7ExutUSGoY-" - }, - "source": [ - "# บทที่ 3 - Stochastic Gradient Descent ตั้งแต่เริ่มต้น" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ai-builders/curriculum/blob/main/notebooks/03_sgd_from_scratch.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jUUXviwGayJP" - }, - "source": [ - "ในบทเรียนนี้ เราจะทำการสร้างวิธีที่โมเดลของเราเรียนรู้ในบทเรียนที่แล้วๆมา เรียกว่า stochastic gradient descent ขึ้นมาเองตั้งแต่ต้นโดยใช้เพียงแค่ Pytorch สำหรับ linear algebra และการทำ partial derivatives เท่านั้น ด้วยตัวอย่างการจำแนกรูปภาพตัวเลข 3 และ 7 ออกจากกัน\n", - "\n", - "บทเรียนแปล-สรุปมาจาก [04_mnist_basics.ipynb](https://github.com/fastai/fastbook/blob/master/04_mnist_basics.ipynb) ของ [fastai](https://course.fast.ai/) ผู้ที่สนใจสามารถไปติดตามบทเรียนต้นทางได้ที่ [course.fast.ai](https://course.fast.ai/)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "6IlUoU8YGlQ7", - "outputId": "2aae9add-1fe0-40fe-ed88-c7b83114ce47" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[K |████████████████████████████████| 720 kB 26.4 MB/s \n", - "\u001b[K |████████████████████████████████| 48 kB 5.5 MB/s \n", - "\u001b[K |████████████████████████████████| 1.2 MB 46.0 MB/s \n", - "\u001b[K |████████████████████████████████| 189 kB 63.3 MB/s \n", - "\u001b[K |████████████████████████████████| 56 kB 5.0 MB/s \n", - "\u001b[K |████████████████████████████████| 51 kB 310 kB/s \n", - "\u001b[K |████████████████████████████████| 558 kB 63.9 MB/s \n", - "\u001b[K |████████████████████████████████| 130 kB 55.7 MB/s \n", - "\u001b[?25h" - ] - } - ], - "source": [ - "#ติดตั้ง fastai\n", - "!pip install -q fastbook\n", - "import fastbook" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "-2VyXV7kpoju" - }, - "outputs": [], - "source": [ - "from fastai.vision.all import *\n", - "from fastbook import *\n", - "import torch\n", - "\n", - "matplotlib.rc('image', cmap='Greys')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gBWvqPiOpoju" - }, - "source": [ - "# เทรนโมเดลจำแนกรูปเลข 3 และเลข 7 จาก [ชุดข้อมูล MNIST](http://yann.lecun.com/exdb/mnist/)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_5WWrXnD6Vpj" - }, - "source": [ - "## โหลดข้อมูลรูปเลข 3 และเลข 7" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 37 - }, - "id": "wkh3fHu5pojw", - "outputId": "d204e01b-6cb2-4ad4-f8de-119b5bcf34b4" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " 100.14% [3219456/3214948 00:00<00:00]\n", - "
\n", - " " - ], - "text/plain": [ - "" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "D7ExutUSGoY-" + }, + "source": [ + "# บทที่ 3 - Stochastic Gradient Descent ตั้งแต่เริ่มต้น" ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "path = untar_data(URLs.MNIST_SAMPLE)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "Z6OjxHEPpojw", - "outputId": "d76dc5d8-0237-45e4-9ece-b94ee12f6bbc" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(#3) [Path('/root/.fastai/data/mnist_sample/valid'),Path('/root/.fastai/data/mnist_sample/train'),Path('/root/.fastai/data/mnist_sample/labels.csv')]" + "cell_type": "markdown", + "metadata": { + "id": "GtlT-MCeHftk" + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ai-builders/curriculum/blob/main/notebooks/03_sgd_from_scratch.ipynb)" ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#แบ่งเป็น train, validation, test \n", - "path.ls()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "mMmCdDscpojw", - "outputId": "68396480-226d-41b6-e0d0-81977e28ba38" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(#2) [Path('/root/.fastai/data/mnist_sample/train/3'),Path('/root/.fastai/data/mnist_sample/train/7')]" + "cell_type": "markdown", + "metadata": { + "id": "jUUXviwGayJP" + }, + "source": [ + "ในบทเรียนนี้ เราจะทำการสร้างวิธีที่โมเดลของเราเรียนรู้ในบทเรียนที่แล้วๆมา เรียกว่า stochastic gradient descent ขึ้นมาเองตั้งแต่ต้นโดยใช้เพียงแค่ Pytorch สำหรับ linear algebra และการทำ partial derivatives เท่านั้น ด้วยตัวอย่างการจำแนกรูปภาพตัวเลข 3 และ 7 ออกจากกัน\n", + "\n", + "บทเรียนแปล-สรุปมาจาก [04_mnist_basics.ipynb](https://github.com/fastai/fastbook/blob/master/04_mnist_basics.ipynb) ของ [fastai](https://course.fast.ai/) ผู้ที่สนใจสามารถไปติดตามบทเรียนต้นทางได้ที่ [course.fast.ai](https://course.fast.ai/)" ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ในแต่ละ set จะมีเลข 3 และ 7\n", - "(path/'train').ls()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "EilsFTaWpojx", - "outputId": "6da64090-f81d-46b1-ae05-86a5b3fd4187" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(#6131) [Path('/root/.fastai/data/mnist_sample/train/3/10.png'),Path('/root/.fastai/data/mnist_sample/train/3/10000.png'),Path('/root/.fastai/data/mnist_sample/train/3/10011.png'),Path('/root/.fastai/data/mnist_sample/train/3/10031.png'),Path('/root/.fastai/data/mnist_sample/train/3/10034.png'),Path('/root/.fastai/data/mnist_sample/train/3/10042.png'),Path('/root/.fastai/data/mnist_sample/train/3/10052.png'),Path('/root/.fastai/data/mnist_sample/train/3/1007.png'),Path('/root/.fastai/data/mnist_sample/train/3/10074.png'),Path('/root/.fastai/data/mnist_sample/train/3/10091.png')...]" + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6IlUoU8YGlQ7", + "outputId": "96ea0400-f259-4d81-dcbf-4a7665998010" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[K |████████████████████████████████| 720 kB 12.8 MB/s \n", + "\u001b[K |████████████████████████████████| 188 kB 49.3 MB/s \n", + "\u001b[K |████████████████████████████████| 1.2 MB 45.6 MB/s \n", + "\u001b[K |████████████████████████████████| 60 kB 5.6 MB/s \n", + "\u001b[?25h" + ] + } + ], + "source": [ + "#ติดตั้ง fastai\n", + "!pip install -q fastbook\n", + "import fastbook" ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ในละ folder 3 และ 7 จะเป็นไฟล์รูป\n", - "threes = (path/'train'/'3').ls().sorted()\n", - "sevens = (path/'train'/'7').ls().sorted()\n", - "threes" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "2nKh8v3OBG5G", - "outputId": "bf8f9a5e-c62c-48ec-a199-83ac22bac45e" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(6131, 6265)" + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "-2VyXV7kpoju" + }, + "outputs": [], + "source": [ + "from fastai.vision.all import *\n", + "from fastbook import *\n", + "import torch\n", + "\n", + "matplotlib.rc('image', cmap='Greys')\n", + "\n", + "#fix plot_function as new pytorch requires steps argument for torch.linspace\n", + "def plot_function(f, tx=None, ty=None, title=None, min=-2, max=2, figsize=(6,4)):\n", + " x = torch.linspace(min, max, steps=100)\n", + " fig,ax = plt.subplots(figsize=figsize)\n", + " ax.plot(x,f(x))" ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#จำนวนรูปในแต่ละ class\n", - "len(threes), len(sevens)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 45 - }, - "id": "jJEUNNQKpojx", - "outputId": "c84f0303-b2e4-47e5-c7bd-a823c28f4b71" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA3UlEQVR4nGNgGNxA1XnOxX//vxlgkVIqe//379+///4+EkCX4lv//e+/v38vnf33968wumTd379/rzUZ8lf9e1XChC4Zf39RAAMDA+eFf2txuIg18htuObW/f9/y45CM/Pt3iwYOuaxvv7Zz4pBjmPvvoxwSF9XNrxhZ/XFpZOCf9+tjMk5Zhsl/d+AyloHhI4MsTo22T//uF8cu5Xbo/79/8eiiXGVlyqyOk3/9/fe7nRld0urv3/uH/v79+/d1GKZ5In///v339+/zDdiClbVu9sF/X2cr4XQoVQEA4o1d+YAEyFQAAAAASUVORK5CYII=\n", - "text/plain": [ - "" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "im3_path = threes[6000]\n", - "im3 = Image.open(im3_path)\n", - "im3" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "GBYX4Mf0BTb6", - "outputId": "02cc967e-9295-4da6-806b-8b3e58ca3ae3" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(28, 28)" + "cell_type": "markdown", + "metadata": { + "id": "gBWvqPiOpoju" + }, + "source": [ + "# เทรนโมเดลจำแนกรูปเลข 3 และเลข 7 จาก [ชุดข้อมูล MNIST](http://yann.lecun.com/exdb/mnist/)" ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#รูปขนาด 28 x 28 pixels \n", - "array(im3).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RZphr7qzpojx", - "outputId": "257cba1a-9290-4a12-98f8-f3d9d5d9a2dc" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[0, 0, 0, 0, 0, 0],\n", - " [0, 0, 0, 0, 0, 0],\n", - " [0, 0, 0, 0, 0, 0],\n", - " [0, 0, 0, 0, 0, 0],\n", - " [0, 0, 0, 0, 0, 0],\n", - " [0, 0, 0, 0, 0, 0]], dtype=uint8)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#เปลี่ยนจาก numpy array\n", - "array(im3)[4:10,4:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "o52Dig-Vpojx", - "outputId": "c2aae0e8-5e86-4eeb-88a8-180112bf9675" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[0, 0, 0, 0, 0, 0],\n", - " [0, 0, 0, 0, 0, 0],\n", - " [0, 0, 0, 0, 0, 0],\n", - " [0, 0, 0, 0, 0, 0],\n", - " [0, 0, 0, 0, 0, 0],\n", - " [0, 0, 0, 0, 0, 0]], dtype=torch.uint8)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#เป็น torch tensor\n", - "tensor(im3)[4:10,4:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 394 - }, - "id": "0yZCaBRNpojy", - "outputId": "f4d506e0-bf3b-473f-fe3d-bf17277cd1b4" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
 01234567891011121314151617
0000000000000000000
10000000000376715620925425524648
20000000034118239253253253253254253226
300000014175247253254253253210205254253253
40000001262532532532141304915122254234116
5000000952231628000092082541730
60000000000058924625417300
700000000005382532532371500
800000000008925325318040000
900000000010624625018390000
100000000001572542413000000
\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "im3_t = tensor(im3)\n", - "df = pd.DataFrame(im3_t[4:15,4:22])\n", - "df.style.set_properties(**{'font-size':'6pt'}).background_gradient('Greys')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jMdNhmiXpojy" - }, - "source": [ - "## วิธีที่ง่ายที่สุด: ดูว่ามี pixel เหมือนกันแค่ไหน" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MqYq_GbFT7XA" - }, - "source": [ - "สมมุติเราไม่ทำ ML อะไรเลย แต่เราพยามสร้างกฎแบบ Rule-based Systems โดยบอกว่า \"รูปที่มี pixels ใกล้เคียงกับ pixels เฉลี่ยของเลข 3 และ 7 มากกว่า ให้ถือว่าเป็นเลขนั้น\"" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "JDsXZWi8pojy", - "outputId": "deb4fb4e-a3dd-44ce-d5c3-93272004588b" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(6131, 6265)" + "cell_type": "markdown", + "metadata": { + "id": "_5WWrXnD6Vpj" + }, + "source": [ + "## โหลดข้อมูลรูปเลข 3 และเลข 7" ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "seven_tensors = [tensor(Image.open(o)) for o in sevens]\n", - "three_tensors = [tensor(Image.open(o)) for o in threes]\n", - "len(three_tensors),len(seven_tensors)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 85 - }, - "id": "9h_vpbbvpojy", - "outputId": "cadff4bb-f297-4f52-e3ec-17508f92f686" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAJHElEQVR4nO2bXXMSZxuAL1jYXYQsiCYxIWIIjImJ0TaVTu2H41FnnGmPPOtMf0NP+i/6H9oDx+OOOtMjW6cdGz/Sk1ZqxkRDhCTQEAjfsLDsvgeWbbMmxgrEzDtcR5ln+bi5eJ5n7/t+iM0wDPr8g/1tB3DY6Aux0BdioS/EQl+IBcc+1/+fb0G23Qb7M8RCX4iFvhALfSEW+kIs9IVY6Aux0BdiYb/ErGMMw6DVatFsNqlUKmiaRrPZpNls0mg0Xnq8LMtIkoSu6+i6jtPpxOFw4HQ6EQQBWZZxOHoXds+FaJpGvV4nHo/zww8/kMlkSKVSxONxnjx5wr/7MTabjffee4+pqSmq1SqqqjIyMsKxY8eIRCIMDw8zOzuLz+frWbxdF6LrujkDisUilUqFbDbLn3/+yeLiIoVCgfX1dTKZDJVKBVEUEUWRer1OrVZjZWUFh8NhCqlUKmxublIulzl+/DgTExM9FWLbp2P2n2sZVVXZ3t5mdXWV77//nlQqxaNHj8jlcqTTaQzDwDAMZFnG7XYzODhIIBDgyZMnPH/+HLvdjt1uN2eOzWbDZrPh8/nwer1cv36daDT6hh93B7vWMh3PEE3TqNVq1Go1UqkUlUqFVCrFysoKT58+pVQqIQgCExMTRKNRnE4nkiQhiiIulwtFUfB6vczOzpLJZFheXubZs2eUy2VqtZr5Po1GA1VV0XW905BfScdCVFUlFovx22+/8c0331CpVGg2m7RaLRqNBoFAgGg0ysWLF7l69Soejwe3220+3263Y7PZ0HUdwzC4ceMG165dIxaLkUwmOw3vP9OxEMMwaDQaVKtVSqUS1WoVXdex2+2IokggEGBubo5z587h8/lwOp04nU7z+e0l8e+lJAgCNtvOGe33+wkGg8iy3GnIr6QrQqrVKrVaDVVV0TQNAFEUOXr0KHNzc3zxxRf4fD4URXnpg7ZpjzudTux2+0vXpqenmZmZQVGUTkN+JR0LcTqdhEIhRFEkn8/TbDaBF0Lcbjdzc3N4vV5EUdxTBrzYizRNI5/Pk8/nzRzFZrMhCALDw8OEw2FcLlenIb+SjoVIksTp06cJh8N88MEH5nh7KQiCgCiK+75Oo9GgVCqxtrZGMpmkWq0CmM+PRCJEo9Ed+08v6FhI+1sXBGHH3tC+Zp3+e7G+vs7PP//M77//TqlUQtM0BEEgHA4zPj7OzMwMJ06ceOk9uk1XErP2bHidmbAXd+7c4auvvkLTNHRdx+FwIIoiFy9eJBqN8u6773LixIluhPtKep66WzEMA13XqdfrlMtlM2FbWFgwN2S73U4kEiESifDhhx8SjUZ7vpm2OXAhuq6jaRrZbJbHjx/z448/cvPmTbLZrHm7djgcXLhwgY8//phPP/2UkydPHlh8B1LtGoZBsVhkdXXVrGWSySSrq6ssLS2Ry+Wo1+vAi3zD7/dz+vRpzp8/z8DAQK9D3MGBlf/JZJLvvvuO5eVlHjx4gKqqO1LzNkNDQ0xPT3PhwgUmJyd7fpu1ciBLRtd1CoUCjx49Yn19HVVVzXzFSiaTIRaLcevWLeLxOMeOHcPj8TA2NobX62VwcLCnkg5syWxtbfHgwQM0TTM31t3IZDJkMhnW19dxu90oioLH4+HKlSucP3+eS5cuIcvyK5O8Tuh6+f/SC/y9ZDY2Nrh9+zblcplCoUCxWCSXy5mPi8fjLC0tmT0UURRxOp1mv2RqaorR0VE++ugjzpw5w9mzZ/F6vQiC8Nq5joVdjfZciJVarWbKWFtbM8d/+uknfvnlF1ZXV0mn07s+t91Ri0QifP3110xOTiJJEoIgvEkovemH/FecTieKoiDL8o7O19DQEJcuXWJtbY10Ok0mkyGXy3H//n3i8TjwYrYlEgmq1SrPnz9ncHCQ48ePv6mQXTlwIQ6HA4fDgcvlwuv1muPDw8NMT09TrVbNW/TKygrZbNYUArC5uUkulyMejxMKhTh69Gh34+vqq3VAuxBsd9VlWSYYDBKPx/nrr79IJBJsb28DL2bKxsYG8XicU6dOdTWOQ3Mu0y4EJUkye63BYJDZ2VmmpqZ2pO6GYZgzR1XVrsZxaITsRbtwtI55vd6eVL+HXgjAbndCt9uN3+/v6oYKh2gPsVIsFtne3mZhYYGHDx/uyFlsNhunTp0iEol01HLYjUMppF0MJpNJEokEiURiR2YrCAJDQ0P4/f6uH2seOiHVapVKpcK9e/e4c+cOf/zxh3lEATA5Ocn4+DiBQABZlt80S92TQyOk/YHr9TrZbJZYLMb8/DypVMq8ZrfbCQaDRCIRvF4vDoej6zXNoRFSKBRIp9Pcvn2bu3fv8vjxY5LJpNknGRgYwO12c/XqVS5fvszo6Oiu5zed8laFtCthwzDI5/M8ffqUhw8fcuPGDbO3Ci820SNHjjA4OMi5c+cIhUI9kQFvUYiqqqiqysbGBktLS9y9e5f5+XkSiQTNZtNcJi6XC1mW+fLLL/nkk08Ih8M9kwFdFvLvb9yaULXH2383Gg0KhQLxeJz5+XkWFha4d++e+fj2me/AwACKonD27FneeecdPB5Pz2RAF4Vomka1WqVcLrO2toaiKIyMjJiH3pVKha2tLUqlEpubmywvL7O4uGjWKfl8HvgnM41EIoTDYa5cuUI0GiUUCqEoSk9/PQRdFNJqtSiVSmxtbRGLxRgdHUWSJPMXRLlcjuXlZba2tsxl0u6tqqq645RPFEXGx8cJh8NEo1FmZmbMhlGv6ZqQfD7Pt99+SzKZ5P79++Zhd6vVotVqmecwqqqaM6ZSqZgb5/DwMGNjY7z//vtmk/nkyZMoioIkSV3PN/aia0IajQaJRIJnz56xuLj4Wj9ssdlsSJKEJEmMjY0RiUQ4c+YM0WiUiYkJ/H5/t8J7bbomxOPxcPnyZTweD7/++uu+QmRZ5siRI3z++ed89tlnhEIhRkZGcLlcB7Y8dqNrQhwOB8FgkHQ6TSAQMI8l98LlcuHz+ZicnGR2dpahoaEdHbS3RdeazK1WyzxvKRaL+7/x3w0ht9ttdsm6XcrvF8KugwfddT9E9P+j6nXoC7HQF2KhL8TCfrfd3lVRh5T+DLHQF2KhL8RCX4iFvhALfSEW/gcMlBno19ugeQAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "show_image(three_tensors[1]);" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "D5B7mZNspojy", - "outputId": "bc93149a-adf0-42f7-fe24-bd1dc286b5e9" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(torch.Size([6131, 28, 28]), torch.Size([6265, 28, 28]))" + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 37 + }, + "id": "wkh3fHu5pojw", + "outputId": "c0fe5213-0052-462a-b8ee-84882e0c73cd" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + " 100.14% [3219456/3214948 00:00<00:00]\n", + "
\n", + " " + ] + }, + "metadata": {} + } + ], + "source": [ + "path = untar_data(URLs.MNIST_SAMPLE)" ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#เรามัดเลข 3 และ 7 ทั้งหมดรวมกันตามคลาสแล้วหารมันด้วย 255 เพื่อให้ได้ค่าระหว่าง 0 และ 1\n", - "stacked_sevens = torch.stack(seven_tensors).float()/255\n", - "stacked_threes = torch.stack(three_tensors).float()/255\n", - "stacked_threes.shape, stacked_sevens.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "K4mmFyngpojz", - "outputId": "4b1ba479-b693-4f5a-bcd3-260a3283ed07" - }, - "outputs": [ { - "data": { - "text/plain": [ - "3" + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Z6OjxHEPpojw", + "outputId": "b1c65a6c-315f-4233-fe92-215e817ecdb3" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(#3) [Path('/root/.fastai/data/mnist_sample/valid'),Path('/root/.fastai/data/mnist_sample/labels.csv'),Path('/root/.fastai/data/mnist_sample/train')]" + ] + }, + "metadata": {}, + "execution_count": 4 + } + ], + "source": [ + "#แบ่งเป็น train, validation, test \n", + "path.ls()" ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(stacked_threes.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "Jck_enHBpojz", - "outputId": "1bc5de1f-50f5-45cd-ed75-65f44ac5c8fe" - }, - "outputs": [ { - "data": { - "text/plain": [ - "3" + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mMmCdDscpojw", + "outputId": "fc98608c-0131-4be1-d56e-922e924a03af" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(#2) [Path('/root/.fastai/data/mnist_sample/train/7'),Path('/root/.fastai/data/mnist_sample/train/3')]" + ] + }, + "metadata": {}, + "execution_count": 5 + } + ], + "source": [ + "#ในแต่ละ set จะมีเลข 3 และ 7\n", + "(path/'train').ls()" ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "stacked_threes.ndim" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 103 - }, - "id": "cxcx0MIlg-fI", - "outputId": "0cf1655d-5ce7-4044-b1f5-b36dd0a93b3b" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAIEElEQVR4nO2bS08TbRuArx6ntJQONBQoYIFgEPCQYBQ0Rl2YuDFGoy504Q/xJ7j1B7hwR+LGjboQSZRo8IQBLIdWC5aDI1Baep525luYzgulKF++TjHv1yvpZmae9u41z9z3c2gNqqpS5R+MBx3A30ZVSBFVIUVUhRRRFVKE+Q/n/80lyFDqYLWHFFEVUkRVSBFVIUVUhRRRFVLEn8pu2VEUhXw+Tz6fR5ZlMpkMqVQKs9mMxWLBbDZjMpm06wVBwGKxYDCUrJJlp+JCstksW1tbrK6uMjs7y/v373nz5g0+nw+v16u9Cpw5c4bm5maMRmNFpFRMiCzLpNNp1tfX+fbtG9++fWN+fp5AIEAoFCKbzZJMJolEIqysrABgNBrp7u6mvr4eq9WK2ax/uBUTEo1GmZiY4MWLFzx69IhUKkUymSSXy5HP51laWuL9+/cYjUaMxl+pzWAwIIoiDoeDlpYWamtrdY+zYkIsFgv19fXYbDaSySTpdJp0Oq2dz+fzJdvNzc3x7t07hoaGMJlMWp7RDVVVf/cqG7Isq4lEQh0eHlZbWlpUp9Op8muu9NuXy+VS29vb1QcPHqh+v1+Nx+PlCqnkd65Y2TUajZjNZlpaWjh9+jRtbW2YTCYtUZrNZux2+667n0ql2Nzc5MePH0iShCzL+sap67tv/yCjEavVSldXF3fv3uXixYsIgqAJqK2tpampCYfDsaNdNpslHo8TDAYZHx8nHo/rG6eu716C2tpaenp6aG9vx+VyIQgCACaTCUEQdoxBtuPxeDhy5Ag1NTW6xldxIU6nk76+PoaGhvD5fLhcLgCsVit1dXVYLJZdbQwGA8ePH+fs2bO6V5qKD8wKOcPtdjM0NIQoilq1kSRpR+U5CCoupIDX6+XWrVuMjIwQi8UIhUKEQqE9ry8M+fXmwCZ3NTU1tLe3c/78eW7cuMG5c+dwu90lc4SqqgQCAfx+P6lUSte4DqyHOBwOHA4HTU1NDAwM4Ha7CYVCLC4u7vrSqqoyPj7O1tYWXq8XURR1i+vAhKTTaeLxOIlEgmg0SiAQYHV1lUQisetag8GA2+2mpaUFq9Wqa1wHJmRrawu/38/y8jLhcJhPnz7x/ft31D32mj0eDz09PbqX3YoJyeVyyLLM2toac3NzzM/PMzMzQywWIxqNMjc3t6cMgKamJrq7u7HZbLrGWVEh8XicsbEx7t+/jyRJ/PjxA0VRUBTlt20NBgM+n4/Ozk5tIKcXFasy+XyeWCyGJEmsra2RSCRQFOW3vaKAqqrMz8/z5cuXf0+VyeVybGxsIEkSP3/+RJblP/aM7czNzWGz2Whra9NGt3pQMSGCINDR0cGFCxcIh8P4/X4+fvy4r0cGIJlMEovFdB+cVUyIzWbDZrPR39/PlStXEASBycnJffUUVVVJJBLE43FyuZyucVa87IqiyODgIKIo0tDQQDabJZvNaudVVUVRFMbHx5menkaW5YoM2QtUXIjdbsdut+N0OmlsbCSTyZDJZLTzBSHpdJpgMKhtWVSKAxuY1dTU4PP5Sk7aFEXh4sWLKIrC69evCQaDrK6u4nA42NjYIJ1OY7FY9lw7+V84MCGCIOw5plBVlRMnThCLxVhYWCAYDLK5ucni4iKxWIxMJoPJZNJFyF+3lVl4ZKampnj8+DEzMzPAP2uyem9YHVgP2YuCkEAgwOjoqHa8sAXxfydkdXWVYDDI/Py8dsxgMDAwMMCxY8doa2vDZrPp8riADkIK+xvb7+J/c0clSWJsbIzv37/vOO71eunv70cUxZLrruWibEIURdGG569evaKhoYGOjg5EUcTtdv+xfS6XI5fLMT09zdOnT3f1kK6uLk6ePIndbi9XyCUpW1JVVRVZlvn58yfPnj1jdHSUUChEJBLZcxK3fccsl8uRyWRYXFzk8+fPrK+vA79kmEwmmpqaaGxs1LV3QBl7yNbWFs+fP2dycpKXL1/idrtZWlri1KlTXL9+HavVitVqxWKxIAgC6XSaVCqlbXr7/X4+fPjAq1evtJkwwODgIL29vfT19emaOwqUTUgqlWJiYoJAIEA4HCYSiZBOp3E6nUiSpK2hOhwOzGYz2WyWSCRCNBolEonw9u1bRkZGCIVC2nzFaDTS2dnJsWPHaGho0MqunpRNiNPp5OrVq3z69InFxUXW1tZYWFjgyZMnzM7OarmksDYaDof5+vUr8XicaDTK8vIy6+vr2r5MQ0MDoihy6dIlLl++TH19/Y69YL0omxCz2UxrayvxeJy2tjby+TwrKyuEw2FCoRC1tbV4PB6am5s5dOgQX79+xe/3k8lkdkzu4FfecLlc+Hw+Dh8+jMfjqYgMAMMfVqz2/dNuRVHIZrMsLS3x8OFDVFXFarXi9/sZHh7WNrutVqv2G5FkMrkr4dpsNgRB4N69e9y8eROPx4PdbsdgMJRbSMk3K1sPMRqN2Gw26urqaG1tRRAEvF4vsixjt9u1JJnJZDQRpWaxdrsdURQ5ceIEXV1d5Qpv35R9YCaKIrdv3wZ+df3CD+YKQjY3N5EkCb/fz9TUlNauUH3u3LnDtWvXOHr0aLlD2xdlF2KxWBBFUZuTNDc3MzAwoPUGSZJwuVwoioIkSVq7worakSNH6O3txel0lju0fVG2HFKy8bYBV+FzCo9KJpPZsdNfyBGiKFJTU1OJElsyh+gq5C+n+n+Z/VAVUkRVSBFVIUVUhRRRFVLEnwZmlfmTyl9EtYcUURVSRFVIEVUhRVSFFFEVUsR/AP0FXN1zCRLUAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "#เลข 3 ที่ index 0\n", - "show_image(stacked_threes[0,:,:])" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 85 - }, - "id": "2Es3hPcXpojz", - "outputId": "14820e4c-7b0d-42b4-cb01-fe33175307c2" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAJtUlEQVR4nO1b2XLiWhJM7QsChDG22x3h//+qfnKzWVhoX5HmoaNqDufK9jRge2aCiiCEAS0nVUtWlqz0fY+r/dvU776A/za7AiLZFRDJroBIdgVEMv2D7/+fS5Ay9OHVQyS7AiLZFRDJroBI9lFS/RQ7tV1QlME8eFH7dEDkxdPf4udDAMmLVxQFfd8Pfn5JuyggQ4vs+/7o1XUdfy6+l01RFCiKAlX9E9WqqvJn9JJ/fwm7CCDy4ruu423XdTgcDjgcDqiqCm3boq5rtG2LpmlwOBzQti3vr6oqVFWFaZrQNA2WZUHXdViWBU3ToOs6NE3j3xFQZOcCcxYgQ15AIHRdh7Zt0bYtqqpCXdcoigJlWaIoCtR1jTzPUdc1yrLk4+i6DlVVMRqNYFnW0dY0Tdi2DcMwYBgGNE1jEGRgTrWTARHBkD2haRpUVYU8z1EUBcIwRJIk2Gw2CMMQu90OaZoiiiKUZYk8z3E4HNB1HUzThGEY8DwPo9EIi8UCs9kMj4+PmM1mmM/ncF0X4/EYlmUdeY4YYqeCc7aHDAFSliWqqkKSJMjzHLvdDmEYYrPZII5jBEGALMsQRRGyLEOWZRw6dOdnsxk8z0PXdSiKAoqioKoq6LqOw+EAXdfR9z17CYXPUOL9dEDkXCHmiKqqEMcx0jTFdrtFGIZ4fn7Gfr/HZrNBmqYIggBJkiCKIlRVhbIs0bYtDocDn8NxHFiWhYeHB9ze3iIMQ9zc3CDLMtze3qJtW0wmEyiKAsdxjpLvOXZ2UhXzBy1K3LZty+GlaRpM08RoNIKiKNA0jQFpmgZt27Knkaf0fY+6rlHXNYdhlmX8N53DNE2+ji/3EAJiKHfQhdZ1zVVEVVVYlsVxb9s270P7U/WhvNM0DZqmga7rnJgp76iqivl8Dk3TMB6Poes6uq7jkKEbcAowJ4eMaHRiimNd19kTyHNs24Zt27xQApRAITAJkKIokOc5NE07SpbyfqKHXsIuRszoonVdh+M4nOw8z4Pv++wB8j60QAJgv98jjmOuTMRZdP3PpYqe+R4Q31Jl6MQiGADQdR2Tp6ZpOESImYpGn2dZBl3XUVUViqJgPkILo5xD5zEM44ikib87x04CRD45ubVpmnyRXdfBcZyjaiTfTSJvTdPAsiyYpvmPUAFwlBNM02SCRpxlCLxT7SwPES9AVVVehBgKsnuLpZq25BVhGGK/32O/3yNNU+R5zhVLVVUYhnHEXonWEykb6nG+HBDxLhIQcqKTgaBS3Pc9V4+Xlxes12usViu8vLwgiiLs93tesKZpsG0bk8kEvu9jNBrBdV0GhRL6uXYyIDIQiqKg67pBUESeIvKJJEmYwa7Xa2y3WwRBgN1uhyzLkOc5JpMJl2rXdeF5HsbjMRzHOQpRsRv+FkBEUAgEIlJvtfvU4NHdXy6XWK1W7BUESBiGHIau60LTNDiOg8lkwpTecRw4jnPkHd+WQwgA8b28JRAOhwN3tHEcI4oiDo3lcsmhst1usdlsUBQFqqri5EkVi7jNUHW5VIU5GZD/BBSRxRZFgSRJsN1usVqt8OvXL6zXazw/P+P3799YLpeI4xhJkvAxPc+D53kA/lQxSqhi6y+LRnQt59jFRGbxQihcyDuKouCmbrPZYLPZIAgCLJdLBEGANE1R1zULRCLPAHBUiaiBpLZAZqvnMtazc8iQZirS67qukWUZ4jjGer1mQJbLJZbLJZIkQZIkXIYVRTnyAjpWVVUsFbiue9QMUh/zrSHznhFIIg+h5EpyoO/7WCwWGI1GmEwmDKSmaZxEKUwo7KIoQhAE3NRR9yxrr8A3UnfRZJFZJmZEvy3Lgud5uLu7Y3WNTNRLAcAwDHRdhzzPoaoqXl9foWkabm5uoOs6bNvm4367h7w3SlAUhXOB67po2xY/fvxghlkUBbIs49AiI/BkQbrve65UqqoiiiJomgbXdTnviJ5C1/C3dramKr8XL0am2/P5HI7jwHVdFn1kSk+9DVH3OI4ZOGK1qqpiv9/DsixUVcVeJHriqXaWHjK0FUsxjRNc14VhGDBNE3VdYz6fs2fIgJCC9vr6ijiOmXiRStY0DVet0WiEsixhmib3O8SWvyyHvFVVxO+o+pD7ip0pdbhDs5yu61igtm2bQ4tAAoC6rqFpGqv1ojInMmU69t8Cc5aHyG39EDAU3zRzeav5I05BJZd6HgLGMAxOvqJsQFLkECf58hwytDBxS6CQejaUa2QPIbNtm0uvyErl38tgnGt/BYjsGVQd3tM236PW4gLp+LRYyhUkWFOYDc13L8FQyU7OISIAlBxliVAeWL8Finx8keWSQCQn7KF9ZftS1V2Me3FoTQuStVZRURMBot9TcozjmGn+arXCbrfjkSclTmK7NBCn7ve9pwM+DRBZ6xABaZrmKBcQ42zbli9cFKMVRWFQReEoTVNW37MsQ1EURyFD9F7WU7+NqYqA0BCpbVu+8Kqq+Ht5VkPAkFG1oMZts9mwRvL6+ookSVAUBatjJBT5vo/pdMojTzr2uU3eSUlVBoW8g2a0RVFwDiAKL48OyGhARSradrvFbrdDEAQcKsRGiehRBSJuQ99dwlP+ChBZFBJPTqSqLEsEQcC0m7xIfIaDRo9Ex8uyRJqmSNOU5YA8z1GWJfMQz/MwnU5xf3+Pu7s73N3dYTKZwPM82LY9mEc+HRARGPmkYnIkgTiKIvYcAk0UpEV5kYbYSZLw4xF93zMpcxwHnudhMplgMpnAcRxmwEO66qn214CIYFBiI+2TdFAAR3khDEPUdY04jo9yDok8IuOkhOn7PsbjMR4fH+H7Pp6ennB/f4+npydMp1PMZjMGhfb58pAZAoUSpvwiLxCbsd1ux5WE8g6Vb3J3Yqeu68L3ffi+j/l8jsVigdvb26MwuVQiPRkQOinxCHERNL60bRsAMJ1Ooes69vs9DMNAkiQwDIPlRNI66O7SvGU+n2M8HuP+/h6z2Qw/f/7EYrHAzc0NC88iGEPc5ssAkcERgSHdAwCr5VmWQVGOH4Ui8IjIkUfNZjOMRiPMZjP4vs9PDj08PGA8HnPeoFmMqKxdcgyhfNADDH451NOQuEP6Z9M0XCnSNOVHrUg9pypDi6PRpLilZ0rkofZQiT0BjMEdzgJkiLVSmaXRgchPaEu5gzQTEotN02SSRS+x2x16NvUMr7gcIPzlAFED/tn9fjQ7kZO03JO81aOcGSKDO19ktiv/LbblQ9uPjvfW9q3zXtLO8pCP7BIaxScu/vIe8uEZP/FOfpZd/4FIsisgkn0UMv97Pn+mXT1Esisgkl0BkewKiGRXQCS7AiLZvwBtCZqwAvXF1QAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "#เลข 3 โดยเฉลี่ย\n", - "mean3 = stacked_threes.mean(0)\n", - "show_image(mean3);" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 85 - }, - "id": "ambkeHzzpojz", - "outputId": "2ef89ced-bead-4371-ff36-d472f5fa0121" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAI6klEQVR4nO1baVPiTBc9ZF/IhoOWOjUf5v//KgelNEZICCErvB+eutemjaNC8N04VanG7H1y99uOdrsdzniF8u9+gf80nAmRcCZEwpkQCWdCJGgfHP9fdkGjvp1nCZFwJkTCmRAJZ0IknAmRcCZEwpkQCWdCJJwJkfBRpHoQPlNjkc8ZjXoDxzf47HmHYhBCxMnR74/Gz0IkYDQa8d/y2Hf+ITiKEHmS2+2Wx91uxxv9TcfFY+L1BHGy9FtRFIxGo96RjsvXH4KDCBEnIk6aJt51HbbbLbquQ9d1aNuWR9pP58n3kUET1zQNqqrCMAyoqgrTNKGqKjRNg6IoexvhEGK+TIj44kQCEdA0DZqmQVmWaJoGm80GdV1jvV6jrmvkeY6qqrDZbNA0Deq6Rtu2TFSfJKmqCkVR4LouTNPExcUFXNdFFEVwHAe+78M0TViWxQSNRiOoqordbvdlUr5EiCwZ9KVJAoqiQNM0WK/XKMsSWZahKAokSYLNZoMsy1CWJRPVNA1fS6SKKgaACRmPxzBNE9PpFJ7n4devX/B9H5qmYbfbMRHb7RaKohxExpcI6VMPmgxNcLVaoSgKxHGMNE0xn8+RZRniOEae51gsFlitVlgulywpoiqJ6kQbffUgCOB5Hn7//o0wDJGmKS4vLwEAQRBA0/6ZCqnYyQnpI0ckhlSFJCHLMqRpymOe50iSBOv1GqvVCm3boq7rvfuJIHLatkVVVRiNRmiaBlEUQVEU5HkO13VZ0kjCjsWXVUYkous6NE2DqqpQliXyPEeapojjGMvlEnEcI8sy3N/fY7VaIUkS1HWNsiyZAEVR2DCS1xCfQXamrmvoug7TNFHXNS4vL2EYBvI8h23bLK3vGeeTEPI3kIskEdd1HbquwzAM+L4PVVUBgL+6fA55ETKyy+US6/UaWZZhvV7zMwDskSkSSb9Fd/1VfIqQjxgXyaAJmqYJ27aZhPF4jDAMoaoqu03HcfhcXdehqip7qiRJkGUZ7u7u8PT0xAabDCdBdrnHkPFpQshIyQQoisISsd1uYZomuq5DEARQFAV1XcO2bei6ziqg6zosy4Jt23BdF5ZlwbIslpDNZoOyLHl/nucoioLVgQik0bIsjk2+jRCRiD5CDMMAALiuC1VV0XUdTNOEoiioqgphGLKtoNjB8zy4rssTo3sSIZ7nYTabIc9zrNdrVFWFtm1hWRYcx4Ft27Btm4kjCRNV5tu8jKizAKDrOn89ALBtm41j13WoqgqapsE0Tbiuy5LhOA50XedYYrfb7UWmwD/qRt5IVVU4jgPXdTEejxEEAUuIaJiPwacJER+kKAoHQPTypNuqqrL6ULS42+1YVSzLYsmwLIuJJZUiIk3TZEKqquKYxHEcOI4Dz/M4WqUoVbQjJydEJIaCHjHxAv6RFABMhghR98muEJF0Pbnbuq6RZRkHcuRlDMPAeDyG7/sIw5CjV13X3+Qxh+JglQFeJYV0V9d1Jqxt2z2dFr0P6TtNgK6h2KYsS6RpisVigcViwUEYqVwQBIii6A0hx7rcLxMiehvRsJLEkOEkkogQ2i8TQRCDvDzP8fz8jMfHRzw/PyNNU9R1zaG77/sIggCu68K27T3SAey938mTO3oQPVj0OuRxSBrETJU2MVUX70PGl/Kh+XyOp6cnzGYzpGmKpmmgaRo8z4PneQjDkCWGpGMoHBypytICvNoSsh80igUdoL+EUBQF0jTF/f097u7uEMcx4jhG0zRQVRVhGGIymfBIrvZYFZFxVOjeV84Tv5ZMmLyfJKNtW2w2G86QHx4eMJ/PkaYpezFys7TJrnYoUg42qrItod8A9ryGDLmMQMlekiT48+cPZrMZ5vM5FosFmqZhb3J1dYXLy0tMp1OMx2OOP0TJG4KUoyVEtCVkYGkTiZOLS2RI67pm6Xh4eEAcx3h4eECe5+i6DoZhIIoihGGIi4sLNqhiZErvMgSOtiF9Ve/tdvsm/wH2SaGSY1EUeHl5wWw2w2w2w8vLC6uK53m4vb3Fz58/cX19jZubGwRBwEmhHIx9u9v9G+RIVm5NyG6RCktUR0mSBEmSYLFYoCxLqKoK27bx48cPXFxcYDKZYDKZwHEcmKb5xn70kfBtucxnH0xSIhdtKHCjuuv9/T3HHGVZQtM0RFGEIAhwfX2N29tb3NzcIIoi2LbNkXBfyn+sCg1iQ+S/33sZUWWoClYUBZbLJfI8R5Zl7GZ938d0OkUURZhMJvB9H7Zt73mX98g4BkdLSB8phPdUheqvaZri6ekJSZJwi4LynNvbW0ynU5YQMqaidPTZjW/Ldv8GedLv7SPVIemgEiGRQV5FzGajKILneRyIDW1EZQza7O7zLMCrVxGN6HK5xOPjI6vLdruF4zgIggC2bePq6gpXV1fchxErY3LkKz7/WJx8OURfy4J6Mnmec08HANdIKGchcvq8CtAfFB6LQSVEVg+5b0M9mcVigefnZywWC2w2G4xGI1aJyWSCKIrY1ZLdoJrr0KG6jJOtD/lbD2ez2aAoCpRlia7rOF/RdZ2Lz5TeExGniEr7cJL1IWJoTipSFAVWqxWyLMPLywvyPOe2AnkWUUJEdZELQKfE4CrTJx3Ua6mqipO5uq65qKxpGtsPalGIlfTvIgMYgJC+tSLUZyXpoJ7ver1mQ0pVNbHOats2wjCE7/uwLOvDmOMUGNyG9NkOMqo00nIHsTBMTScav8uIyhhsSdV7RpQWxtBGrQbTNN8siKEikOhqRYP6X6EyIvoW1IgrgyiUp8YUqYSqqiwdhmFwi2KIPstXMWhy99451AR3XZcnKa8CIBsitiffa1mI49A4SRwC7Lc7adLU6gReWw+apnGb0zAM3khy/lbzOAUpow++8KdWnvTZEHHJFdkSMqxt2+6pEEkRkUNumEj5qDJ2IDG9Fw1eMaNqmdinESdMC+xEQoDXxXWiIZXvcYp0/808hpAQPrknYhX391XPel+qJ4EbqiImPqZ355CE8EXvlADeO953ft/EB5aKgwj5v8P530MknAmRcCZEwpkQCWdCJJwJkfAv6ObhbeIGuNEAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "#เลข 7 โดยเฉลี่ย\n", - "mean7 = stacked_sevens.mean(0)\n", - "show_image(mean7);" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 85 - }, - "id": "aKESuqYHpojz", - "outputId": "fc68b38d-ba24-4c02-9b92-154e5e4908f9" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAJUElEQVR4nO2by08bVxuHn/HM2OP7BbsmYJNQKCQUCGkuyratmi7aRbfZt1L/ga66Sfbtv9FVF11E6aKLKG2RqBIlaUpSih0IoSF2YkrAl7l5xl1830w/GxJKGBP0yY+EkOacmfPOb95zed9zLLRaLXr8g+91G3DY6AnSQU+QDnqCdNATpANpl/L/5ylI2Oliz0M66AnSQU+QDnqCdNATpIOeIB3sNu16hm3bWJaFaZo0Gg0Mw0BV1RfW9/l8+Hw+UqkUiqIgiiI+X/e/34EJYhgG1WqVxcVFrl69ysLCAtevX39h/VgsRiwW48svv+Tdd98lGo0SCAS6bmfXBFFVlSdPniAIAqIoUq/XefbsGXfu3KFQKLC4uMj6+vpL769Wq9y6dYu+vj6GhoZIJBJdF0bYJR/yyivV3377jc8//xxJkkin0zx//pwnT57w/Plz1tfXsSwLy7JebJggIAgCwWCQQCDAxx9/zNtvv83FixfJ5/OvalZbEztd9NxDTNNEVVXK5TKlUgmAjY0N6vU6m5ubaJqGaZr/sUjY0SaXVquFqqrous6jR48IBAJsbm6SzWaRJKkrY4rngjQaDW7evMns7Cy1Wo1Go4GqqjieuNcMXavVwrIs7ty5w9LSEh9++CHZbJZYLNaVruOZIM7s8fjxY2ZnZykUCqiqSrPZxLZtt14qlSKfz5PJZHjjjTe2PadYLLKyssLW1lbbLNRsNjEMA13XMU1zz8L+WzwTRFVVbty4wezsLF999RWapu04RkxPT/PJJ59w9uxZJicnt5VfuXKFb7/9llu3brGysuJet23b7Y6GYRx+QVqtFoZhYJomhmFsE0OSJBRFIZfLMTk5STabbXN5VVWp1WqUy2XW1tbQNK3tflmWURSFvr4+otEooih6ZXq7nV4+zBFkp68XCoXIZrNMTExw+vRpAoEAfr/fLS+VSszPz3P79m3u3r2LYRht9weDQZLJJLlcjlQqteuA/Kp4JojP5yORSBCJRJAkyZ0yE4kEw8PDDA4OMjo6yvnz5/H7/QiCgGmaaJpGtVplfn6e69evs7y8jGma7rgjiiKSJHHy5EnGx8dJJpNdEwM8FEQURdLpNOl0GkmSCAaDZLNZTp06xYULFzhx4gQTExPIskwgEHAHyKWlJX755ReuXbvGlStX0HWdZrPpPtfv96MoCu+99x4XLlzYcSD2Es8EkSSJTCbD6dOn+fTTTxFFkUQiwdDQEDMzM6RSKWRZdvu+rutuN5mbm6NQKKDruusZ8XiccDjM5OQkg4ODnDlzhv7+fmRZ9srkHfF8pWoYBrVaDZ/PhyiKrkd0uvnq6io//fQT33//Pd988822cWd8fJzR0VE++ugjzp07x8jICPF4fK/mvIyDWamKooiiKAiC4IqyU5+vVqvcu3ePUqnUJkYmkyGdTvPBBx/wzjvvMD09TX9//4EEdtAlQUKh0K71KpUKc3NzLC0ttV0fHh7m5MmTXLx4kZmZGSRJ6toUuxOvNUHUarW2dRVd1914Z6fybnPoMma6rlOv19F1HcuyDlwQ8dKlSy8rf2nhflBVlXq9jm3bPHr0yPWGVqtFo9EAoFarYVkWzWbT/e+MSx5weaeLr00QQRBIJpM8ffqU+fl5bNum2WyiqiqVSoWVlRV+/fVXNzgUBIFWq0UwGPRq6j18goTDYSKRCOFwGFEUKZfLbeOGruuoqsrjx49ZXl5maWmJ/v5+r3KsOwpyYDnVToLBIMFgkJmZGaLRKLZtc/fuXQAsy6JWq7nBHkAymSQejzM0NEQkEkGWZSTJe/Nfm4c4+Hw+IpEImUyGY8eOceTIESKRCK1Wa9tMY5ompmny4MEDMpkMsiy3rX73yOHyEAfHUxRF4ejRo8zNzREOhwHcQdeyLDRNQ9M0vvvuOxRFYWBgAEmSCIVCbVHzfnntHuIgCAKKopDJZDhx4gTHjx9ncnKSTCaDoihomuZm0GzbJhaLUalUiEQi2LaNoih77UKH00Mc/H4/fr+fcDhMLpfj6NGjnDp1ilgsht/v56+//mJ9fR3btmm1Wty4cYM///yTaDTK1NQUyWQSRVH2bUfXtiH2i5NsevbsGZVKhZ9//plCocAPP/xAsVgkFosRDAaZmpoin89z6dKlvW5PHExw5xWOxwQCAY4cOYKiKAwNDbGwsECxWGRra4utrS2q1SqpVIovvvjCk3YPrSAOTrSczWaRZZlMJtPV9g69IM5WaDgcRhCEfxVJ74dDL4iz5bm2tsbq6qq7G9gtPBWkc4B+1WTw/+7ymaZJs9mkXC5z//59NjY2trXh5ZamZ4Lous7Tp09pNpvouk40GiWdTrsGO67vBHGw87Zms9nENE1WV1dZXV1lcXGRUqlEsVjk4cOHLC8vu3VFUeTMmTOMjIy4i7n94okgzpcslUrutkI2myUSibhBmLO8dl74RckfVVVRVZU//viDmzdv8vvvv7OwsEC5XGZra8ut5/P5kGWZfD7PyMiIZynGfQtiGAabm5sUCgW+/vprd082FArR19dHOBwmnU6TzWY5fvw4Dx484N69e+5SvBNHsFKpRKVSoVqt0mg02uoKgsDZs2d58803+eyzz3jrrbdIJBL7fRXAA0Fs26Zer7O2tsaPP/5IrVZD13W3PB6Pk8vlGB8fZ2Njg8XFRa5du0aj0aBer++pLafbybLMsWPHGB0dZWRkhP7+/v2+xj9t7Helats2hmHw8OFDLl++TLFY5Pbt2207b4qiEAqFiEaj1Go1NjY2sCyr7VTANsP+e2DG+QuFQgQCAd5//32Gh4c5f/48uVyOsbGxVx0/urNS9fl8KIpCLBYjn89jmiYLCwvu0tuyLOr1unuk6qUWCgKSJCFJEn6/3824OymCcDjM2NgY09PTbjfxOifiWSxTr9e5f/8+a2trzM3Nsby8zNWrVzFNs60LvYhIJEI0GuXcuXOMjY0xMTHB8PAw8XicUCjkzlTOGTO/34/P59vPSaLuxjKSJDEwMEAwGETTNCRJYnBwEFVV0TRt1+x5KpUilUoxOjrKzMwMU1NTDAwMEA6HD2yTCjz0EGd94RxsMU2TarXqlu2GMzUrikIgEHBThM551S6wo4cc2vD/AOj9Xubf0BOkg54gHfQE6aAnSAc9QTrYbWHWveN+h5Seh3TQE6SDniAd9ATpoCdIBz1BOvgb74E4FYpj4GsAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "#เลข 3 อันที่ 125\n", - "a_3 = stacked_threes[125]\n", - "show_image(a_3);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "k4Rm7EZjU-Gb" - }, - "source": [ - "เราจะเห็นได้ว่าหากเราคำนวน mean absolute error และ mean squared error ของแต่ละ pixel ระหว่าง \"เลข 3 อันที่ 125\" กับ \"เลข 3 โดยเฉลี่ย\" และ \"เลข 7 โดยเฉลี่ย\" เราจะเห็นกว่าค่าของ \"เลข 3 อันที่ 125\" กับ \"เลข 3 โดยเฉลี่ย\" มีค่าน้อยกว่า และในกรณีนี้ระบบ (ที่ไม่ใช่ ML) ของเราจะทำนายถูกว่า \"เลข 3 อันที่ 125\" คือเลข 3" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "JbfdH2zWpoj0", - "outputId": "68513061-cf01-4337-adba-950be0372b4b" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(tensor(0.1259), tensor(0.2290))" + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "EilsFTaWpojx", + "outputId": "97aecb3b-fce5-4f7c-9b62-e6d43253e616" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(#6131) [Path('/root/.fastai/data/mnist_sample/train/3/10.png'),Path('/root/.fastai/data/mnist_sample/train/3/10000.png'),Path('/root/.fastai/data/mnist_sample/train/3/10011.png'),Path('/root/.fastai/data/mnist_sample/train/3/10031.png'),Path('/root/.fastai/data/mnist_sample/train/3/10034.png'),Path('/root/.fastai/data/mnist_sample/train/3/10042.png'),Path('/root/.fastai/data/mnist_sample/train/3/10052.png'),Path('/root/.fastai/data/mnist_sample/train/3/1007.png'),Path('/root/.fastai/data/mnist_sample/train/3/10074.png'),Path('/root/.fastai/data/mnist_sample/train/3/10091.png')...]" + ] + }, + "metadata": {}, + "execution_count": 6 + } + ], + "source": [ + "#ในละ folder 3 และ 7 จะเป็นไฟล์รูป\n", + "threes = (path/'train'/'3').ls().sorted()\n", + "sevens = (path/'train'/'7').ls().sorted()\n", + "threes" ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ความห่างระหว่า \"เลข 3 อันที่ 125\" กับ \"เลข 3 โดยเฉลี่ย\"\n", - "dist_3_abs = (a_3 - mean3).abs().mean()\n", - "dist_3_sqr = ((a_3 - mean3)**2).mean().sqrt()\n", - "dist_3_abs,dist_3_sqr" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "s5xmB8f1poj0", - "outputId": "d9967f38-5193-4528-b890-5de8e16a42a8" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(tensor(0.1836), tensor(0.3390))" + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2nKh8v3OBG5G", + "outputId": "9442d7c1-6125-4d88-c81a-e7e87d8b49ce" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(6131, 6265)" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ], + "source": [ + "#จำนวนรูปในแต่ละ class\n", + "len(threes), len(sevens)" ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ความห่างระหว่า \"เลข 3 อันที่ 125\" กับ \"เลข 7 โดยเฉลี่ย\"\n", - "dist_7_abs = (a_3 - mean7).abs().mean()\n", - "dist_7_sqr = ((a_3 - mean7)**2).mean().sqrt()\n", - "dist_7_abs,dist_7_sqr" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "lDlxHuO8Vazc", - "outputId": "a20160f4-2575-47c5-a8dc-bb5169feb261" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(tensor(0.1836), tensor(0.3390))" + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 45 + }, + "id": "jJEUNNQKpojx", + "outputId": "16277235-20fb-47c8-d1f4-bb6bbb0483a9" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA3UlEQVR4nGNgGNxA1XnOxX//vxlgkVIqe//379+///4+EkCX4lv//e+/v38vnf33968wumTd379/rzUZ8lf9e1XChC4Zf39RAAMDA+eFf2txuIg18htuObW/f9/y45CM/Pt3iwYOuaxvv7Zz4pBjmPvvoxwSF9XNrxhZ/XFpZOCf9+tjMk5Zhsl/d+AyloHhI4MsTo22T//uF8cu5Xbo/79/8eiiXGVlyqyOk3/9/fe7nRld0urv3/uH/v79+/d1GKZ5In///v339+/zDdiClbVu9sF/X2cr4XQoVQEA4o1d+YAEyFQAAAAASUVORK5CYII=\n" + }, + "metadata": {}, + "execution_count": 8 + } + ], + "source": [ + "im3_path = threes[6000]\n", + "im3 = Image.open(im3_path)\n", + "im3" ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ใช้ function ของ pytorch คิดก็ได้\n", - "F.l1_loss(a_3.float(), mean7), F.mse_loss(a_3, mean7).sqrt()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4Lx6GAdSVgD2" - }, - "source": [ - "เรื่องน่าคิดถึงความแตกต่างระหว่าง mean squared error และ mean absolute error" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "IcAtO7IwDzUp", - "outputId": "cb2fc2a2-cf78-4ada-895d-29ab5a4e3cbb" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(1.1666666269302368, 1.1902379989624023)" + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GBYX4Mf0BTb6", + "outputId": "b77ba285-8176-455e-bd86-6d16b75f9acd" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(28, 28)" + ] + }, + "metadata": {}, + "execution_count": 9 + } + ], + "source": [ + "#รูปขนาด 28 x 28 pixels \n", + "array(im3).shape" ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = torch.tensor([1, 2, 3]).float()\n", - "b = torch.tensor([2.,3.,4.5])\n", - "c = torch.tensor([2.,3.,40.])\n", - "\n", - "#mse และ mae ไม่ต่างกันเท่าไหร่สำหรับ a และ b\n", - "(a-b).abs().mean().item(), \\\n", - "((a-b)**2).mean().sqrt().item()," - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "y5zmXGk2poj0", - "outputId": "96a206ec-10da-4054-8de4-13e538559074" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(13.0, 21.3775577545166)" + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RZphr7qzpojx", + "outputId": "ba2bf4b9-1cdd-4583-8734-abf34a064f96" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array([[0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0]], dtype=uint8)" + ] + }, + "metadata": {}, + "execution_count": 10 + } + ], + "source": [ + "#เปลี่ยนจาก numpy array\n", + "array(im3)[4:10,4:10]" ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#mse และ mae ต่างกันเกือบเท่าตัวสำหรับ a และ c\n", - "(a-c).abs().mean().item(),\\\n", - "((a-c)**2).mean().sqrt().item()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ErTcHuGkpoj4" - }, - "source": [ - "## Stochastic Gradient Descent (SGD)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XTAhc8bJWi6y" - }, - "source": [ - "หากยังจำได้จากบทที่ 1 เราเรียนรู้ว่าเราจะคำนวณ `Gradients` จาก `Loss` แล้วให้ Optimizer ทำหน้าที่ update `Weights` ในบทเรียนนี้เราจะมาเรียนรู้ขั้นตอนเหล่านี้กัน" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "V9fzeVckWhP9" - }, - "source": [ - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OISepsl-poj5" - }, - "source": [ - "### คำนวณ Gradients เพื่อทำ Backpropagation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_9WrHtiu0ScN" - }, - "source": [ - "หากคุณยังไม่เคยเรียนเกี่ยวกับ [partial derivative](https://en.wikipedia.org/wiki/Partial_derivative) และ [chain rule](https://www.khanacademy.org/math/ap-calculus-ab/ab-differentiation-2-new/ab-3-1a/a/chain-rule-review) ในชั้นเรียนมัธยมปลาย คุณอาจจะไม่จำเป็นต้องเข้าใจเนื้อหาส่วนนี้ทั้งหมดก็ได้ ใจความสำคัญคือเราสามารถปรับแต่ง `Weights` ได้ด้วย `Gradients` ที่ถูกคำนวณมาจาก `Loss` เพื่อให้ได้ `Loss` ที่น้อยที่สุดเท่าที่จะทำได้ใน iteration ต่อๆไป" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tP7xMjbi6er4" - }, - "source": [ - "ตัวอย่างการทำ Backpropagation อย่างง่ายจาก [cs231n](https://cs231n.github.io/optimization-2/#backprop)\n", - "\n", - "Independent variables:\n", - "(เราอาจจะคิดว่า `x, y, z` เป็น `Inputs` หรือ `Weights` สำหรับโมเดลก็ได้)\n", - "\\begin{align}\n", - "x & = -2 \\\\\n", - "y & = 5 \\\\\n", - "z & = -4 \\\\\n", - "\\end{align}\n", - "\n", - "Dependent variables: \n", - "(`q` และ `f` คือฟังชั่นอะไรบางอย่าง เช่น `Loss Function` ของโมเดลก็ได้)\n", - "\n", - "ถ้าแทนค่า `x, y, z` เข้าไปในฟังชั่น `q, f` จะได้\n", - "\n", - "\\begin{align}\n", - "q & = x+y = -2+5 = 3\\\\\n", - "f & = q*z = 3*-4 = -12\n", - "\\end{align}\n", - "\n", - "การทำแบบนี้เปรียบเสมือนการเปลี่ยน `Inputs` เป็น `Predictions` ด้วย `Weights` เรียกว่า `Forward Pass`\n", - "\n", - "หลังจากนั้น เราสามารถคำนวณหา `Gradients` ซึ่งโดยทั่วไปแล้วหมายถึงค่าอัตราการเปลี่ยนแปลงของฟังชั่นท้ายสุด (ในที่นี้คือ `f`) เทียบกับตัวแปรแรกสุด (ในที่นี้คือ `x, y, z`) ได้แก่ $\\frac{df}{dx}$, $\\frac{df}{dy}$, $\\frac{df}{dz}$ เราสามารถหาสิ่งนี้ด้วยการหา [partial derivative](https://en.wikipedia.org/wiki/Partial_derivative) และ [chain rule](https://www.khanacademy.org/math/ap-calculus-ab/ab-differentiation-2-new/ab-3-1a/a/chain-rule-review) เรียกว่า `Backward Pass`\n", - "\n", - "\\begin{align}\n", - "\\frac{df}{dq} & = z = -4\\\\\n", - "\\frac{dq}{dx} & = 1\\\\\n", - "\\frac{df}{dx} & = \\frac{df}{dq} * \\frac{dq}{dx}\\\\\n", - "& = -4*1\\\\\n", - "& = -4\\\\\n", - "\\end{align}\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "VXdDIlJl6jYC", - "outputId": "326dcf7c-8210-49c7-918f-e4067ee240ab" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(-4.0, -4.0)" + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "o52Dig-Vpojx", + "outputId": "9295e887-7371-4494-fca8-829823e96a64" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0]], dtype=torch.uint8)" + ] + }, + "metadata": {}, + "execution_count": 11 + } + ], + "source": [ + "#เป็น torch tensor\n", + "tensor(im3)[4:10,4:10]" ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# set some inputs\n", - "x = -2; y = 5; z = -4\n", - "\n", - "# perform the forward pass\n", - "q = x + y # q becomes 3\n", - "f = q * z # f becomes -12\n", - "\n", - "# perform the backward pass (backpropagation) in reverse order:\n", - "# first backprop through f = q * z\n", - "dfdz = q # df/dz = q, so gradient on z becomes 3\n", - "dfdq = z # df/dq = z, so gradient on q becomes -4\n", - "# now backprop through q = x + y\n", - "dfdx = 1.0 * dfdq # dq/dx = 1. And the multiplication here is the chain rule!\n", - "dfdy = 1.0 * dfdq # dq/dy = 1\n", - "\n", - "dfdx, dfdy" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kt9JIcVZ3CuU" - }, - "source": [ - "Pytorch สามารถทำ `Backward Pass` ให้เราโดยอัตโนมัติด้วยฟังชั่น Autograd โดยที่เราไม่ต้องคิด partial derivative เอง ผลข้างเคียงอีกอย่างคือเราสามารถใช้ Pytorch ช่วยทำการบ้านวิชาแคลคูลัสเวลาเราหา derivative ที่ยากเกินไปไม่ออกได้อีกด้วย" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "SWg90ZHApoj6", - "outputId": "96fde65e-54bc-43c4-a8e1-e4217f5441c1" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor([ 3., 4., 10.], requires_grad=True)" + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 394 + }, + "id": "0yZCaBRNpojy", + "outputId": "9e3f04a1-6be4-48db-ac47-3ee450e0c536" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 01234567891011121314151617
0000000000000000000
10000000000376715620925425524648
20000000034118239253253253253254253226
300000014175247253254253253210205254253253
40000001262532532532141304915122254234116
5000000952231628000092082541730
60000000000058924625417300
700000000005382532532371500
800000000008925325318040000
900000000010624625018390000
100000000001572542413000000
\n" + ] + }, + "metadata": {}, + "execution_count": 12 + } + ], + "source": [ + "im3_t = tensor(im3)\n", + "df = pd.DataFrame(im3_t[4:15,4:22])\n", + "df.style.set_properties(**{'font-size':'6pt'}).background_gradient('Greys')" ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "xt = torch.tensor([3.,4.,10.]).requires_grad_()\n", - "xt" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "WpM1Q5nApoj6", - "outputId": "f969c1d9-e911-4bf7-9622-2bf843b63be3" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor(125., grad_fn=)" + "cell_type": "markdown", + "metadata": { + "id": "jMdNhmiXpojy" + }, + "source": [ + "## วิธีที่ง่ายที่สุด: ดูว่ามี pixel เหมือนกันแค่ไหน" ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ฟังชั่น f(x) = (x1^2 + x2^2 +...+xn^2)\n", - "def f(x): return (x**2).sum()\n", - "\n", - "yt = f(xt) #ใส่ 3, 4, 10 เข้าไปได้ 3^2+4^2+10^2 = 125\n", - "yt" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "POJB2ewmpoj6", - "outputId": "de1feba8-9f02-4654-8622-11a1b4b30b51" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor([ 6., 8., 20.])" + "cell_type": "markdown", + "metadata": { + "id": "MqYq_GbFT7XA" + }, + "source": [ + "สมมุติเราไม่ทำ ML อะไรเลย แต่เราพยามสร้างกฎแบบ Rule-based Systems โดยบอกว่า \"รูปที่มี pixels ใกล้เคียงกับ pixels เฉลี่ยของเลข 3 และ 7 มากกว่า ให้ถือว่าเป็นเลขนั้น\"" ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ถ้าคิดด้วยมือ df(x)/dx = 2x \n", - "#สำหรับ x1=3, x2=4, x3=10 ก็จะเป็น \n", - "#df(x1)/dx1 = 6, df(x2)/dx2 = 8, df(x3)/dx3 = 20\n", - "\n", - "#ใช้ autograd หา df(x1)/dx1, df(x2)/dx2, df(x3)/dx3\n", - "yt.backward()\n", - "xt.grad" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mHHqwvz2poj_" - }, - "source": [ - "## สร้าง Loss Function สำหรับจำแนกรูปเลข 3 และเลข 7" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "U0EQ6WD4zEf0" - }, - "source": [ - "#### สร้าง X และ y" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "t2aoyhac4ZrI" - }, - "source": [ - "จัดการ `Inputs` คือรูปตัวเลข 28x28 pixels และ `Labels` คือ `1 ถ้าเป็นเลข 3` และ `0 ถ้าเป็นเลข 7`" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": { - "id": "-4LTBmPcpoj_" - }, - "outputs": [], - "source": [ - "train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "Q5bU0JpipokA", - "outputId": "fa8b6406-e3a7-4721-fef1-218c6e38fbf7" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(torch.Size([12396, 784]), torch.Size([12396, 1]))" + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JDsXZWi8pojy", + "outputId": "34c39b1d-84e9-43ed-8364-e30f8b4969a3" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(6131, 6265)" + ] + }, + "metadata": {}, + "execution_count": 13 + } + ], + "source": [ + "seven_tensors = [tensor(Image.open(o)) for o in sevens]\n", + "three_tensors = [tensor(Image.open(o)) for o in threes]\n", + "len(three_tensors),len(seven_tensors)" ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)\n", - "train_x.shape,train_y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "rkavLIe1IcmD", - "outputId": "10446b01-f0b3-470d-cdcd-fb1752c76a8e" - }, - "outputs": [ { - "data": { - "text/plain": [ - "[(1, 4), (2, 5), (3, 6)]" + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 85 + }, + "id": "9h_vpbbvpojy", + "outputId": "3873c893-9aad-43cb-be95-7ba2aa239f63" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAJHElEQVR4nO2bXXMSZxuAL1jYXYQsiCYxIWIIjImJ0TaVTu2H41FnnGmPPOtMf0NP+i/6H9oDx+OOOtMjW6cdGz/Sk1ZqxkRDhCTQEAjfsLDsvgeWbbMmxgrEzDtcR5ln+bi5eJ5n7/t+iM0wDPr8g/1tB3DY6Aux0BdioS/EQl+IBcc+1/+fb0G23Qb7M8RCX4iFvhALfSEW+kIs9IVY6Aux0BdiYb/ErGMMw6DVatFsNqlUKmiaRrPZpNls0mg0Xnq8LMtIkoSu6+i6jtPpxOFw4HQ6EQQBWZZxOHoXds+FaJpGvV4nHo/zww8/kMlkSKVSxONxnjx5wr/7MTabjffee4+pqSmq1SqqqjIyMsKxY8eIRCIMDw8zOzuLz+frWbxdF6LrujkDisUilUqFbDbLn3/+yeLiIoVCgfX1dTKZDJVKBVEUEUWRer1OrVZjZWUFh8NhCqlUKmxublIulzl+/DgTExM9FWLbp2P2n2sZVVXZ3t5mdXWV77//nlQqxaNHj8jlcqTTaQzDwDAMZFnG7XYzODhIIBDgyZMnPH/+HLvdjt1uN2eOzWbDZrPh8/nwer1cv36daDT6hh93B7vWMh3PEE3TqNVq1Go1UqkUlUqFVCrFysoKT58+pVQqIQgCExMTRKNRnE4nkiQhiiIulwtFUfB6vczOzpLJZFheXubZs2eUy2VqtZr5Po1GA1VV0XW905BfScdCVFUlFovx22+/8c0331CpVGg2m7RaLRqNBoFAgGg0ysWLF7l69Soejwe3220+3263Y7PZ0HUdwzC4ceMG165dIxaLkUwmOw3vP9OxEMMwaDQaVKtVSqUS1WoVXdex2+2IokggEGBubo5z587h8/lwOp04nU7z+e0l8e+lJAgCNtvOGe33+wkGg8iy3GnIr6QrQqrVKrVaDVVV0TQNAFEUOXr0KHNzc3zxxRf4fD4URXnpg7ZpjzudTux2+0vXpqenmZmZQVGUTkN+JR0LcTqdhEIhRFEkn8/TbDaBF0Lcbjdzc3N4vV5EUdxTBrzYizRNI5/Pk8/nzRzFZrMhCALDw8OEw2FcLlenIb+SjoVIksTp06cJh8N88MEH5nh7KQiCgCiK+75Oo9GgVCqxtrZGMpmkWq0CmM+PRCJEo9Ed+08v6FhI+1sXBGHH3tC+Zp3+e7G+vs7PP//M77//TqlUQtM0BEEgHA4zPj7OzMwMJ06ceOk9uk1XErP2bHidmbAXd+7c4auvvkLTNHRdx+FwIIoiFy9eJBqN8u6773LixIluhPtKep66WzEMA13XqdfrlMtlM2FbWFgwN2S73U4kEiESifDhhx8SjUZ7vpm2OXAhuq6jaRrZbJbHjx/z448/cvPmTbLZrHm7djgcXLhwgY8//phPP/2UkydPHlh8B1LtGoZBsVhkdXXVrGWSySSrq6ssLS2Ry+Wo1+vAi3zD7/dz+vRpzp8/z8DAQK9D3MGBlf/JZJLvvvuO5eVlHjx4gKqqO1LzNkNDQ0xPT3PhwgUmJyd7fpu1ciBLRtd1CoUCjx49Yn19HVVVzXzFSiaTIRaLcevWLeLxOMeOHcPj8TA2NobX62VwcLCnkg5syWxtbfHgwQM0TTM31t3IZDJkMhnW19dxu90oioLH4+HKlSucP3+eS5cuIcvyK5O8Tuh6+f/SC/y9ZDY2Nrh9+zblcplCoUCxWCSXy5mPi8fjLC0tmT0UURRxOp1mv2RqaorR0VE++ugjzpw5w9mzZ/F6vQiC8Nq5joVdjfZciJVarWbKWFtbM8d/+uknfvnlF1ZXV0mn07s+t91Ri0QifP3110xOTiJJEoIgvEkovemH/FecTieKoiDL8o7O19DQEJcuXWJtbY10Ok0mkyGXy3H//n3i8TjwYrYlEgmq1SrPnz9ncHCQ48ePv6mQXTlwIQ6HA4fDgcvlwuv1muPDw8NMT09TrVbNW/TKygrZbNYUArC5uUkulyMejxMKhTh69Gh34+vqq3VAuxBsd9VlWSYYDBKPx/nrr79IJBJsb28DL2bKxsYG8XicU6dOdTWOQ3Mu0y4EJUkye63BYJDZ2VmmpqZ2pO6GYZgzR1XVrsZxaITsRbtwtI55vd6eVL+HXgjAbndCt9uN3+/v6oYKh2gPsVIsFtne3mZhYYGHDx/uyFlsNhunTp0iEol01HLYjUMppF0MJpNJEokEiURiR2YrCAJDQ0P4/f6uH2seOiHVapVKpcK9e/e4c+cOf/zxh3lEATA5Ocn4+DiBQABZlt80S92TQyOk/YHr9TrZbJZYLMb8/DypVMq8ZrfbCQaDRCIRvF4vDoej6zXNoRFSKBRIp9Pcvn2bu3fv8vjxY5LJpNknGRgYwO12c/XqVS5fvszo6Oiu5zed8laFtCthwzDI5/M8ffqUhw8fcuPGDbO3Ci820SNHjjA4OMi5c+cIhUI9kQFvUYiqqqiqysbGBktLS9y9e5f5+XkSiQTNZtNcJi6XC1mW+fLLL/nkk08Ih8M9kwFdFvLvb9yaULXH2383Gg0KhQLxeJz5+XkWFha4d++e+fj2me/AwACKonD27FneeecdPB5Pz2RAF4Vomka1WqVcLrO2toaiKIyMjJiH3pVKha2tLUqlEpubmywvL7O4uGjWKfl8HvgnM41EIoTDYa5cuUI0GiUUCqEoSk9/PQRdFNJqtSiVSmxtbRGLxRgdHUWSJPMXRLlcjuXlZba2tsxl0u6tqqq645RPFEXGx8cJh8NEo1FmZmbMhlGv6ZqQfD7Pt99+SzKZ5P79++Zhd6vVotVqmecwqqqaM6ZSqZgb5/DwMGNjY7z//vtmk/nkyZMoioIkSV3PN/aia0IajQaJRIJnz56xuLj4Wj9ssdlsSJKEJEmMjY0RiUQ4c+YM0WiUiYkJ/H5/t8J7bbomxOPxcPnyZTweD7/++uu+QmRZ5siRI3z++ed89tlnhEIhRkZGcLlcB7Y8dqNrQhwOB8FgkHQ6TSAQMI8l98LlcuHz+ZicnGR2dpahoaEdHbS3RdeazK1WyzxvKRaL+7/x3w0ht9ttdsm6XcrvF8KugwfddT9E9P+j6nXoC7HQF2KhL8TCfrfd3lVRh5T+DLHQF2KhL8RCX4iFvhALfSEW/gcMlBno19ugeQAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "show_image(three_tensors[1]);" ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#เราสามารถนำ iterables สองอันมาต่อกันแบบนี้ได้ด้วย zip\n", - "a = [1,2,3]\n", - "b = [4,5,6]\n", - "list(zip(a,b))" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "08BIemoApokA", - "outputId": "1ed03c6d-72af-4ff9-cb58-11251e7790e6" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(torch.Size([784]), torch.Size([1]))" + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "D5B7mZNspojy", + "outputId": "0d8770e5-0ccc-4af4-8781-f871c04e7b14" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(torch.Size([6131, 28, 28]), torch.Size([6265, 28, 28]))" + ] + }, + "metadata": {}, + "execution_count": 15 + } + ], + "source": [ + "#เรามัดเลข 3 และ 7 ทั้งหมดรวมกันตามคลาสแล้วหารมันด้วย 255 เพื่อให้ได้ค่าระหว่าง 0 และ 1\n", + "stacked_sevens = torch.stack(seven_tensors).float()/255\n", + "stacked_threes = torch.stack(three_tensors).float()/255\n", + "stacked_threes.shape, stacked_sevens.shape" ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dset = list(zip(train_x,train_y))\n", - "example = dset[0]\n", - "\n", - "#คู่ Inputs และ Labels\n", - "example[0].shape, example[1].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": { - "id": "Y6ImID5ApokA" - }, - "outputs": [], - "source": [ - "#สร้าง validation set ในแบบเดียวกัน\n", - "valid_3_tens = torch.stack([tensor(Image.open(o)) \n", - " for o in (path/'valid'/'3').ls()])\n", - "valid_3_tens = valid_3_tens.float()/255\n", - "valid_7_tens = torch.stack([tensor(Image.open(o)) \n", - " for o in (path/'valid'/'7').ls()])\n", - "valid_7_tens = valid_7_tens.float()/255\n", - "\n", - "valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)\n", - "valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)\n", - "valid_dset = list(zip(valid_x,valid_y))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RR1v1Z2tzHmo" - }, - "source": [ - "#### Initiate `Weights`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7slIAUaOI0_d" - }, - "source": [ - "สมมุติว่าเราจะใช้ architecture สุดเรียบง่าย แค่คูณค่า pixels ของรูป `Inputs` ด้วย `W` และบวกด้วย `b`\n", - "\n", - "$$prediction = \\Sigma(xW^T) + b$$\n", - "\n", - "เราสามารถเริ่มตั้ง `Weights` เป็นการ random จาก standard normal distribution" - ] - }, - { - "cell_type": "code", - "execution_count": 62, - "metadata": { - "id": "WQ9vpYHIpokA" - }, - "outputs": [], - "source": [ - "def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "itXv1TZ9pokA", - "outputId": "50f3e5d0-626e-4633-a395-92a02991d36f" - }, - "outputs": [ { - "data": { - "text/plain": [ - "torch.Size([784, 1])" + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "K4mmFyngpojz", + "outputId": "03cb8688-0ce1-44a6-c239-25654368389a" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "3" + ] + }, + "metadata": {}, + "execution_count": 16 + } + ], + "source": [ + "len(stacked_threes.shape)" ] - }, - "execution_count": 88, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#Inputs มี dimension (batch_size, 28*28)\n", - "#เพราะงั้นถ้าเราจะคูณรายตัว (element-wise multiplication) ด้วย W^T, W ต้องมี dimension (28*28,1)\n", - "weights = init_params((28*28,1))\n", - "weights.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "7Ef5fK0KpokA", - "outputId": "767be632-37cc-4583-b733-0d8c382d6ffb" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor([0.6971], requires_grad=True)" + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Jck_enHBpojz", + "outputId": "7c732263-92aa-40d7-caec-2c98b91026aa" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "3" + ] + }, + "metadata": {}, + "execution_count": 17 + } + ], + "source": [ + "stacked_threes.ndim" ] - }, - "execution_count": 103, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#bias มีแค่ (1) dimension แล้วจะถูก broadcast ไปทุก dimension ของ batch size เอง\n", - "# (1) -> (batch_size,1)\n", - "bias = init_params(1)\n", - "bias" - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "pnMvUz7kFoIU", - "outputId": "903243e1-5149-4c37-c02c-554d8a5fe665" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(torch.Size([784, 1]),\n", - " torch.Size([1]),\n", - " torch.Size([12396, 784]),\n", - " torch.Size([784]))" - ] - }, - "execution_count": 104, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "weights.shape, bias.shape, train_x.shape, train_x[7].shape" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KXK2LAZqzLYk" - }, - "source": [ - "#### Forward Pass" - ] - }, - { - "cell_type": "code", - "execution_count": 105, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "S9md33qTEu6n", - "outputId": "4b257356-e622-40c9-ea2c-aac1966e4fcb" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor([[0.5587]], grad_fn=)" + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 103 + }, + "id": "cxcx0MIlg-fI", + "outputId": "b7af8b59-816c-4937-fe79-d9b1514d5a3d" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 18 + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAIEElEQVR4nO2bS08TbRuArx6ntJQONBQoYIFgEPCQYBQ0Rl2YuDFGoy504Q/xJ7j1B7hwR+LGjboQSZRo8IQBLIdWC5aDI1Baep525luYzgulKF++TjHv1yvpZmae9u41z9z3c2gNqqpS5R+MBx3A30ZVSBFVIUVUhRRRFVKE+Q/n/80lyFDqYLWHFFEVUkRVSBFVIUVUhRRRFVLEn8pu2VEUhXw+Tz6fR5ZlMpkMqVQKs9mMxWLBbDZjMpm06wVBwGKxYDCUrJJlp+JCstksW1tbrK6uMjs7y/v373nz5g0+nw+v16u9Cpw5c4bm5maMRmNFpFRMiCzLpNNp1tfX+fbtG9++fWN+fp5AIEAoFCKbzZJMJolEIqysrABgNBrp7u6mvr4eq9WK2ax/uBUTEo1GmZiY4MWLFzx69IhUKkUymSSXy5HP51laWuL9+/cYjUaMxl+pzWAwIIoiDoeDlpYWamtrdY+zYkIsFgv19fXYbDaSySTpdJp0Oq2dz+fzJdvNzc3x7t07hoaGMJlMWp7RDVVVf/cqG7Isq4lEQh0eHlZbWlpUp9Op8muu9NuXy+VS29vb1QcPHqh+v1+Nx+PlCqnkd65Y2TUajZjNZlpaWjh9+jRtbW2YTCYtUZrNZux2+667n0ql2Nzc5MePH0iShCzL+sap67tv/yCjEavVSldXF3fv3uXixYsIgqAJqK2tpampCYfDsaNdNpslHo8TDAYZHx8nHo/rG6eu716C2tpaenp6aG9vx+VyIQgCACaTCUEQdoxBtuPxeDhy5Ag1NTW6xldxIU6nk76+PoaGhvD5fLhcLgCsVit1dXVYLJZdbQwGA8ePH+fs2bO6V5qKD8wKOcPtdjM0NIQoilq1kSRpR+U5CCoupIDX6+XWrVuMjIwQi8UIhUKEQqE9ry8M+fXmwCZ3NTU1tLe3c/78eW7cuMG5c+dwu90lc4SqqgQCAfx+P6lUSte4DqyHOBwOHA4HTU1NDAwM4Ha7CYVCLC4u7vrSqqoyPj7O1tYWXq8XURR1i+vAhKTTaeLxOIlEgmg0SiAQYHV1lUQisetag8GA2+2mpaUFq9Wqa1wHJmRrawu/38/y8jLhcJhPnz7x/ft31D32mj0eDz09PbqX3YoJyeVyyLLM2toac3NzzM/PMzMzQywWIxqNMjc3t6cMgKamJrq7u7HZbLrGWVEh8XicsbEx7t+/jyRJ/PjxA0VRUBTlt20NBgM+n4/Ozk5tIKcXFasy+XyeWCyGJEmsra2RSCRQFOW3vaKAqqrMz8/z5cuXf0+VyeVybGxsIEkSP3/+RJblP/aM7czNzWGz2Whra9NGt3pQMSGCINDR0cGFCxcIh8P4/X4+fvy4r0cGIJlMEovFdB+cVUyIzWbDZrPR39/PlStXEASBycnJffUUVVVJJBLE43FyuZyucVa87IqiyODgIKIo0tDQQDabJZvNaudVVUVRFMbHx5menkaW5YoM2QtUXIjdbsdut+N0OmlsbCSTyZDJZLTzBSHpdJpgMKhtWVSKAxuY1dTU4PP5Sk7aFEXh4sWLKIrC69evCQaDrK6u4nA42NjYIJ1OY7FY9lw7+V84MCGCIOw5plBVlRMnThCLxVhYWCAYDLK5ucni4iKxWIxMJoPJZNJFyF+3lVl4ZKampnj8+DEzMzPAP2uyem9YHVgP2YuCkEAgwOjoqHa8sAXxfydkdXWVYDDI/Py8dsxgMDAwMMCxY8doa2vDZrPp8riADkIK+xvb7+J/c0clSWJsbIzv37/vOO71eunv70cUxZLrruWibEIURdGG569evaKhoYGOjg5EUcTtdv+xfS6XI5fLMT09zdOnT3f1kK6uLk6ePIndbi9XyCUpW1JVVRVZlvn58yfPnj1jdHSUUChEJBLZcxK3fccsl8uRyWRYXFzk8+fPrK+vA79kmEwmmpqaaGxs1LV3QBl7yNbWFs+fP2dycpKXL1/idrtZWlri1KlTXL9+HavVitVqxWKxIAgC6XSaVCqlbXr7/X4+fPjAq1evtJkwwODgIL29vfT19emaOwqUTUgqlWJiYoJAIEA4HCYSiZBOp3E6nUiSpK2hOhwOzGYz2WyWSCRCNBolEonw9u1bRkZGCIVC2nzFaDTS2dnJsWPHaGho0MqunpRNiNPp5OrVq3z69InFxUXW1tZYWFjgyZMnzM7OarmksDYaDof5+vUr8XicaDTK8vIy6+vr2r5MQ0MDoihy6dIlLl++TH19/Y69YL0omxCz2UxrayvxeJy2tjby+TwrKyuEw2FCoRC1tbV4PB6am5s5dOgQX79+xe/3k8lkdkzu4FfecLlc+Hw+Dh8+jMfjqYgMAMMfVqz2/dNuRVHIZrMsLS3x8OFDVFXFarXi9/sZHh7WNrutVqv2G5FkMrkr4dpsNgRB4N69e9y8eROPx4PdbsdgMJRbSMk3K1sPMRqN2Gw26urqaG1tRRAEvF4vsixjt9u1JJnJZDQRpWaxdrsdURQ5ceIEXV1d5Qpv35R9YCaKIrdv3wZ+df3CD+YKQjY3N5EkCb/fz9TUlNauUH3u3LnDtWvXOHr0aLlD2xdlF2KxWBBFUZuTNDc3MzAwoPUGSZJwuVwoioIkSVq7worakSNH6O3txel0lju0fVG2HFKy8bYBV+FzCo9KJpPZsdNfyBGiKFJTU1OJElsyh+gq5C+n+n+Z/VAVUkRVSBFVIUVUhRRRFVLEnwZmlfmTyl9EtYcUURVSRFVIEVUhRVSFFFEVUsR/AP0FXN1zCRLUAAAAAElFTkSuQmCC\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "#เลข 3 ที่ index 0\n", + "show_image(stacked_threes[0,:,:])" ] - }, - "execution_count": 105, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#คำนวณ forward pass สำหรับตัวอย่าง 7\n", - "(train_x[7]*weights.T).sum(1)[:,None] + bias" - ] - }, - { - "cell_type": "code", - "execution_count": 106, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "22qYPEQDEwjC", - "outputId": "a4b12d0c-c2fa-435b-ae26-611940166c7e" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 7.6066],\n", - " [22.2802],\n", - " [ 9.7072],\n", - " ...,\n", - " [ 8.1622],\n", - " [ 9.2091],\n", - " [-0.7656]], grad_fn=)" - ] - }, - "execution_count": 106, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#คำนวณ forward pass สำหรับทุกตัวอย่าง\n", - "(train_x*weights.T).sum(1)[:,None] + bias" - ] - }, - { - "cell_type": "code", - "execution_count": 107, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "I65On42BpokB", - "outputId": "3c2f1528-6aae-4999-ee30-c086090ea535" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 7.6066],\n", - " [22.2802],\n", - " [ 9.7072],\n", - " ...,\n", - " [ 8.1622],\n", - " [ 9.2091],\n", - " [-0.7656]], grad_fn=)" - ] - }, - "execution_count": 107, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#เขียนเป็นฟังชั่น; @ คือ matrix multiplication ใน pytorch\n", - "def linear1(xb): return xb@weights + bias\n", - "preds = linear1(train_x)\n", - "preds" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XJKeTmrPzOr3" - }, - "source": [ - "#### Metric ใช้ Accuracy" - ] - }, - { - "cell_type": "code", - "execution_count": 111, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "SLnbggPopokB", - "outputId": "e41a9fc7-78f5-4d9e-f46d-b9e08a62eaf9" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ True],\n", - " [ True],\n", - " [ True],\n", - " ...,\n", - " [False],\n", - " [False],\n", - " [ True]])" - ] - }, - "execution_count": 111, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ถ้าสมมุติว่า Predictions >0 ให้ทายเป็น 1 (เลข 3)\n", - "corrects = (preds>0.0).float() == train_y\n", - "corrects" - ] - }, - { - "cell_type": "code", - "execution_count": 112, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "5DVZlAVzJqbV", - "outputId": "60c481ad-1b1b-40a8-e74c-0fff4962eda5" - }, - "outputs": [ { - "data": { - "text/plain": [ - "Counter({0.0: 647, 1.0: 11749})" + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 85 + }, + "id": "2Es3hPcXpojz", + "outputId": "492d2c79-b892-48a9-f20d-8c4eba02024e" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAJtUlEQVR4nO1b2XLiWhJM7QsChDG22x3h//+qfnKzWVhoX5HmoaNqDufK9jRge2aCiiCEAS0nVUtWlqz0fY+r/dvU776A/za7AiLZFRDJroBIdgVEMv2D7/+fS5Ay9OHVQyS7AiLZFRDJroBI9lFS/RQ7tV1QlME8eFH7dEDkxdPf4udDAMmLVxQFfd8Pfn5JuyggQ4vs+/7o1XUdfy6+l01RFCiKAlX9E9WqqvJn9JJ/fwm7CCDy4ruu423XdTgcDjgcDqiqCm3boq5rtG2LpmlwOBzQti3vr6oqVFWFaZrQNA2WZUHXdViWBU3ToOs6NE3j3xFQZOcCcxYgQ15AIHRdh7Zt0bYtqqpCXdcoigJlWaIoCtR1jTzPUdc1yrLk4+i6DlVVMRqNYFnW0dY0Tdi2DcMwYBgGNE1jEGRgTrWTARHBkD2haRpUVYU8z1EUBcIwRJIk2Gw2CMMQu90OaZoiiiKUZYk8z3E4HNB1HUzThGEY8DwPo9EIi8UCs9kMj4+PmM1mmM/ncF0X4/EYlmUdeY4YYqeCc7aHDAFSliWqqkKSJMjzHLvdDmEYYrPZII5jBEGALMsQRRGyLEOWZRw6dOdnsxk8z0PXdSiKAoqioKoq6LqOw+EAXdfR9z17CYXPUOL9dEDkXCHmiKqqEMcx0jTFdrtFGIZ4fn7Gfr/HZrNBmqYIggBJkiCKIlRVhbIs0bYtDocDn8NxHFiWhYeHB9ze3iIMQ9zc3CDLMtze3qJtW0wmEyiKAsdxjpLvOXZ2UhXzBy1K3LZty+GlaRpM08RoNIKiKNA0jQFpmgZt27Knkaf0fY+6rlHXNYdhlmX8N53DNE2+ji/3EAJiKHfQhdZ1zVVEVVVYlsVxb9s270P7U/WhvNM0DZqmga7rnJgp76iqivl8Dk3TMB6Poes6uq7jkKEbcAowJ4eMaHRiimNd19kTyHNs24Zt27xQApRAITAJkKIokOc5NE07SpbyfqKHXsIuRszoonVdh+M4nOw8z4Pv++wB8j60QAJgv98jjmOuTMRZdP3PpYqe+R4Q31Jl6MQiGADQdR2Tp6ZpOESImYpGn2dZBl3XUVUViqJgPkILo5xD5zEM44ikib87x04CRD45ubVpmnyRXdfBcZyjaiTfTSJvTdPAsiyYpvmPUAFwlBNM02SCRpxlCLxT7SwPES9AVVVehBgKsnuLpZq25BVhGGK/32O/3yNNU+R5zhVLVVUYhnHEXonWEykb6nG+HBDxLhIQcqKTgaBS3Pc9V4+Xlxes12usViu8vLwgiiLs93tesKZpsG0bk8kEvu9jNBrBdV0GhRL6uXYyIDIQiqKg67pBUESeIvKJJEmYwa7Xa2y3WwRBgN1uhyzLkOc5JpMJl2rXdeF5HsbjMRzHOQpRsRv+FkBEUAgEIlJvtfvU4NHdXy6XWK1W7BUESBiGHIau60LTNDiOg8lkwpTecRw4jnPkHd+WQwgA8b28JRAOhwN3tHEcI4oiDo3lcsmhst1usdlsUBQFqqri5EkVi7jNUHW5VIU5GZD/BBSRxRZFgSRJsN1usVqt8OvXL6zXazw/P+P3799YLpeI4xhJkvAxPc+D53kA/lQxSqhi6y+LRnQt59jFRGbxQihcyDuKouCmbrPZYLPZIAgCLJdLBEGANE1R1zULRCLPAHBUiaiBpLZAZqvnMtazc8iQZirS67qukWUZ4jjGer1mQJbLJZbLJZIkQZIkXIYVRTnyAjpWVVUsFbiue9QMUh/zrSHznhFIIg+h5EpyoO/7WCwWGI1GmEwmDKSmaZxEKUwo7KIoQhAE3NRR9yxrr8A3UnfRZJFZJmZEvy3Lgud5uLu7Y3WNTNRLAcAwDHRdhzzPoaoqXl9foWkabm5uoOs6bNvm4367h7w3SlAUhXOB67po2xY/fvxghlkUBbIs49AiI/BkQbrve65UqqoiiiJomgbXdTnviJ5C1/C3dramKr8XL0am2/P5HI7jwHVdFn1kSk+9DVH3OI4ZOGK1qqpiv9/DsixUVcVeJHriqXaWHjK0FUsxjRNc14VhGDBNE3VdYz6fs2fIgJCC9vr6ijiOmXiRStY0DVet0WiEsixhmib3O8SWvyyHvFVVxO+o+pD7ip0pdbhDs5yu61igtm2bQ4tAAoC6rqFpGqv1ojInMmU69t8Cc5aHyG39EDAU3zRzeav5I05BJZd6HgLGMAxOvqJsQFLkECf58hwytDBxS6CQejaUa2QPIbNtm0uvyErl38tgnGt/BYjsGVQd3tM236PW4gLp+LRYyhUkWFOYDc13L8FQyU7OISIAlBxliVAeWL8Finx8keWSQCQn7KF9ZftS1V2Me3FoTQuStVZRURMBot9TcozjmGn+arXCbrfjkSclTmK7NBCn7ve9pwM+DRBZ6xABaZrmKBcQ42zbli9cFKMVRWFQReEoTVNW37MsQ1EURyFD9F7WU7+NqYqA0BCpbVu+8Kqq+Ht5VkPAkFG1oMZts9mwRvL6+ookSVAUBatjJBT5vo/pdMojTzr2uU3eSUlVBoW8g2a0RVFwDiAKL48OyGhARSradrvFbrdDEAQcKsRGiehRBSJuQ99dwlP+ChBZFBJPTqSqLEsEQcC0m7xIfIaDRo9Ex8uyRJqmSNOU5YA8z1GWJfMQz/MwnU5xf3+Pu7s73N3dYTKZwPM82LY9mEc+HRARGPmkYnIkgTiKIvYcAk0UpEV5kYbYSZLw4xF93zMpcxwHnudhMplgMpnAcRxmwEO66qn214CIYFBiI+2TdFAAR3khDEPUdY04jo9yDok8IuOkhOn7PsbjMR4fH+H7Pp6ennB/f4+npydMp1PMZjMGhfb58pAZAoUSpvwiLxCbsd1ux5WE8g6Vb3J3Yqeu68L3ffi+j/l8jsVigdvb26MwuVQiPRkQOinxCHERNL60bRsAMJ1Ooes69vs9DMNAkiQwDIPlRNI66O7SvGU+n2M8HuP+/h6z2Qw/f/7EYrHAzc0NC88iGEPc5ssAkcERgSHdAwCr5VmWQVGOH4Ui8IjIkUfNZjOMRiPMZjP4vs9PDj08PGA8HnPeoFmMqKxdcgyhfNADDH451NOQuEP6Z9M0XCnSNOVHrUg9pypDi6PRpLilZ0rkofZQiT0BjMEdzgJkiLVSmaXRgchPaEu5gzQTEotN02SSRS+x2x16NvUMr7gcIPzlAFED/tn9fjQ7kZO03JO81aOcGSKDO19ktiv/LbblQ9uPjvfW9q3zXtLO8pCP7BIaxScu/vIe8uEZP/FOfpZd/4FIsisgkn0UMv97Pn+mXT1Esisgkl0BkewKiGRXQCS7AiLZvwBtCZqwAvXF1QAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "#เลข 3 โดยเฉลี่ย\n", + "mean3 = stacked_threes.mean(0)\n", + "show_image(mean3);" ] - }, - "execution_count": 112, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from collections import Counter\n", - "Counter((preds>0.0).float().numpy()[:,0])" - ] - }, - { - "cell_type": "code", - "execution_count": 115, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "CiSVOFonpokB", - "outputId": "8a2b8106-f5f1-4b7f-dab9-4dae99c5bf88" - }, - "outputs": [ { - "data": { - "text/plain": [ - "0.5121006965637207" + "cell_type": "code", + "execution_count": 20, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 85 + }, + "id": "ambkeHzzpojz", + "outputId": "e21cad11-daa9-459a-ebd7-a5a7169567c4" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAI6klEQVR4nO1baVPiTBc9ZF/IhoOWOjUf5v//KgelNEZICCErvB+eutemjaNC8N04VanG7H1y99uOdrsdzniF8u9+gf80nAmRcCZEwpkQCWdCJGgfHP9fdkGjvp1nCZFwJkTCmRAJZ0IknAmRcCZEwpkQCWdCJJwJkfBRpHoQPlNjkc8ZjXoDxzf47HmHYhBCxMnR74/Gz0IkYDQa8d/y2Hf+ITiKEHmS2+2Wx91uxxv9TcfFY+L1BHGy9FtRFIxGo96RjsvXH4KDCBEnIk6aJt51HbbbLbquQ9d1aNuWR9pP58n3kUET1zQNqqrCMAyoqgrTNKGqKjRNg6IoexvhEGK+TIj44kQCEdA0DZqmQVmWaJoGm80GdV1jvV6jrmvkeY6qqrDZbNA0Deq6Rtu2TFSfJKmqCkVR4LouTNPExcUFXNdFFEVwHAe+78M0TViWxQSNRiOoqordbvdlUr5EiCwZ9KVJAoqiQNM0WK/XKMsSWZahKAokSYLNZoMsy1CWJRPVNA1fS6SKKgaACRmPxzBNE9PpFJ7n4devX/B9H5qmYbfbMRHb7RaKohxExpcI6VMPmgxNcLVaoSgKxHGMNE0xn8+RZRniOEae51gsFlitVlgulywpoiqJ6kQbffUgCOB5Hn7//o0wDJGmKS4vLwEAQRBA0/6ZCqnYyQnpI0ckhlSFJCHLMqRpymOe50iSBOv1GqvVCm3boq7rvfuJIHLatkVVVRiNRmiaBlEUQVEU5HkO13VZ0kjCjsWXVUYkous6NE2DqqpQliXyPEeapojjGMvlEnEcI8sy3N/fY7VaIUkS1HWNsiyZAEVR2DCS1xCfQXamrmvoug7TNFHXNS4vL2EYBvI8h23bLK3vGeeTEPI3kIskEdd1HbquwzAM+L4PVVUBgL+6fA55ETKyy+US6/UaWZZhvV7zMwDskSkSSb9Fd/1VfIqQjxgXyaAJmqYJ27aZhPF4jDAMoaoqu03HcfhcXdehqip7qiRJkGUZ7u7u8PT0xAabDCdBdrnHkPFpQshIyQQoisISsd1uYZomuq5DEARQFAV1XcO2bei6ziqg6zosy4Jt23BdF5ZlwbIslpDNZoOyLHl/nucoioLVgQik0bIsjk2+jRCRiD5CDMMAALiuC1VV0XUdTNOEoiioqgphGLKtoNjB8zy4rssTo3sSIZ7nYTabIc9zrNdrVFWFtm1hWRYcx4Ft27Btm4kjCRNV5tu8jKizAKDrOn89ALBtm41j13WoqgqapsE0Tbiuy5LhOA50XedYYrfb7UWmwD/qRt5IVVU4jgPXdTEejxEEAUuIaJiPwacJER+kKAoHQPTypNuqqrL6ULS42+1YVSzLYsmwLIuJJZUiIk3TZEKqquKYxHEcOI4Dz/M4WqUoVbQjJydEJIaCHjHxAv6RFABMhghR98muEJF0Pbnbuq6RZRkHcuRlDMPAeDyG7/sIw5CjV13X3+Qxh+JglQFeJYV0V9d1Jqxt2z2dFr0P6TtNgK6h2KYsS6RpisVigcViwUEYqVwQBIii6A0hx7rcLxMiehvRsJLEkOEkkogQ2i8TQRCDvDzP8fz8jMfHRzw/PyNNU9R1zaG77/sIggCu68K27T3SAey938mTO3oQPVj0OuRxSBrETJU2MVUX70PGl/Kh+XyOp6cnzGYzpGmKpmmgaRo8z4PneQjDkCWGpGMoHBypytICvNoSsh80igUdoL+EUBQF0jTF/f097u7uEMcx4jhG0zRQVRVhGGIymfBIrvZYFZFxVOjeV84Tv5ZMmLyfJKNtW2w2G86QHx4eMJ/PkaYpezFys7TJrnYoUg42qrItod8A9ryGDLmMQMlekiT48+cPZrMZ5vM5FosFmqZhb3J1dYXLy0tMp1OMx2OOP0TJG4KUoyVEtCVkYGkTiZOLS2RI67pm6Xh4eEAcx3h4eECe5+i6DoZhIIoihGGIi4sLNqhiZErvMgSOtiF9Ve/tdvsm/wH2SaGSY1EUeHl5wWw2w2w2w8vLC6uK53m4vb3Fz58/cX19jZubGwRBwEmhHIx9u9v9G+RIVm5NyG6RCktUR0mSBEmSYLFYoCxLqKoK27bx48cPXFxcYDKZYDKZwHEcmKb5xn70kfBtucxnH0xSIhdtKHCjuuv9/T3HHGVZQtM0RFGEIAhwfX2N29tb3NzcIIoi2LbNkXBfyn+sCg1iQ+S/33sZUWWoClYUBZbLJfI8R5Zl7GZ938d0OkUURZhMJvB9H7Zt73mX98g4BkdLSB8phPdUheqvaZri6ekJSZJwi4LynNvbW0ynU5YQMqaidPTZjW/Ldv8GedLv7SPVIemgEiGRQV5FzGajKILneRyIDW1EZQza7O7zLMCrVxGN6HK5xOPjI6vLdruF4zgIggC2bePq6gpXV1fchxErY3LkKz7/WJx8OURfy4J6Mnmec08HANdIKGchcvq8CtAfFB6LQSVEVg+5b0M9mcVigefnZywWC2w2G4xGI1aJyWSCKIrY1ZLdoJrr0KG6jJOtD/lbD2ez2aAoCpRlia7rOF/RdZ2Lz5TeExGniEr7cJL1IWJoTipSFAVWqxWyLMPLywvyPOe2AnkWUUJEdZELQKfE4CrTJx3Ua6mqipO5uq65qKxpGtsPalGIlfTvIgMYgJC+tSLUZyXpoJ7ver1mQ0pVNbHOats2wjCE7/uwLOvDmOMUGNyG9NkOMqo00nIHsTBMTScav8uIyhhsSdV7RpQWxtBGrQbTNN8siKEikOhqRYP6X6EyIvoW1IgrgyiUp8YUqYSqqiwdhmFwi2KIPstXMWhy99451AR3XZcnKa8CIBsitiffa1mI49A4SRwC7Lc7adLU6gReWw+apnGb0zAM3khy/lbzOAUpow++8KdWnvTZEHHJFdkSMqxt2+6pEEkRkUNumEj5qDJ2IDG9Fw1eMaNqmdinESdMC+xEQoDXxXWiIZXvcYp0/808hpAQPrknYhX391XPel+qJ4EbqiImPqZ355CE8EXvlADeO953ft/EB5aKgwj5v8P530MknAmRcCZEwpkQCWdCJJwJkfAv6ObhbeIGuNEAAAAASUVORK5CYII=\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "#เลข 7 โดยเฉลี่ย\n", + "mean7 = stacked_sevens.mean(0)\n", + "show_image(mean7);" ] - }, - "execution_count": 115, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#จะเห็นได้ว่ามันถูกเครื่องๆ (เพราะเราเดาสุ่มด้วย weights ที่สุ่มมา)\n", - "corrects.float().mean().item()" - ] - }, - { - "cell_type": "code", - "execution_count": 116, - "metadata": { - "id": "RwkizTBqpokC" - }, - "outputs": [], - "source": [ - "#ถ้าเราเปลี่ยน weights เล็กน้อย\n", - "with torch.no_grad():\n", - " weights *= 5.0001\n", - " bias+=10" - ] - }, - { - "cell_type": "code", - "execution_count": 118, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "d5Kv_vFrpokC", - "outputId": "11814646-8f8b-44ed-eb89-9a6deed931ef" - }, - "outputs": [ { - "data": { - "text/plain": [ - "0.5012907385826111" + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 85 + }, + "id": "aKESuqYHpojz", + "outputId": "2ef7a17e-7e6e-4844-b49a-c094483c5679" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEQAAABECAYAAAA4E5OyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAJUElEQVR4nO2by08bVxuHn/HM2OP7BbsmYJNQKCQUCGkuyratmi7aRbfZt1L/ga66Sfbtv9FVF11E6aKLKG2RqBIlaUpSih0IoSF2YkrAl7l5xl1830w/GxJKGBP0yY+EkOacmfPOb95zed9zLLRaLXr8g+91G3DY6AnSQU+QDnqCdNATpANpl/L/5ylI2Oliz0M66AnSQU+QDnqCdNATpIOeIB3sNu16hm3bWJaFaZo0Gg0Mw0BV1RfW9/l8+Hw+UqkUiqIgiiI+X/e/34EJYhgG1WqVxcVFrl69ysLCAtevX39h/VgsRiwW48svv+Tdd98lGo0SCAS6bmfXBFFVlSdPniAIAqIoUq/XefbsGXfu3KFQKLC4uMj6+vpL769Wq9y6dYu+vj6GhoZIJBJdF0bYJR/yyivV3377jc8//xxJkkin0zx//pwnT57w/Plz1tfXsSwLy7JebJggIAgCwWCQQCDAxx9/zNtvv83FixfJ5/OvalZbEztd9NxDTNNEVVXK5TKlUgmAjY0N6vU6m5ubaJqGaZr/sUjY0SaXVquFqqrous6jR48IBAJsbm6SzWaRJKkrY4rngjQaDW7evMns7Cy1Wo1Go4GqqjieuNcMXavVwrIs7ty5w9LSEh9++CHZbJZYLNaVruOZIM7s8fjxY2ZnZykUCqiqSrPZxLZtt14qlSKfz5PJZHjjjTe2PadYLLKyssLW1lbbLNRsNjEMA13XMU1zz8L+WzwTRFVVbty4wezsLF999RWapu04RkxPT/PJJ59w9uxZJicnt5VfuXKFb7/9llu3brGysuJet23b7Y6GYRx+QVqtFoZhYJomhmFsE0OSJBRFIZfLMTk5STabbXN5VVWp1WqUy2XW1tbQNK3tflmWURSFvr4+otEooih6ZXq7nV4+zBFkp68XCoXIZrNMTExw+vRpAoEAfr/fLS+VSszPz3P79m3u3r2LYRht9weDQZLJJLlcjlQqteuA/Kp4JojP5yORSBCJRJAkyZ0yE4kEw8PDDA4OMjo6yvnz5/H7/QiCgGmaaJpGtVplfn6e69evs7y8jGma7rgjiiKSJHHy5EnGx8dJJpNdEwM8FEQURdLpNOl0GkmSCAaDZLNZTp06xYULFzhx4gQTExPIskwgEHAHyKWlJX755ReuXbvGlStX0HWdZrPpPtfv96MoCu+99x4XLlzYcSD2Es8EkSSJTCbD6dOn+fTTTxFFkUQiwdDQEDMzM6RSKWRZdvu+rutuN5mbm6NQKKDruusZ8XiccDjM5OQkg4ODnDlzhv7+fmRZ9srkHfF8pWoYBrVaDZ/PhyiKrkd0uvnq6io//fQT33//Pd988822cWd8fJzR0VE++ugjzp07x8jICPF4fK/mvIyDWamKooiiKAiC4IqyU5+vVqvcu3ePUqnUJkYmkyGdTvPBBx/wzjvvMD09TX9//4EEdtAlQUKh0K71KpUKc3NzLC0ttV0fHh7m5MmTXLx4kZmZGSRJ6toUuxOvNUHUarW2dRVd1914Z6fybnPoMma6rlOv19F1HcuyDlwQ8dKlSy8rf2nhflBVlXq9jm3bPHr0yPWGVqtFo9EAoFarYVkWzWbT/e+MSx5weaeLr00QQRBIJpM8ffqU+fl5bNum2WyiqiqVSoWVlRV+/fVXNzgUBIFWq0UwGPRq6j18goTDYSKRCOFwGFEUKZfLbeOGruuoqsrjx49ZXl5maWmJ/v5+r3KsOwpyYDnVToLBIMFgkJmZGaLRKLZtc/fuXQAsy6JWq7nBHkAymSQejzM0NEQkEkGWZSTJe/Nfm4c4+Hw+IpEImUyGY8eOceTIESKRCK1Wa9tMY5ompmny4MEDMpkMsiy3rX73yOHyEAfHUxRF4ejRo8zNzREOhwHcQdeyLDRNQ9M0vvvuOxRFYWBgAEmSCIVCbVHzfnntHuIgCAKKopDJZDhx4gTHjx9ncnKSTCaDoihomuZm0GzbJhaLUalUiEQi2LaNoih77UKH00Mc/H4/fr+fcDhMLpfj6NGjnDp1ilgsht/v56+//mJ9fR3btmm1Wty4cYM///yTaDTK1NQUyWQSRVH2bUfXtiH2i5NsevbsGZVKhZ9//plCocAPP/xAsVgkFosRDAaZmpoin89z6dKlvW5PHExw5xWOxwQCAY4cOYKiKAwNDbGwsECxWGRra4utrS2q1SqpVIovvvjCk3YPrSAOTrSczWaRZZlMJtPV9g69IM5WaDgcRhCEfxVJ74dDL4iz5bm2tsbq6qq7G9gtPBWkc4B+1WTw/+7ymaZJs9mkXC5z//59NjY2trXh5ZamZ4Lous7Tp09pNpvouk40GiWdTrsGO67vBHGw87Zms9nENE1WV1dZXV1lcXGRUqlEsVjk4cOHLC8vu3VFUeTMmTOMjIy4i7n94okgzpcslUrutkI2myUSibhBmLO8dl74RckfVVVRVZU//viDmzdv8vvvv7OwsEC5XGZra8ut5/P5kGWZfD7PyMiIZynGfQtiGAabm5sUCgW+/vprd082FArR19dHOBwmnU6TzWY5fvw4Dx484N69e+5SvBNHsFKpRKVSoVqt0mg02uoKgsDZs2d58803+eyzz3jrrbdIJBL7fRXAA0Fs26Zer7O2tsaPP/5IrVZD13W3PB6Pk8vlGB8fZ2Njg8XFRa5du0aj0aBer++pLafbybLMsWPHGB0dZWRkhP7+/v2+xj9t7Helats2hmHw8OFDLl++TLFY5Pbt2207b4qiEAqFiEaj1Go1NjY2sCyr7VTANsP+e2DG+QuFQgQCAd5//32Gh4c5f/48uVyOsbGxVx0/urNS9fl8KIpCLBYjn89jmiYLCwvu0tuyLOr1unuk6qUWCgKSJCFJEn6/3824OymCcDjM2NgY09PTbjfxOifiWSxTr9e5f/8+a2trzM3Nsby8zNWrVzFNs60LvYhIJEI0GuXcuXOMjY0xMTHB8PAw8XicUCjkzlTOGTO/34/P59vPSaLuxjKSJDEwMEAwGETTNCRJYnBwEFVV0TRt1+x5KpUilUoxOjrKzMwMU1NTDAwMEA6HD2yTCjz0EGd94RxsMU2TarXqlu2GMzUrikIgEHBThM551S6wo4cc2vD/AOj9Xubf0BOkg54gHfQE6aAnSAc9QTrYbWHWveN+h5Seh3TQE6SDniAd9ATpoCdIBz1BOvgb74E4FYpj4GsAAAAASUVORK5CYII=\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "#เลข 3 อันที่ 125\n", + "a_3 = stacked_threes[125]\n", + "show_image(a_3);" ] - }, - "execution_count": 118, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ความแม่นยำก็เปลี่ยนนิดนึง\n", - "preds = linear1(train_x)\n", - "((preds>0.0).float() == train_y).float().mean().item()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PTOZYvLRy-Vz" - }, - "source": [ - "#### Loss Function อย่างง่าย" - ] - }, - { - "cell_type": "code", - "execution_count": 120, - "metadata": { - "id": "w7wVT8TjpokC" - }, - "outputs": [], - "source": [ - "trgts = tensor([1,0,1])\n", - "prds = tensor([0.9, 0.4, 0.2])" - ] - }, - { - "cell_type": "code", - "execution_count": 121, - "metadata": { - "id": "9HgV5HQppokC" - }, - "outputs": [], - "source": [ - "#loss function ที่จะเป็นจะเป็น Predictions ของ class ที่ผิด\n", - "#เพราะงั้นยิ่งทาย class ที่ผิดแบบมั่นใจมาก Loss ก็จะยิ่งสูง\n", - "def mnist_loss(predictions, targets):\n", - " return torch.where(targets==1, 1-predictions, predictions).mean()" - ] - }, - { - "cell_type": "code", - "execution_count": 122, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "5xGnIzeHpokC", - "outputId": "f6f63235-c18f-4fb3-ce36-ecb7a1b332d0" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor([0.1000, 0.4000, 0.8000])" + "cell_type": "markdown", + "metadata": { + "id": "k4Rm7EZjU-Gb" + }, + "source": [ + "เราจะเห็นได้ว่าหากเราคำนวน mean absolute error และ mean squared error ของแต่ละ pixel ระหว่าง \"เลข 3 อันที่ 125\" กับ \"เลข 3 โดยเฉลี่ย\" และ \"เลข 7 โดยเฉลี่ย\" เราจะเห็นกว่าค่าของ \"เลข 3 อันที่ 125\" กับ \"เลข 3 โดยเฉลี่ย\" มีค่าน้อยกว่า และในกรณีนี้ระบบ (ที่ไม่ใช่ ML) ของเราจะทำนายถูกว่า \"เลข 3 อันที่ 125\" คือเลข 3" ] - }, - "execution_count": 122, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#if target == 1, loss is how far it is from 1\n", - "#if target == 0, loss is how far it is from 0\n", - "torch.where(trgts==1, 1-prds, prds)" - ] - }, - { - "cell_type": "code", - "execution_count": 123, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "dhQH45dYpokC", - "outputId": "6eb43351-b5a6-4092-b7ac-ccc2ad266ce3" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor(0.4333)" + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JbfdH2zWpoj0", + "outputId": "a2d7d370-8bae-45ab-fa02-e10e6ad40d2d" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(tensor(0.1259), tensor(0.2290))" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ], + "source": [ + "#ความห่างระหว่า \"เลข 3 อันที่ 125\" กับ \"เลข 3 โดยเฉลี่ย\"\n", + "dist_3_abs = (a_3 - mean3).abs().mean()\n", + "dist_3_sqr = ((a_3 - mean3)**2).mean().sqrt()\n", + "dist_3_abs,dist_3_sqr" ] - }, - "execution_count": 123, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mnist_loss(prds,trgts)" - ] - }, - { - "cell_type": "code", - "execution_count": 125, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "uiUjD7aJpokD", - "outputId": "63c11348-6649-4688-f672-032837dbd862" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor(0.2333)" + "cell_type": "code", + "execution_count": 23, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "s5xmB8f1poj0", + "outputId": "fd56a041-cae9-4195-be48-44491f38be2e" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(tensor(0.1836), tensor(0.3390))" + ] + }, + "metadata": {}, + "execution_count": 23 + } + ], + "source": [ + "#ความห่างระหว่า \"เลข 3 อันที่ 125\" กับ \"เลข 7 โดยเฉลี่ย\"\n", + "dist_7_abs = (a_3 - mean7).abs().mean()\n", + "dist_7_sqr = ((a_3 - mean7)**2).mean().sqrt()\n", + "dist_7_abs,dist_7_sqr" ] - }, - "execution_count": 125, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ถ้าทายแม่นขึ้น loss ก็ลดลง\n", - "mnist_loss(tensor([0.9, 0.4, 0.8]), trgts)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dqsrmegupokD" - }, - "source": [ - "### Activation Function - Sigmoid" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-0DXOc8w0YeU" - }, - "source": [ - "ปัญหาอย่างนึงของเราตอนนี้คือโมเดลเราให้ค่า `Predictions` เป็นจำนวนจริง แต่ `Loss Function` ของเราต้องการค่าระหว่าง 0 ถึง 1 เราจึงใช้ `Activation Function` ที่ชื่อ Sigmoid เพื่อเปลี่ยนให้ตัวเลขอยู่ระหว่าง 0 ถึง 1\n", - "\n", - "`Activation Function` นั้นมีมากมายหลายประเภท ที่นิยมใช้ เช่น [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))(x) - ฟังชั่นที่บอกว่าถ้าน้อยกว่า 0 ให้เป็น นอกนั้นเป็น x, [tanh](https://mathworld.wolfram.com/HyperbolicTangent.html)(x) - ฟังชั่นที่เปลี่ยนค่าจำนวนจริงให้อยู่ระหว่าง -1 และ 1 เป็นต้น อ่านเพิ่มเกี่ยวกับ [Activation Functions](https://en.wikipedia.org/wiki/Activation_function)" - ] - }, - { - "cell_type": "code", - "execution_count": 126, - "metadata": { - "id": "sQIUTk5ZpokD" - }, - "outputs": [], - "source": [ - "def sigmoid(x): return 1/(1+torch.exp(-x))" - ] - }, - { - "cell_type": "code", - "execution_count": 127, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 340 - }, - "id": "TY9tdedYpokD", - "outputId": "966a16d1-0eb5-42b7-d056-08c80a35909a" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.7/dist-packages/fastbook/__init__.py:74: UserWarning: Not providing a value for linspace's steps is deprecated and will throw a runtime error in a future release. This warning will appear only once per process. (Triggered internally at ../aten/src/ATen/native/RangeFactories.cpp:23.)\n", - " x = torch.linspace(min,max)\n" - ] - }, - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plot_function(torch.sigmoid, title='Sigmoid', min=-4, max=4)" - ] - }, - { - "cell_type": "code", - "execution_count": 129, - "metadata": { - "id": "htbxxyxlpokD" - }, - "outputs": [], - "source": [ - "def mnist_loss(predictions, targets):\n", - " predictions = predictions.sigmoid() #ใส่ sigmoid ไป; pytorch tensor มี built-in function ให้แล้ว\n", - " return torch.where(targets==1, 1-predictions, predictions).mean()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UUxkL3_2pokD" - }, - "source": [ - "### SGD and Mini-Batches" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rxCnTrNRInAz" - }, - "source": [ - "หลายครั้งข้อมูลทั้งชุดของเราใหญ่เกินไปที่จะใส่เข้าไปใน memory ของเครื่อง เราจึงต้องใส่เข้าไปทีละ batch ด้วย `DataLoader`" - ] - }, - { - "cell_type": "code", - "execution_count": 132, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2-twSiAspokD", - "outputId": "f22c96c8-2188-4b31-c97a-2eccb24f46c7" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[tensor([ 8, 9, 5, 11, 4]),\n", - " tensor([ 7, 13, 14, 2, 3]),\n", - " tensor([ 6, 12, 10, 1, 0])]" - ] - }, - "execution_count": 132, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ทดลองใช้ Dataloader สุ่มตัวเลขจาก 0-14; batch ละ 5 ตัวอย่าง; ให้สุ่มด้วย\n", - "coll = range(15)\n", - "dl = DataLoader(coll, batch_size=5, shuffle=True)\n", - "list(dl)" - ] - }, - { - "cell_type": "code", - "execution_count": 133, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "7W43iIU_pokE", - "outputId": "8cedff83-1aa4-458d-b8e2-0372eb1520d5" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(#26) [(0, 'a'),(1, 'b'),(2, 'c'),(3, 'd'),(4, 'e'),(5, 'f'),(6, 'g'),(7, 'h'),(8, 'i'),(9, 'j')...]" + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lDlxHuO8Vazc", + "outputId": "11eda802-037f-4abb-8308-f5045f90d468" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(tensor(0.1836), tensor(0.3390))" + ] + }, + "metadata": {}, + "execution_count": 24 + } + ], + "source": [ + "#ใช้ function ของ pytorch คิดก็ได้\n", + "F.l1_loss(a_3.float(), mean7), F.mse_loss(a_3, mean7).sqrt()" ] - }, - "execution_count": 133, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#สมมุติเรามี DataSet แบบนี้\n", - "ds = L(enumerate(string.ascii_lowercase))\n", - "ds" - ] - }, - { - "cell_type": "code", - "execution_count": 134, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "GOccHc_rpokE", - "outputId": "f9664233-48a4-4f7d-f5ef-5be94f04a9de" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[(tensor([19, 14, 0, 24, 20, 12]), ('t', 'o', 'a', 'y', 'u', 'm')),\n", - " (tensor([23, 8, 9, 3, 16, 6]), ('x', 'i', 'j', 'd', 'q', 'g')),\n", - " (tensor([ 4, 7, 1, 13, 2, 22]), ('e', 'h', 'b', 'n', 'c', 'w')),\n", - " (tensor([ 5, 17, 18, 10, 11, 15]), ('f', 'r', 's', 'k', 'l', 'p')),\n", - " (tensor([25, 21]), ('z', 'v'))]" - ] - }, - "execution_count": 134, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#ก็ใส่เข้าไปใน dataloader ได้เช่นกัน\n", - "dl = DataLoader(ds, batch_size=6, shuffle=True)\n", - "list(dl)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tvWlUYU8pokE" - }, - "source": [ - "## Putting It All Together" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fjUk_EOS0deu" - }, - "source": [ - "#### Initialize weights" - ] - }, - { - "cell_type": "code", - "execution_count": 186, - "metadata": { - "id": "lK9dlPP4pokE" - }, - "outputs": [], - "source": [ - "weights = init_params((28*28,1))\n", - "bias = init_params(1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L4ejZTDi0fds" - }, - "source": [ - "#### Initialize dataloader" - ] - }, - { - "cell_type": "code", - "execution_count": 169, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "XesYQeEUpokE", - "outputId": "fc6aab3d-26af-4611-ea8d-bcb308310b41" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(torch.Size([256, 784]), torch.Size([256, 1]))" + "cell_type": "markdown", + "metadata": { + "id": "4Lx6GAdSVgD2" + }, + "source": [ + "เรื่องน่าคิดถึงความแตกต่างระหว่าง mean squared error และ mean absolute error" ] - }, - "execution_count": 169, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#สร้าง dataloader สำหรับ train\n", - "dl = DataLoader(dset, batch_size=256, shuffle=True)\n", - "xb,yb = first(dl)\n", - "xb.shape,yb.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 170, - "metadata": { - "id": "hlw0qizVpokE" - }, - "outputs": [], - "source": [ - "#เราจะไม่ shuffle validation set\n", - "valid_dl = DataLoader(valid_dset, batch_size=256, shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 171, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "opSgyb70pokF", - "outputId": "eac98990-8410-46e5-8356-68c03c412118" - }, - "outputs": [ { - "data": { - "text/plain": [ - "torch.Size([4, 784])" + "cell_type": "code", + "execution_count": 25, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IcAtO7IwDzUp", + "outputId": "a4f00685-f56e-4a9b-ca7b-051ba45ff758" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(1.1666666269302368, 1.1902379989624023)" + ] + }, + "metadata": {}, + "execution_count": 25 + } + ], + "source": [ + "a = torch.tensor([1, 2, 3]).float()\n", + "b = torch.tensor([2.,3.,4.5])\n", + "c = torch.tensor([2.,3.,40.])\n", + "\n", + "#mse และ mae ไม่ต่างกันเท่าไหร่สำหรับ a และ b\n", + "(a-b).abs().mean().item(), \\\n", + "((a-b)**2).mean().sqrt().item()," ] - }, - "execution_count": 171, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#(batch size, 28*28)\n", - "batch = train_x[:4]\n", - "batch.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d_OZXa2l0iwq" - }, - "source": [ - "#### Forward pass" - ] - }, - { - "cell_type": "code", - "execution_count": 196, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "IQ5eVzE5pokF", - "outputId": "91143405-c419-41be-ffa8-1f485be9b9ef" - }, - "outputs": [ { - "data": { - "text/plain": [ - "torch.Size([256, 1])" + "cell_type": "code", + "execution_count": 26, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "y5zmXGk2poj0", + "outputId": "5457536b-f88b-4077-f57f-8b3aa5021cdf" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(13.0, 21.3775577545166)" + ] + }, + "metadata": {}, + "execution_count": 26 + } + ], + "source": [ + "#mse และ mae ต่างกันเกือบเท่าตัวสำหรับ a และ c\n", + "(a-c).abs().mean().item(),\\\n", + "((a-c)**2).mean().sqrt().item()" ] - }, - "execution_count": 196, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def linear1(xb): return xb@weights + bias\n", - "preds = linear1(xb)\n", - "preds.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "P0dawUHM0kpf" - }, - "source": [ - "#### Calculate loss" - ] - }, - { - "cell_type": "code", - "execution_count": 197, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "_zK9HJE7pokF", - "outputId": "84527bdf-28f4-4a64-86a5-17ad6c3c08ec" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor(0.2793, grad_fn=)" + "cell_type": "markdown", + "metadata": { + "id": "ErTcHuGkpoj4" + }, + "source": [ + "## Stochastic Gradient Descent (SGD)" ] - }, - "execution_count": 197, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def mnist_loss(predictions, targets):\n", - " predictions = predictions.sigmoid()\n", - " return torch.where(targets==1, 1-predictions, predictions).mean()\n", - "loss = mnist_loss(preds, xb)\n", - "loss" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tMIHxJcw0mW5" - }, - "source": [ - "#### Backward pass - get gradients and update weights" - ] - }, - { - "cell_type": "code", - "execution_count": 198, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "lMdSQrorpokF", - "outputId": "30ef9ce3-f371-434e-8598-f81970ce1cbf" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(torch.Size([784, 1]), tensor(0.0098), tensor([0.0854]))" + "cell_type": "markdown", + "metadata": { + "id": "XTAhc8bJWi6y" + }, + "source": [ + "หากยังจำได้จากบทที่ 1 เราเรียนรู้ว่าเราจะคำนวณ `Gradients` จาก `Loss` แล้วให้ Optimizer ทำหน้าที่ update `Weights` ในบทเรียนนี้เราจะมาเรียนรู้ขั้นตอนเหล่านี้กัน" ] - }, - "execution_count": 198, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "loss.backward()\n", - "weights.grad.shape,weights.grad.mean(), bias.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 199, - "metadata": { - "id": "K4x5HLp4pokF" - }, - "outputs": [], - "source": [ - "#คำนวณ gradient ด้วย .backward()\n", - "def calc_grad(xb, yb, model):\n", - " preds = model(xb)\n", - " loss = mnist_loss(preds, yb)\n", - " loss.backward()" - ] - }, - { - "cell_type": "code", - "execution_count": 217, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "CZzqYNtgpokF", - "outputId": "433c868f-f1d4-449f-a127-7187ec620785" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(tensor(-3.8863e-05), tensor([-0.0003]))" + "cell_type": "markdown", + "metadata": { + "id": "V9fzeVckWhP9" + }, + "source": [ + "" ] - }, - "execution_count": 217, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "calc_grad(batch, train_y[:4], linear1)\n", - "weights.grad.mean(),bias.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 218, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "R5yY7Az-pokF", - "outputId": "ffe44369-ddd7-42fe-f61d-37ed1101133f" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(tensor(-4.2396e-05), tensor([-0.0003]))" + "cell_type": "markdown", + "metadata": { + "id": "OISepsl-poj5" + }, + "source": [ + "### คำนวณ Gradients เพื่อทำ Backpropagation" ] - }, - "execution_count": 218, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "calc_grad(batch, train_y[:4], linear1)\n", - "weights.grad.mean(),bias.grad" - ] - }, - { - "cell_type": "code", - "execution_count": 219, - "metadata": { - "id": "R-Bzu26wpokG" - }, - "outputs": [], - "source": [ - "#เราจำเป็นต้อง reset gradients เพื่อไม่ให้มันบวกเพิ่มขึ้นเรือ่ยๆ\n", - "weights.grad.zero_();\n", - "bias.grad.zero_();" - ] - }, - { - "cell_type": "code", - "execution_count": 232, - "metadata": { - "id": "gKkqYbaqpokG" - }, - "outputs": [], - "source": [ - "#เทรน 1 epoch\n", - "def train_epoch(model, lr, params):\n", - " #โยนทุก batch ให้โมเดลที่ละ batch จนหมด\n", - " for xb,yb in dl:\n", - " #คำนวน loss และ gradients\n", - " calc_grad(xb, yb, model)\n", - " #เปลี่ยน weights ด้วย gradient * learning rate (lr)\n", - " for p in params:\n", - " p.data -= p.grad*lr\n", - " p.grad.zero_()" - ] - }, - { - "cell_type": "code", - "execution_count": 233, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "2XWh1J1ppokG", - "outputId": "5d18d3ee-0fa7-4394-f366-86a8f68351a2" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor(0.3438)" + "cell_type": "markdown", + "metadata": { + "id": "_9WrHtiu0ScN" + }, + "source": [ + "หากคุณยังไม่เคยเรียนเกี่ยวกับ [partial derivative](https://en.wikipedia.org/wiki/Partial_derivative) และ [chain rule](https://www.khanacademy.org/math/ap-calculus-ab/ab-differentiation-2-new/ab-3-1a/a/chain-rule-review) ในชั้นเรียนมัธยมปลาย คุณอาจจะไม่จำเป็นต้องเข้าใจเนื้อหาส่วนนี้ทั้งหมดก็ได้ ใจความสำคัญคือเราสามารถปรับแต่ง `Weights` ได้ด้วย `Gradients` ที่ถูกคำนวณมาจาก `Loss` เพื่อให้ได้ `Loss` ที่น้อยที่สุดเท่าที่จะทำได้ใน iteration ต่อๆไป" ] - }, - "execution_count": 233, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#คิด accuracy\n", - "((preds>0.0).float() == yb).float().mean()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4ZfrHoey0w2z" - }, - "source": [ - "### Calculate metric (accuracy in this case)" - ] - }, - { - "cell_type": "code", - "execution_count": 234, - "metadata": { - "id": "5QPI55L4pokG" - }, - "outputs": [], - "source": [ - "def batch_accuracy(xb, yb):\n", - " preds = xb.sigmoid()\n", - " correct = (preds>0.5) == yb\n", - " return correct.float().mean()" - ] - }, - { - "cell_type": "code", - "execution_count": 235, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "8UVCx6RApokG", - "outputId": "29839e53-a855-4967-b6f1-e4edc8d51571" - }, - "outputs": [ { - "data": { - "text/plain": [ - "tensor(0.)" + "cell_type": "markdown", + "metadata": { + "id": "tP7xMjbi6er4" + }, + "source": [ + "ตัวอย่างการทำ Backpropagation อย่างง่ายจาก [cs231n](https://cs231n.github.io/optimization-2/#backprop)\n", + "\n", + "Independent variables:\n", + "(เราอาจจะคิดว่า `x, y, z` เป็น `Inputs` หรือ `Weights` สำหรับโมเดลก็ได้)\n", + "\\begin{align}\n", + "x & = -2 \\\\\n", + "y & = 5 \\\\\n", + "z & = -4 \\\\\n", + "\\end{align}\n", + "\n", + "Dependent variables: \n", + "(`q` และ `f` คือฟังชั่นอะไรบางอย่าง เช่น `Loss Function` ของโมเดลก็ได้)\n", + "\n", + "ถ้าแทนค่า `x, y, z` เข้าไปในฟังชั่น `q, f` จะได้\n", + "\n", + "\\begin{align}\n", + "q & = x+y = -2+5 = 3\\\\\n", + "f & = q*z = 3*-4 = -12\n", + "\\end{align}\n", + "\n", + "การทำแบบนี้เปรียบเสมือนการเปลี่ยน `Inputs` เป็น `Predictions` ด้วย `Weights` เรียกว่า `Forward Pass`\n", + "\n", + "หลังจากนั้น เราสามารถคำนวณหา `Gradients` ซึ่งโดยทั่วไปแล้วหมายถึงค่าอัตราการเปลี่ยนแปลงของฟังชั่นท้ายสุด (ในที่นี้คือ `f`) เทียบกับตัวแปรแรกสุด (ในที่นี้คือ `x, y, z`) ได้แก่ $\\frac{df}{dx}$, $\\frac{df}{dy}$, $\\frac{df}{dz}$ เราสามารถหาสิ่งนี้ด้วยการหา [partial derivative](https://en.wikipedia.org/wiki/Partial_derivative) และ [chain rule](https://www.khanacademy.org/math/ap-calculus-ab/ab-differentiation-2-new/ab-3-1a/a/chain-rule-review) เรียกว่า `Backward Pass`\n", + "\n", + "\\begin{align}\n", + "\\frac{df}{dq} & = z = -4\\\\\n", + "\\frac{dq}{dx} & = 1\\\\\n", + "\\frac{df}{dx} & = \\frac{df}{dq} * \\frac{dq}{dx}\\\\\n", + "& = -4*1\\\\\n", + "& = -4\\\\\n", + "\\end{align}\n", + "\n" ] - }, - "execution_count": 235, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "batch_accuracy(linear1(batch), train_y[:4])" - ] - }, - { - "cell_type": "code", - "execution_count": 236, - "metadata": { - "id": "GbJQiOgIpokG" - }, - "outputs": [], - "source": [ - "#ทำการ validate ด้วย validation set\n", - "def validate_epoch(model):\n", - " accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]\n", - " return round(torch.stack(accs).mean().item(), 4)" - ] - }, - { - "cell_type": "code", - "execution_count": 237, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "txrLkdJQpokH", - "outputId": "83acd8ad-c9dd-4852-d105-fe35af5c3f3b" - }, - "outputs": [ { - "data": { - "text/plain": [ - "0.3447" + "cell_type": "code", + "execution_count": 27, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VXdDIlJl6jYC", + "outputId": "9770d555-8e05-4b73-bf31-48d1d6b782c7" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(-4.0, -4.0)" + ] + }, + "metadata": {}, + "execution_count": 27 + } + ], + "source": [ + "# set some inputs\n", + "x = -2; y = 5; z = -4\n", + "\n", + "# perform the forward pass\n", + "q = x + y # q becomes 3\n", + "f = q * z # f becomes -12\n", + "\n", + "# perform the backward pass (backpropagation) in reverse order:\n", + "# first backprop through f = q * z\n", + "dfdz = q # df/dz = q, so gradient on z becomes 3\n", + "dfdq = z # df/dq = z, so gradient on q becomes -4\n", + "# now backprop through q = x + y\n", + "dfdx = 1.0 * dfdq # dq/dx = 1. And the multiplication here is the chain rule!\n", + "dfdy = 1.0 * dfdq # dq/dy = 1\n", + "\n", + "dfdx, dfdy" ] - }, - "execution_count": 237, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "validate_epoch(linear1)" - ] - }, - { - "cell_type": "code", - "execution_count": 238, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "HRIapIkfpokH", - "outputId": "f26dd6f4-c57f-481a-dcc1-b8199c3803ad" - }, - "outputs": [ { - "data": { - "text/plain": [ - "0.714" + "cell_type": "markdown", + "metadata": { + "id": "kt9JIcVZ3CuU" + }, + "source": [ + "Pytorch สามารถทำ `Backward Pass` ให้เราโดยอัตโนมัติด้วยฟังชั่น Autograd โดยที่เราไม่ต้องคิด partial derivative เอง ผลข้างเคียงอีกอย่างคือเราสามารถใช้ Pytorch ช่วยทำการบ้านวิชาแคลคูลัสเวลาเราหา derivative ที่ยากเกินไปไม่ออกได้อีกด้วย" ] - }, - "execution_count": 238, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#เทรนไป 1 epoch; accuracy เพิ่มเกือบเท่าตัว!\n", - "lr = 1.\n", - "\n", - "params = weights, bias\n", - "train_epoch(linear1, lr, params)\n", - "validate_epoch(linear1)" - ] - }, - { - "cell_type": "code", - "execution_count": 239, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "hvBUTwULpokH", - "outputId": "20f9b9df-8efb-4ead-baee-9e644f908d79" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.9068 0.9328 0.947 0.9569 0.9603 0.9618 0.9613 0.9627 0.9642 0.9677 0.9686 0.9696 0.9701 0.9706 0.972 0.972 0.973 0.973 0.9735 0.974 " - ] - } - ], - "source": [ - "#เทรนไป 20 epoch; ทายถูกเกือบหมด\n", - "for i in range(20):\n", - " train_epoch(linear1, lr, params)\n", - " print(validate_epoch(linear1), end=' ')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uVt65caepokH" - }, - "source": [ - "### Optimizer as a class" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gyvY33oqMdZk" - }, - "source": [ - "หลายครั้งเรารวบรวมหน้าที่การ update weights มาเป็น class ชื่อ optimizer" - ] - }, - { - "cell_type": "code", - "execution_count": 240, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "ETZ0sdlXpokH", - "outputId": "5be434b1-64b6-4080-de5b-51216745e07b" - }, - "outputs": [ { - "data": { - "text/plain": [ - "Linear(in_features=784, out_features=1, bias=True)" + "cell_type": "code", + "execution_count": 28, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SWg90ZHApoj6", + "outputId": "d7fbd5b9-1170-424e-e049-26bb3ca5a1a1" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([ 3., 4., 10.], requires_grad=True)" + ] + }, + "metadata": {}, + "execution_count": 28 + } + ], + "source": [ + "xt = torch.tensor([3.,4.,10.]).requires_grad_()\n", + "xt" ] - }, - "execution_count": 240, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#nn.Linear ก็คือฟังชั่นที่ x*W.T + b เราเขียนเมื่อกี้เลย\n", - "linear_model = nn.Linear(28*28,1, bias=True)\n", - "linear_model" - ] - }, - { - "cell_type": "code", - "execution_count": 242, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "Op_9VxfipokH", - "outputId": "6096f6bb-ffd3-4f4b-8c6b-46412d73cd1a" - }, - "outputs": [ { - "data": { - "text/plain": [ - "(torch.Size([1, 784]), torch.Size([1]))" + "cell_type": "code", + "execution_count": 29, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "WpM1Q5nApoj6", + "outputId": "826205b5-d40f-4b04-986d-7c0beb6a0265" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor(125., grad_fn=)" + ] + }, + "metadata": {}, + "execution_count": 29 + } + ], + "source": [ + "#ฟังชั่น f(x) = (x1^2 + x2^2 +...+xn^2)\n", + "def f(x): return (x**2).sum()\n", + "\n", + "yt = f(xt) #ใส่ 3, 4, 10 เข้าไปได้ 3^2+4^2+10^2 = 125\n", + "yt" ] - }, - "execution_count": 242, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "w,b = linear_model.parameters()\n", - "w.shape,b.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 243, - "metadata": { - "id": "HDPzEqO-pokI" - }, - "outputs": [], - "source": [ - "#optimizer แบบง่ายที่สุด\n", - "class BasicOptim:\n", - " def __init__(self,params,lr): \n", - " self.params,self.lr = list(params),lr\n", - "\n", - " #step ทำการ update weights\n", - " def step(self, *args, **kwargs):\n", - " for p in self.params: p.data -= p.grad.data * self.lr\n", - "\n", - " #zero grad เพื่อทำการรีเซต gradients\n", - " def zero_grad(self, *args, **kwargs):\n", - " for p in self.params: p.grad = None" - ] - }, - { - "cell_type": "code", - "execution_count": 244, - "metadata": { - "id": "UIAhC389pokI" - }, - "outputs": [], - "source": [ - "opt = BasicOptim(linear_model.parameters(), lr)" - ] - }, - { - "cell_type": "code", - "execution_count": 245, - "metadata": { - "id": "zNokDJCupokI" - }, - "outputs": [], - "source": [ - "def train_epoch(model):\n", - " for xb,yb in dl:\n", - " calc_grad(xb, yb, model)\n", - " #เปลี่ยนโค้ดเป็น optimizer แทน\n", - " opt.step()\n", - " opt.zero_grad()" - ] - }, - { - "cell_type": "code", - "execution_count": 246, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "pUr7KQecpokI", - "outputId": "0515d9ee-847a-497e-8f13-20fe853d9548" - }, - "outputs": [ { - "data": { - "text/plain": [ - "0.3604" + "cell_type": "code", + "execution_count": 30, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "POJB2ewmpoj6", + "outputId": "83cd31b2-2071-4a9e-af97-f3b301b2c807" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([ 6., 8., 20.])" + ] + }, + "metadata": {}, + "execution_count": 30 + } + ], + "source": [ + "#ถ้าคิดด้วยมือ df(x)/dx = 2x \n", + "#สำหรับ x1=3, x2=4, x3=10 ก็จะเป็น \n", + "#df(x1)/dx1 = 6, df(x2)/dx2 = 8, df(x3)/dx3 = 20\n", + "\n", + "#ใช้ autograd หา df(x1)/dx1, df(x2)/dx2, df(x3)/dx3\n", + "yt.backward()\n", + "xt.grad" ] - }, - "execution_count": 246, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "validate_epoch(linear_model)" - ] - }, - { - "cell_type": "code", - "execution_count": 247, - "metadata": { - "id": "wQptVRJfpokI" - }, - "outputs": [], - "source": [ - "def train_model(model, epochs):\n", - " for i in range(epochs):\n", - " train_epoch(model)\n", - " print(validate_epoch(model), end=' ')" - ] - }, - { - "cell_type": "code", - "execution_count": 248, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "kRbgjPURpokI", - "outputId": "e95b776d-8a7c-46e5-b56e-b596a5237fce" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.9702 0.9745 0.9765 0.9785 0.9794 0.9789 0.9789 0.9799 0.9794 0.9799 0.9809 0.9809 0.9814 0.9819 0.9819 0.9819 0.9824 0.9819 0.9829 0.9829 " - ] - } - ], - "source": [ - "#ได้ผลดีเหมือนเดิม\n", - "train_model(linear_model, 20)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rBjvWLYwNII8" - }, - "source": [ - "## ทำงานง่ายขึ้นด้วย PyTorch" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "60qaqoMbNMA4" - }, - "source": [ - "Pytorch ได้ทำการเปลี่ยนฟังชั่นและ class ต่างๆที่เราใช้งานเป็นฟังชั่นพื้นฐานของ package ให้เราเรียกใช้ได้" - ] - }, - { - "cell_type": "code", - "execution_count": 249, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "DY1kWv7lpokI", - "outputId": "1da51068-8204-48d0-b3cb-10136c208ec5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.9696 0.9741 0.976 0.977 0.977 0.978 0.9775 0.9789 0.9794 0.9799 0.9804 0.9794 0.9794 0.9804 0.9804 0.9799 0.9799 0.9814 0.9809 0.9804 " - ] - } - ], - "source": [ - "linear_model = nn.Linear(28*28,1,bias=False) #เหมือนกับ init_weights และการคูณ matrix ที่เราเพิ่งทำไป\n", - "opt = SGD(linear_model.parameters(), lr) #เหมือนกับ BasicOptim ที่เราเพิ่งทำไป\n", - "\n", - "#เทรนโมเดลได้ผลดีเหมือนเดิม\n", - "train_model(linear_model, 20)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6ctjZxC3Nxi8" - }, - "source": [ - "## ใช้ fastai ให้สะดวกยิ่งขึ้น" - ] - }, - { - "cell_type": "code", - "execution_count": 250, - "metadata": { - "id": "jtzVcFkxpokJ" - }, - "outputs": [], - "source": [ - "dls = DataLoaders(dl, valid_dl)\n", - "#รวมทุกอย่างเข้าด้วยกันด้วย class Learner\n", - "learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,\n", - " loss_func=mnist_loss, metrics=batch_accuracy)" - ] - }, - { - "cell_type": "code", - "execution_count": 251, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 676 - }, - "id": "0WcORvumpokJ", - "outputId": "e0988105-59a7-4c16-9d8a-466003017eb5" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n" + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mHHqwvz2poj_" + }, + "source": [ + "## สร้าง Loss Function สำหรับจำแนกรูปเลข 3 และเลข 7" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "U0EQ6WD4zEf0" + }, + "source": [ + "#### สร้าง X และ y" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t2aoyhac4ZrI" + }, + "source": [ + "จัดการ `Inputs` คือรูปตัวเลข 28x28 pixels และ `Labels` คือ `1 ถ้าเป็นเลข 3` และ `0 ถ้าเป็นเลข 7`" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "id": "-4LTBmPcpoj_" + }, + "outputs": [], + "source": [ + "train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Q5bU0JpipokA", + "outputId": "798e931e-2e7b-41bf-e94c-231da4a25c83" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(torch.Size([12396, 784]), torch.Size([12396, 1]))" + ] + }, + "metadata": {}, + "execution_count": 32 + } ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
epochtrain_lossvalid_lossbatch_accuracytime
00.0572380.0413640.97055900:00
10.0403980.0342190.97546600:00
20.0318900.0315480.97595700:00
30.0277810.0294280.97742900:00
40.0254070.0281660.97792000:00
50.0235450.0267280.97939200:00
60.0222740.0261640.97841000:00
70.0212850.0253030.97939200:00
80.0204780.0247790.97890100:00
90.0197470.0245610.98037300:00
100.0197390.0242870.98037300:00
110.0191990.0233610.97988200:00
120.0186020.0228020.98086400:00
130.0180490.0226690.98086400:00
140.0177650.0224540.98184500:00
150.0174720.0220460.98135400:00
160.0177030.0223970.98135400:00
170.0173600.0217730.98233600:00
180.0169900.0215650.98233600:00
190.0166550.0213360.98184500:00
" + "source": [ + "train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)\n", + "train_x.shape,train_y.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rkavLIe1IcmD", + "outputId": "c7cabbb3-e47f-4403-b5d7-b09f38036698" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[(1, 4), (2, 5), (3, 6)]" + ] + }, + "metadata": {}, + "execution_count": 33 + } ], - "text/plain": [ - "" + "source": [ + "#เราสามารถนำ iterables สองอันมาต่อกันแบบนี้ได้ด้วย zip\n", + "a = [1,2,3]\n", + "b = [4,5,6]\n", + "list(zip(a,b))" ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "learn.fit(20, lr=lr)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "H_UoKZ24pokJ" - }, - "source": [ - "## สร้าง Architecture ที่ Deep ขึ้น" - ] - }, - { - "cell_type": "code", - "execution_count": 252, - "metadata": { - "id": "4s1EGNszpokK" - }, - "outputs": [], - "source": [ - "#แทนที่จะทำ matrix multipilcation + bias ครั้งเดียว เราใส่ ReLU activation ไปเพิ่ม\n", - "def simple_net(xb): \n", - " res = xb@w1 + b1\n", - " res = res.max(tensor(0.0)) #นี่คือ relu\n", - " res = res@w2 + b2 #แล้วทำ matrix multiplication + bias อีกรอบ\n", - " return res" - ] - }, - { - "cell_type": "code", - "execution_count": 253, - "metadata": { - "id": "bKDW0LWlpokK" - }, - "outputs": [], - "source": [ - "w1 = init_params((28*28,30))\n", - "b1 = init_params(30)\n", - "w2 = init_params((30,1))\n", - "b2 = init_params(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 254, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 268 - }, - "id": "Fza6OpHHpokK", - "outputId": "55ede7e1-b109-4bd7-ba79-bdb62f172b02" - }, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plot_function(F.relu)" - ] - }, - { - "cell_type": "code", - "execution_count": 255, - "metadata": { - "id": "QyuXxFphpokK" - }, - "outputs": [], - "source": [ - "#แน่นอนว่าอ่านง่ายกว่าถ้าใช้ pytorch แทนที่จะเขียนฟังชั่นเอง\n", - "simple_net = nn.Sequential(\n", - " nn.Linear(28*28,30),\n", - " nn.ReLU(),\n", - " nn.Linear(30,1)\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 256, - "metadata": { - "id": "94urzFOSpokK" - }, - "outputs": [], - "source": [ - "learn = Learner(dls, simple_net, opt_func=SGD,\n", - " loss_func=mnist_loss, metrics=batch_accuracy)" - ] - }, - { - "cell_type": "code", - "execution_count": 257, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "id": "Z8WM5ImkpokK", - "outputId": "4773537c-6e5c-410a-e839-14181e3a192d" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n" + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "08BIemoApokA", + "outputId": "a209e1e9-e5ec-4c78-c2ed-3e9842c8974c" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(torch.Size([784]), torch.Size([1]))" + ] + }, + "metadata": {}, + "execution_count": 34 + } ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
epochtrain_lossvalid_lossbatch_accuracytime
00.2019710.0848630.96467100:00
10.0998620.0534450.96712500:00
20.0634570.0442920.97006900:00
30.0466910.0397990.97055900:00
40.0387840.0367570.97252200:00
50.0339280.0347910.97252200:00
60.0303350.0333780.97301300:00
70.0281980.0318290.97546600:00
80.0267690.0309010.97497500:00
90.0251490.0300880.97497500:00
100.0246890.0292900.97595700:00
110.0242790.0287690.97546600:00
120.0232550.0279090.97644800:00
130.0222210.0273520.97644800:00
140.0217220.0267840.97644800:00
150.0207420.0262480.97742900:00
160.0205200.0257570.97890100:00
170.0200880.0254520.97841000:00
180.0192520.0251850.97841000:00
190.0193140.0250370.97890100:00
200.0191190.0245810.97841000:00
210.0187150.0240800.97841000:00
220.0189390.0238080.97841000:00
230.0184360.0235630.97841000:00
240.0185330.0230740.97890100:00
250.0180650.0230190.97988200:00
260.0178940.0227760.97988200:00
270.0172870.0227010.98086400:00
280.0174050.0225140.98135400:00
290.0173150.0221080.98086400:00
300.0167650.0220430.98135400:00
310.0167900.0218320.98135400:00
320.0165870.0217300.98184500:00
330.0159500.0213040.98135400:00
340.0158970.0213510.98233600:00
350.0161210.0210570.98184500:00
360.0160240.0209620.98184500:00
370.0159870.0209560.98233600:00
380.0155440.0207720.98233600:00
390.0155440.0203690.98233600:00
" + "source": [ + "dset = list(zip(train_x,train_y))\n", + "example = dset[0]\n", + "\n", + "#คู่ Inputs และ Labels\n", + "example[0].shape, example[1].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "id": "Y6ImID5ApokA" + }, + "outputs": [], + "source": [ + "#สร้าง validation set ในแบบเดียวกัน\n", + "valid_3_tens = torch.stack([tensor(Image.open(o)) \n", + " for o in (path/'valid'/'3').ls()])\n", + "valid_3_tens = valid_3_tens.float()/255\n", + "valid_7_tens = torch.stack([tensor(Image.open(o)) \n", + " for o in (path/'valid'/'7').ls()])\n", + "valid_7_tens = valid_7_tens.float()/255\n", + "\n", + "valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)\n", + "valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)\n", + "valid_dset = list(zip(valid_x,valid_y))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RR1v1Z2tzHmo" + }, + "source": [ + "#### Initiate `Weights`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7slIAUaOI0_d" + }, + "source": [ + "สมมุติว่าเราจะใช้ architecture สุดเรียบง่าย แค่คูณค่า pixels ของรูป `Inputs` ด้วย `W` และบวกด้วย `b`\n", + "\n", + "$$prediction = \\Sigma(xW^T) + b$$\n", + "\n", + "เราสามารถเริ่มตั้ง `Weights` เป็นการ random จาก standard normal distribution" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "id": "WQ9vpYHIpokA" + }, + "outputs": [], + "source": [ + "def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "itXv1TZ9pokA", + "outputId": "cc10560d-27d6-4ff4-b29e-99fca3878da9" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([784, 1])" + ] + }, + "metadata": {}, + "execution_count": 37 + } ], - "text/plain": [ - "" + "source": [ + "#Inputs มี dimension (batch_size, 28*28)\n", + "#เพราะงั้นถ้าเราจะคูณรายตัว (element-wise multiplication) ด้วย W^T, W ต้องมี dimension (28*28,1)\n", + "weights = init_params((28*28,1))\n", + "weights.shape" ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "learn.fit(40, 0.1)" - ] - }, - { - "cell_type": "code", - "execution_count": 258, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 268 - }, - "id": "Q0d__m-FpokL", - "outputId": "846b03cf-5b37-45f8-ef7b-cbe4691e0b7a" - }, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "#ดู accuracy เพิ่มขึ้นเรื่อยๆ\n", - "plt.plot(L(learn.recorder.values).itemgot(2));" - ] - }, - { - "cell_type": "code", - "execution_count": 259, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "-pnmg5QqpokL", - "outputId": "91382f1d-566d-4cfb-82d9-70e7696a4498" - }, - "outputs": [ { - "data": { - "text/plain": [ - "0.98233562707901" + "cell_type": "code", + "execution_count": 38, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7Ef5fK0KpokA", + "outputId": "087d1c9a-adf1-4cf2-a3ea-b24fea8066e0" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([0.3472], requires_grad=True)" + ] + }, + "metadata": {}, + "execution_count": 38 + } + ], + "source": [ + "#bias มีแค่ (1) dimension แล้วจะถูก broadcast ไปทุก dimension ของ batch size เอง\n", + "# (1) -> (batch_size,1)\n", + "bias = init_params(1)\n", + "bias" ] - }, - "execution_count": 259, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "learn.recorder.values[-1][2]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fQQNpNZ-pokL" - }, - "source": [ - "## ไปให้ Deep กว่านั้นอีก" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jq3YclVaOjCk" - }, - "source": [ - "แทนที่เราจะใช้ Architecture ง่ายๆที่เราคิดขึ้นเอง เรามาลองใช้ Architecture ที่ซับซ้อนขึ้นอย่าง `resnet18` กัน" - ] - }, - { - "cell_type": "code", - "execution_count": 260, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 80 - }, - "id": "LB3qg01jpokL", - "outputId": "93503072-abe0-4fb1-a2d0-ff79af9e72fa" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n" + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pnMvUz7kFoIU", + "outputId": "07d34405-6a39-4b0f-b374-56fe7f0a7272" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(torch.Size([784, 1]),\n", + " torch.Size([1]),\n", + " torch.Size([12396, 784]),\n", + " torch.Size([784]))" + ] + }, + "metadata": {}, + "execution_count": 39 + } + ], + "source": [ + "weights.shape, bias.shape, train_x.shape, train_x[7].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KXK2LAZqzLYk" + }, + "source": [ + "#### Forward Pass" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "S9md33qTEu6n", + "outputId": "1469e6f5-bd86-45e3-ed3d-831e52ded4ab" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[-15.2588]], grad_fn=)" + ] + }, + "metadata": {}, + "execution_count": 40 + } + ], + "source": [ + "#คำนวณ forward pass สำหรับตัวอย่าง 7\n", + "(train_x[7]*weights.T).sum(1)[:,None] + bias" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "22qYPEQDEwjC", + "outputId": "528c7509-1eea-49cf-d5b6-181fd62059ca" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[ -6.2330],\n", + " [-10.6388],\n", + " [-20.8865],\n", + " ...,\n", + " [-15.9176],\n", + " [ -1.6866],\n", + " [-11.3568]], grad_fn=)" + ] + }, + "metadata": {}, + "execution_count": 41 + } + ], + "source": [ + "#คำนวณ forward pass สำหรับทุกตัวอย่าง\n", + "(train_x*weights.T).sum(1)[:,None] + bias" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "I65On42BpokB", + "outputId": "d77c1723-af03-485b-b036-c98e1b0f50af" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[ -6.2330],\n", + " [-10.6388],\n", + " [-20.8865],\n", + " ...,\n", + " [-15.9176],\n", + " [ -1.6866],\n", + " [-11.3568]], grad_fn=)" + ] + }, + "metadata": {}, + "execution_count": 42 + } + ], + "source": [ + "#เขียนเป็นฟังชั่น; @ คือ matrix multiplication ใน pytorch\n", + "def linear1(xb): return xb@weights + bias\n", + "preds = linear1(train_x)\n", + "preds" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XJKeTmrPzOr3" + }, + "source": [ + "#### Metric ใช้ Accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SLnbggPopokB", + "outputId": "478badc3-ae34-444e-c7ea-ff028b18cb23" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[False],\n", + " [False],\n", + " [False],\n", + " ...,\n", + " [ True],\n", + " [ True],\n", + " [ True]])" + ] + }, + "metadata": {}, + "execution_count": 43 + } + ], + "source": [ + "#ถ้าสมมุติว่า Predictions >0 ให้ทายเป็น 1 (เลข 3)\n", + "corrects = (preds>0.0).float() == train_y\n", + "corrects" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5DVZlAVzJqbV", + "outputId": "299bf0b0-09cf-4822-a55a-9a43537c0888" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Counter({0.0: 10788, 1.0: 1608})" + ] + }, + "metadata": {}, + "execution_count": 44 + } + ], + "source": [ + "from collections import Counter\n", + "Counter((preds>0.0).float().numpy()[:,0])" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CiSVOFonpokB", + "outputId": "4f584989-ecc1-47ef-a7e1-b95075af4f18" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.5379961133003235" + ] + }, + "metadata": {}, + "execution_count": 45 + } + ], + "source": [ + "#จะเห็นได้ว่ามันถูกเครื่องๆ (เพราะเราเดาสุ่มด้วย weights ที่สุ่มมา)\n", + "corrects.float().mean().item()" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "id": "RwkizTBqpokC" + }, + "outputs": [], + "source": [ + "#ถ้าเราเปลี่ยน weights เล็กน้อย\n", + "with torch.no_grad():\n", + " weights *= 5.0001\n", + " bias+=10" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d5Kv_vFrpokC", + "outputId": "351fac1c-2a3a-485c-f48e-af448f3141c3" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.548402726650238" + ] + }, + "metadata": {}, + "execution_count": 47 + } + ], + "source": [ + "#ความแม่นยำก็เปลี่ยนนิดนึง\n", + "preds = linear1(train_x)\n", + "((preds>0.0).float() == train_y).float().mean().item()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PTOZYvLRy-Vz" + }, + "source": [ + "#### Loss Function อย่างง่าย" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "id": "w7wVT8TjpokC" + }, + "outputs": [], + "source": [ + "trgts = tensor([1,0,1])\n", + "prds = tensor([0.9, 0.4, 0.2])" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "id": "9HgV5HQppokC" + }, + "outputs": [], + "source": [ + "#loss function ที่จะเป็นจะเป็น Predictions ของ class ที่ผิด\n", + "#เพราะงั้นยิ่งทาย class ที่ผิดแบบมั่นใจมาก Loss ก็จะยิ่งสูง\n", + "def mnist_loss(predictions, targets):\n", + " return torch.where(targets==1, 1-predictions, predictions).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5xGnIzeHpokC", + "outputId": "29656655-4d49-491a-abfa-b4f3c87a63c7" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([0.1000, 0.4000, 0.8000])" + ] + }, + "metadata": {}, + "execution_count": 50 + } + ], + "source": [ + "#if target == 1, loss is how far it is from 1\n", + "#if target == 0, loss is how far it is from 0\n", + "torch.where(trgts==1, 1-prds, prds)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dhQH45dYpokC", + "outputId": "2a3c5978-ecea-493d-cb92-02b89b64c610" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor(0.4333)" + ] + }, + "metadata": {}, + "execution_count": 51 + } + ], + "source": [ + "mnist_loss(prds,trgts)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uiUjD7aJpokD", + "outputId": "c2fabcb4-0df8-4d71-9196-d110404ea9bb" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor(0.2333)" + ] + }, + "metadata": {}, + "execution_count": 52 + } + ], + "source": [ + "#ถ้าทายแม่นขึ้น loss ก็ลดลง\n", + "mnist_loss(tensor([0.9, 0.4, 0.8]), trgts)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dqsrmegupokD" + }, + "source": [ + "### Activation Function - Sigmoid" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-0DXOc8w0YeU" + }, + "source": [ + "ปัญหาอย่างนึงของเราตอนนี้คือโมเดลเราให้ค่า `Predictions` เป็นจำนวนจริง แต่ `Loss Function` ของเราต้องการค่าระหว่าง 0 ถึง 1 เราจึงใช้ `Activation Function` ที่ชื่อ Sigmoid เพื่อเปลี่ยนให้ตัวเลขอยู่ระหว่าง 0 ถึง 1\n", + "\n", + "`Activation Function` นั้นมีมากมายหลายประเภท ที่นิยมใช้ เช่น [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))(x) - ฟังชั่นที่บอกว่าถ้าน้อยกว่า 0 ให้เป็น นอกนั้นเป็น x, [tanh](https://mathworld.wolfram.com/HyperbolicTangent.html)(x) - ฟังชั่นที่เปลี่ยนค่าจำนวนจริงให้อยู่ระหว่าง -1 และ 1 เป็นต้น อ่านเพิ่มเกี่ยวกับ [Activation Functions](https://en.wikipedia.org/wiki/Activation_function)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "id": "sQIUTk5ZpokD" + }, + "outputs": [], + "source": [ + "def sigmoid(x): return 1/(1+torch.exp(-x))" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + }, + "id": "TY9tdedYpokD", + "outputId": "f1b91b5a-9267-4d78-93e5-6fa068f979f2" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "plot_function(torch.sigmoid, title='Sigmoid', min=-4, max=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": { + "id": "htbxxyxlpokD" + }, + "outputs": [], + "source": [ + "def mnist_loss(predictions, targets):\n", + " predictions = predictions.sigmoid() #ใส่ sigmoid ไป; pytorch tensor มี built-in function ให้แล้ว\n", + " return torch.where(targets==1, 1-predictions, predictions).mean()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UUxkL3_2pokD" + }, + "source": [ + "### SGD and Mini-Batches" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rxCnTrNRInAz" + }, + "source": [ + "หลายครั้งข้อมูลทั้งชุดของเราใหญ่เกินไปที่จะใส่เข้าไปใน memory ของเครื่อง เราจึงต้องใส่เข้าไปทีละ batch ด้วย `DataLoader`" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2-twSiAspokD", + "outputId": "c3b8269f-b27a-4fb0-bda5-cbe2e18f9deb" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[tensor([ 3, 12, 8, 10, 2]),\n", + " tensor([ 9, 4, 7, 14, 5]),\n", + " tensor([ 1, 13, 0, 6, 11])]" + ] + }, + "metadata": {}, + "execution_count": 56 + } + ], + "source": [ + "#ทดลองใช้ Dataloader สุ่มตัวเลขจาก 0-14; batch ละ 5 ตัวอย่าง; ให้สุ่มด้วย\n", + "coll = range(15)\n", + "dl = DataLoader(coll, batch_size=5, shuffle=True)\n", + "list(dl)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7W43iIU_pokE", + "outputId": "1b46408a-2fdc-40f7-a20d-fa1ab4c988a8" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(#26) [(0, 'a'),(1, 'b'),(2, 'c'),(3, 'd'),(4, 'e'),(5, 'f'),(6, 'g'),(7, 'h'),(8, 'i'),(9, 'j')...]" + ] + }, + "metadata": {}, + "execution_count": 57 + } + ], + "source": [ + "#สมมุติเรามี DataSet แบบนี้\n", + "ds = L(enumerate(string.ascii_lowercase))\n", + "ds" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GOccHc_rpokE", + "outputId": "1d364e87-dedd-455f-8e95-96271d3e9f2e" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[(tensor([17, 18, 10, 22, 8, 14]), ('r', 's', 'k', 'w', 'i', 'o')),\n", + " (tensor([20, 15, 9, 13, 21, 12]), ('u', 'p', 'j', 'n', 'v', 'm')),\n", + " (tensor([ 7, 25, 6, 5, 11, 23]), ('h', 'z', 'g', 'f', 'l', 'x')),\n", + " (tensor([ 1, 3, 0, 24, 19, 16]), ('b', 'd', 'a', 'y', 't', 'q')),\n", + " (tensor([2, 4]), ('c', 'e'))]" + ] + }, + "metadata": {}, + "execution_count": 58 + } + ], + "source": [ + "#ก็ใส่เข้าไปใน dataloader ได้เช่นกัน\n", + "dl = DataLoader(ds, batch_size=6, shuffle=True)\n", + "list(dl)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tvWlUYU8pokE" + }, + "source": [ + "## Putting It All Together" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fjUk_EOS0deu" + }, + "source": [ + "#### Initialize weights" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": { + "id": "lK9dlPP4pokE" + }, + "outputs": [], + "source": [ + "weights = init_params((28*28,1))\n", + "bias = init_params(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L4ejZTDi0fds" + }, + "source": [ + "#### Initialize dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XesYQeEUpokE", + "outputId": "b0490d69-191f-49a7-c487-eb96763a85a3" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(torch.Size([256, 784]), torch.Size([256, 1]))" + ] + }, + "metadata": {}, + "execution_count": 60 + } + ], + "source": [ + "#สร้าง dataloader สำหรับ train\n", + "dl = DataLoader(dset, batch_size=256, shuffle=True)\n", + "xb,yb = first(dl)\n", + "xb.shape,yb.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": { + "id": "hlw0qizVpokE" + }, + "outputs": [], + "source": [ + "#เราจะไม่ shuffle validation set\n", + "valid_dl = DataLoader(valid_dset, batch_size=256, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "opSgyb70pokF", + "outputId": "3f99d56f-92b8-4b35-f887-72298850466e" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([4, 784])" + ] + }, + "metadata": {}, + "execution_count": 62 + } + ], + "source": [ + "#(batch size, 28*28)\n", + "batch = train_x[:4]\n", + "batch.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d_OZXa2l0iwq" + }, + "source": [ + "#### Forward pass" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IQ5eVzE5pokF", + "outputId": "8c2e6586-aa3f-4a20-ef4c-7d0a4b8b0d1b" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([256, 1])" + ] + }, + "metadata": {}, + "execution_count": 63 + } + ], + "source": [ + "def linear1(xb): return xb@weights + bias\n", + "preds = linear1(xb)\n", + "preds.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P0dawUHM0kpf" + }, + "source": [ + "#### Calculate loss" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_zK9HJE7pokF", + "outputId": "5dc8119e-fac6-4fb6-d34f-67632d490ec2" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor(0.7690, grad_fn=)" + ] + }, + "metadata": {}, + "execution_count": 64 + } + ], + "source": [ + "def mnist_loss(predictions, targets):\n", + " predictions = predictions.sigmoid()\n", + " return torch.where(targets==1, 1-predictions, predictions).mean()\n", + "loss = mnist_loss(preds, xb)\n", + "loss" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tMIHxJcw0mW5" + }, + "source": [ + "#### Backward pass - get gradients and update weights" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lMdSQrorpokF", + "outputId": "65a29105-64da-4008-a8a3-37b98344972e" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(torch.Size([784, 1]), tensor(0.0049), tensor([0.0425]))" + ] + }, + "metadata": {}, + "execution_count": 65 + } + ], + "source": [ + "loss.backward()\n", + "weights.grad.shape,weights.grad.mean(), bias.grad" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": { + "id": "K4x5HLp4pokF" + }, + "outputs": [], + "source": [ + "#คำนวณ gradient ด้วย .backward()\n", + "def calc_grad(xb, yb, model):\n", + " preds = model(xb)\n", + " loss = mnist_loss(preds, yb)\n", + " loss.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CZzqYNtgpokF", + "outputId": "8c50d09c-6e74-4943-ce7a-a9ec8dee08f7" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(tensor(-0.0007), tensor([0.0070]))" + ] + }, + "metadata": {}, + "execution_count": 67 + } + ], + "source": [ + "calc_grad(batch, train_y[:4], linear1)\n", + "weights.grad.mean(),bias.grad" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "R5yY7Az-pokF", + "outputId": "9596cc88-13f5-4f32-c0d4-41a9c2de62c4" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(tensor(-0.0064), tensor([-0.0285]))" + ] + }, + "metadata": {}, + "execution_count": 68 + } + ], + "source": [ + "calc_grad(batch, train_y[:4], linear1)\n", + "weights.grad.mean(),bias.grad" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": { + "id": "R-Bzu26wpokG" + }, + "outputs": [], + "source": [ + "#เราจำเป็นต้อง reset gradients เพื่อไม่ให้มันบวกเพิ่มขึ้นเรือ่ยๆ\n", + "weights.grad.zero_();\n", + "bias.grad.zero_();" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": { + "id": "gKkqYbaqpokG" + }, + "outputs": [], + "source": [ + "#เทรน 1 epoch\n", + "def train_epoch(model, lr, params):\n", + " #โยนทุก batch ให้โมเดลที่ละ batch จนหมด\n", + " for xb,yb in dl:\n", + " #คำนวน loss และ gradients\n", + " calc_grad(xb, yb, model)\n", + " #เปลี่ยน weights ด้วย gradient * learning rate (lr)\n", + " for p in params:\n", + " p.data -= p.grad*lr\n", + " p.grad.zero_()" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2XWh1J1ppokG", + "outputId": "cfe8a98d-d53f-45e8-fed6-bb81899f0325" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor(0.5469)" + ] + }, + "metadata": {}, + "execution_count": 71 + } + ], + "source": [ + "#คิด accuracy\n", + "((preds>0.0).float() == yb).float().mean()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4ZfrHoey0w2z" + }, + "source": [ + "### Calculate metric (accuracy in this case)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": { + "id": "5QPI55L4pokG" + }, + "outputs": [], + "source": [ + "def batch_accuracy(xb, yb):\n", + " preds = xb.sigmoid()\n", + " correct = (preds>0.5) == yb\n", + " return correct.float().mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8UVCx6RApokG", + "outputId": "fcdbc646-2421-4dfe-bb67-6d51e28d62ef" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor(0.5000)" + ] + }, + "metadata": {}, + "execution_count": 73 + } + ], + "source": [ + "batch_accuracy(linear1(batch), train_y[:4])" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": { + "id": "GbJQiOgIpokG" + }, + "outputs": [], + "source": [ + "#ทำการ validate ด้วย validation set\n", + "def validate_epoch(model):\n", + " accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]\n", + " return round(torch.stack(accs).mean().item(), 4)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "txrLkdJQpokH", + "outputId": "eeafe072-31d3-438f-e99e-f2377b03eac0" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.5484" + ] + }, + "metadata": {}, + "execution_count": 75 + } + ], + "source": [ + "validate_epoch(linear1)" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HRIapIkfpokH", + "outputId": "8537b6d8-70fc-4d5b-cca9-40ab62cf8d2e" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.9253" + ] + }, + "metadata": {}, + "execution_count": 76 + } + ], + "source": [ + "#เทรนไป 1 epoch; accuracy เพิ่มเกือบเท่าตัว!\n", + "lr = 1.\n", + "\n", + "params = weights, bias\n", + "train_epoch(linear1, lr, params)\n", + "validate_epoch(linear1)" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hvBUTwULpokH", + "outputId": "79511794-ad36-4c6b-a2a9-02df71aa8e64" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.9554 0.9608 0.9632 0.9657 0.9681 0.9696 0.9716 0.973 0.974 0.974 0.975 0.9775 0.9745 0.976 0.9779 0.9779 0.9784 0.9784 0.9784 0.9789 " + ] + } + ], + "source": [ + "#เทรนไป 20 epoch; ทายถูกเกือบหมด\n", + "for i in range(20):\n", + " train_epoch(linear1, lr, params)\n", + " print(validate_epoch(linear1), end=' ')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uVt65caepokH" + }, + "source": [ + "### Optimizer as a class" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gyvY33oqMdZk" + }, + "source": [ + "หลายครั้งเรารวบรวมหน้าที่การ update weights มาเป็น class ชื่อ optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ETZ0sdlXpokH", + "outputId": "acb76f13-b382-48dd-d2a4-f19d8b93b984" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Linear(in_features=784, out_features=1, bias=True)" + ] + }, + "metadata": {}, + "execution_count": 78 + } + ], + "source": [ + "#nn.Linear ก็คือฟังชั่นที่ x*W.T + b เราเขียนเมื่อกี้เลย\n", + "linear_model = nn.Linear(28*28,1, bias=True)\n", + "linear_model" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Op_9VxfipokH", + "outputId": "52bc6520-0ddf-4954-9f5e-083a5ab08813" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(torch.Size([1, 784]), torch.Size([1]))" + ] + }, + "metadata": {}, + "execution_count": 79 + } + ], + "source": [ + "w,b = linear_model.parameters()\n", + "w.shape,b.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": { + "id": "HDPzEqO-pokI" + }, + "outputs": [], + "source": [ + "#optimizer แบบง่ายที่สุด\n", + "class BasicOptim:\n", + " def __init__(self,params,lr): \n", + " self.params,self.lr = list(params),lr\n", + "\n", + " #step ทำการ update weights\n", + " def step(self, *args, **kwargs):\n", + " for p in self.params: p.data -= p.grad.data * self.lr\n", + "\n", + " #zero grad เพื่อทำการรีเซต gradients\n", + " def zero_grad(self, *args, **kwargs):\n", + " for p in self.params: p.grad = None" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": { + "id": "UIAhC389pokI" + }, + "outputs": [], + "source": [ + "opt = BasicOptim(linear_model.parameters(), lr)" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": { + "id": "zNokDJCupokI" + }, + "outputs": [], + "source": [ + "def train_epoch(model):\n", + " for xb,yb in dl:\n", + " calc_grad(xb, yb, model)\n", + " #เปลี่ยนโค้ดเป็น optimizer แทน\n", + " opt.step()\n", + " opt.zero_grad()" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pUr7KQecpokI", + "outputId": "03eae252-fa47-49ef-dae7-7cf2500497e5" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.3567" + ] + }, + "metadata": {}, + "execution_count": 83 + } + ], + "source": [ + "validate_epoch(linear_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": { + "id": "wQptVRJfpokI" + }, + "outputs": [], + "source": [ + "def train_model(model, epochs):\n", + " for i in range(epochs):\n", + " train_epoch(model)\n", + " print(validate_epoch(model), end=' ')" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kRbgjPURpokI", + "outputId": "8a680b82-56a8-4e0a-8444-e5dde4b97520" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.9706 0.975 0.975 0.9779 0.9779 0.9784 0.9794 0.9794 0.9799 0.9804 0.9804 0.9804 0.9804 0.9814 0.9819 0.9823 0.9823 0.9819 0.9823 0.9828 " + ] + } + ], + "source": [ + "#ได้ผลดีเหมือนเดิม\n", + "train_model(linear_model, 20)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rBjvWLYwNII8" + }, + "source": [ + "## ทำงานง่ายขึ้นด้วย PyTorch" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "60qaqoMbNMA4" + }, + "source": [ + "Pytorch ได้ทำการเปลี่ยนฟังชั่นและ class ต่างๆที่เราใช้งานเป็นฟังชั่นพื้นฐานของ package ให้เราเรียกใช้ได้" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "DY1kWv7lpokI", + "outputId": "21c7a564-f0be-4d80-e07c-ff0fd34667e1" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0.9726 0.9755 0.976 0.976 0.977 0.9764 0.9789 0.9799 0.9789 0.9789 0.9799 0.9799 0.9794 0.9803 0.9803 0.9808 0.9803 0.9809 0.9823 0.9813 " + ] + } + ], + "source": [ + "linear_model = nn.Linear(28*28,1,bias=False) #เหมือนกับ init_weights และการคูณ matrix ที่เราเพิ่งทำไป\n", + "opt = SGD(linear_model.parameters(), lr) #เหมือนกับ BasicOptim ที่เราเพิ่งทำไป\n", + "\n", + "#เทรนโมเดลได้ผลดีเหมือนเดิม\n", + "train_model(linear_model, 20)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6ctjZxC3Nxi8" + }, + "source": [ + "## ใช้ fastai ให้สะดวกยิ่งขึ้น" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": { + "id": "jtzVcFkxpokJ" + }, + "outputs": [], + "source": [ + "dls = DataLoaders(dl, valid_dl)\n", + "#รวมทุกอย่างเข้าด้วยกันด้วย class Learner\n", + "learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,\n", + " loss_func=mnist_loss, metrics=batch_accuracy)" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 676 + }, + "id": "0WcORvumpokJ", + "outputId": "0a15116b-a4ca-4557-e01e-0be547013147" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_lossbatch_accuracytime
00.0596730.0413070.97203100:00
10.0403680.0348970.97546600:00
20.0326610.0323650.97497500:00
30.0282520.0294630.97742900:00
40.0251780.0284480.97792000:00
50.0236420.0276700.97792000:00
60.0226530.0263400.97841000:00
70.0214880.0252640.97988200:00
80.0209880.0246700.97939200:00
90.0204300.0242370.97939200:00
100.0195260.0239210.98037300:00
110.0190020.0236560.98037300:00
120.0188140.0236990.98135400:00
130.0180700.0226860.98135400:00
140.0174310.0227080.98184500:00
150.0173560.0224090.98184500:00
160.0172960.0219270.98135400:00
170.0169580.0218550.98184500:00
180.0170050.0216770.98184500:00
190.0165770.0217020.98184500:00
" + ] + }, + "metadata": {} + } + ], + "source": [ + "learn.fit(20, lr=lr)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H_UoKZ24pokJ" + }, + "source": [ + "## สร้าง Architecture ที่ Deep ขึ้น" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": { + "id": "4s1EGNszpokK" + }, + "outputs": [], + "source": [ + "#แทนที่จะทำ matrix multipilcation + bias ครั้งเดียว เราใส่ ReLU activation ไปเพิ่ม\n", + "def simple_net(xb): \n", + " res = xb@w1 + b1\n", + " res = res.max(tensor(0.0)) #นี่คือ relu\n", + " res = res@w2 + b2 #แล้วทำ matrix multiplication + bias อีกรอบ\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": { + "id": "bKDW0LWlpokK" + }, + "outputs": [], + "source": [ + "w1 = init_params((28*28,30))\n", + "b1 = init_params(30)\n", + "w2 = init_params((30,1))\n", + "b2 = init_params(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + }, + "id": "Fza6OpHHpokK", + "outputId": "b23d6134-c2bc-4d4f-d92c-af0c8a2c14df" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": { + "needs_background": "light" + } + } + ], + "source": [ + "plot_function(F.relu)" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": { + "id": "QyuXxFphpokK" + }, + "outputs": [], + "source": [ + "#แน่นอนว่าอ่านง่ายกว่าถ้าใช้ pytorch แทนที่จะเขียนฟังชั่นเอง\n", + "simple_net = nn.Sequential(\n", + " nn.Linear(28*28,30),\n", + " nn.ReLU(),\n", + " nn.Linear(30,1)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": { + "id": "94urzFOSpokK" + }, + "outputs": [], + "source": [ + "learn = Learner(dls, simple_net, opt_func=SGD,\n", + " loss_func=mnist_loss, metrics=batch_accuracy)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "Z8WM5ImkpokK", + "outputId": "cb7de009-0755-421c-c896-4ecbeee52390" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_lossbatch_accuracytime
00.2538390.1014160.96270900:00
10.1182630.0554670.96957800:00
20.0710470.0449550.97055900:00
30.0504350.0399460.97203100:00
40.0399660.0367460.97252200:00
50.0346560.0348830.97399400:00
60.0305980.0330720.97448500:00
70.0284360.0315630.97448500:00
80.0268950.0307310.97595700:00
90.0259610.0297740.97742900:00
100.0250330.0289580.97742900:00
110.0235330.0280510.97742900:00
120.0228950.0276390.97841000:00
130.0217760.0268080.97841000:00
140.0211310.0267610.97742900:00
150.0206410.0260070.97890100:00
160.0200770.0253960.97890100:00
170.0198570.0249650.97890100:00
180.0197740.0249510.97890100:00
190.0189430.0247180.97890100:00
200.0188280.0242100.97890100:00
210.0187040.0238840.97890100:00
220.0186800.0235850.97939200:00
230.0181920.0234340.97988200:00
240.0176120.0228550.97988200:00
250.0176980.0226980.98037300:00
260.0177340.0226160.98037300:00
270.0176730.0222820.98037300:00
280.0169940.0220690.98086400:00
290.0171850.0218720.98086400:00
300.0168300.0216610.98086400:00
310.0167270.0214720.98037300:00
320.0162330.0214040.98135400:00
330.0159870.0211010.98037300:00
340.0156710.0209230.98135400:00
350.0156170.0207930.98135400:00
360.0154030.0206970.98184500:00
370.0152250.0206560.98135400:00
380.0155050.0203050.98233600:00
390.0153150.0203420.98135400:00
" + ] + }, + "metadata": {} + } ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
epochtrain_lossvalid_lossaccuracytime
00.1121400.0218770.99509300:31
" + "source": [ + "learn.fit(40, 0.1)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 265 + }, + "id": "Q0d__m-FpokL", + "outputId": "b39fc03d-2bb8-4a3e-8036-e57a88e15c62" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": { + "needs_background": "light" + } + } ], - "text/plain": [ - "" + "source": [ + "#ดู accuracy เพิ่มขึ้นเรื่อยๆ\n", + "plt.plot(L(learn.recorder.values).itemgot(2));" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-pnmg5QqpokL", + "outputId": "f6df8472-bbd5-4b94-c0a0-91cc4a363bbb" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.981354296207428" + ] + }, + "metadata": {}, + "execution_count": 96 + } + ], + "source": [ + "learn.recorder.values[-1][2]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fQQNpNZ-pokL" + }, + "source": [ + "## ไปให้ Deep กว่านั้นอีก" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jq3YclVaOjCk" + }, + "source": [ + "แทนที่เราจะใช้ Architecture ง่ายๆที่เราคิดขึ้นเอง เรามาลองใช้ Architecture ที่ซับซ้อนขึ้นอย่าง `resnet18` กัน" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 136 + }, + "id": "LB3qg01jpokL", + "outputId": "d2515ddb-c29c-43f5-9612-f04e927985e5" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/fastai/vision/learner.py:265: UserWarning: `cnn_learner` has been renamed to `vision_learner` -- please update your code\n", + " warn(\"`cnn_learner` has been renamed to `vision_learner` -- please update your code\")\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_lossaccuracytime
00.0858000.0082050.99754700:23
" + ] + }, + "metadata": {} + } + ], + "source": [ + "dls = ImageDataLoaders.from_folder(path)\n", + "learn = cnn_learner(dls, resnet18, pretrained=False,\n", + " loss_func=F.cross_entropy, metrics=accuracy)\n", + "\n", + "#เทรนแค่ epoch เดียวก็ดีกว่าเทรน 20-40 epoch เมื่อกี้แล้ว\n", + "learn.fit_one_cycle(1, 0.1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uxfYQcZfJEsN" + }, + "source": [ + "# Checkpoint ท้ายบท" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wB1TBtie2VPG" + }, + "source": [ + "### ☑️ ดูวิดีโอ [3Blue1Brown](https://www.youtube.com/watch?v=IHZwWFHWa-w) เกี่ยวกับ SGD (มีซับไทย)\n", + "\n", + "วิดีโอนี้จะอธิบายการทำงานของ SGD ใน Neural Networks ที่เราเรียนกันในบทที่ 1 และ 3 ด้วยกราฟฟิคที่สวยงาม ช่วยเพิ่มความเข้าใจได้มากยิ่งขึ้น\n", + "\n", + "Youtube Link: [https://www.youtube.com/watch?v=IHZwWFHWa-w](https://www.youtube.com/watch?v=IHZwWFHWa-w)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4a0-VMv6Pau1" + }, + "source": [ + "## คำถามชวนคิดเกี่ยวกับบทเรียน" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t_NIkQBMPgEh" + }, + "source": [ + "1. เราได้เรียนรู้การทำงานเกือบทุกส่วนของ ML Models ในบทเรียนนี้ คุณคิดว่าสำหรับโครงงานของคุณ การปรับแต่งส่วนไหนของโมเดลจะทำให้ได้ผลดีเพิ่มขึ้นมากที่สุด\n", + "\n", + "2. คุณคิดว่าการ initialize weights มีผลต่อคุณภาพของโมเดลที่ถูกเทรนหรือไม่ หากไม่ใช้ weights แบบสุ่ม เราควรใช้ weights อะไรในการเริ่มเทรนโมเดล\n", + "\n", + "3. Loss กับ Metric ต่างกันอย่างไร ทำไมคุณคิดว่าเราถึงเทรนโมเดลเพื่อให้ได้ Loss ที่ดีที่สุด แต่กลับวัดผลมันด้วย Metric\n", + "\n", + "4. คุณคิดว่า `Activation Function` มีความสำคัญอย่างไรกับประสิทธิภาพของโมเดล ML เราควรเลือกใช้อันไหนในกรณีไหน\n", + "\n", + "5. คุณคิดว่าทำไม Architecture อย่าง resnet18 ถึงทำได้ดีกว่า Architecture ที่เราสร้างขึ้นมาเองง่ายๆ ทั้งที่เทรนแค่ 1 epoch เทียบกับที่เราเทรนถึง 20-40 epochs\n", + "\n", + "6. คุณคิดว่ามีการ update weights ที่ดีกว่าเอา gradients มาลบไปทีละ iteration อย่างที่เราทำหรือไม่ คิดว่าวิธีไหนบ้าง หรือเป็นไปได้ไหมที่เราจะเทรนโมเดลโดยไม่ใช้ backpropagation กล่าวคือ update weights โดยไม่ใช้ gradients\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HeM7xTrJPc8w" + }, + "source": [ + "## สิ่งที่ควรเตรียมพร้อมสำหรับทำโครงงาน" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oPvpAnmLJGgU" + }, + "source": [ + "### ☑️ ทบทวนระบบการให้คะแนนโครงงานให้เรียบร้อย\n", + "\n", + "AI Builders จะออกใบประกาศนียบัตรจบการศึกษาให้กับผู้เข้าร่วมโครงการที่ส่งโครงงานได้คะแนนอย่างน้อย 70 จาก 100 คะแนนตามเกณฑ์ต่อไปนี้เท่านั้น\n", + "\n", + "1. problem statement; เหตุผลในการแก้ปัญหาเชิงธุรกิจ/ชีวิตประจำวันด้วย machine learning - 15 คะแนน\n", + "2. metrics and baselines; การให้เหตุผลเชื่อมโยงการแก้ปัญหากับตัวชี้วัดที่เลือก / การวัดผลเทียบกับวิธีแก้ปัญหาในปัจจุบัน - 15 คะแนน\n", + "3. data collection and cleaning; การเก็บและทำความสะอาดข้อมูล - 15 คะแนน\n", + "\n", + "**วันนี้เราจะคิดถึง 2 ข้อนี้เป็นพิเศษ**\n", + "\n", + "**4. exploratory data analysis; การทำความเข้าใจข้อมูล - 20 คะแนน**\n", + "\n", + "**5. modeling, validation and error analysis; การทำโมเดล, ทดสอบโมเดล และวิเคราะห์ข้อผิดพลาดของโมเดล - 20 คะแนน**\n", + "\n", + "6. deployment; การนำโมเดลไปใช้แก้ปัญหาจริง - 15 คะแนน" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9u8rUq1uPSuE" + }, + "source": [ + "### ☑️ เข้าใจชุดข้อมูล\n", + "\n", + "คุณรู้เกี่ยวกับ labels ที่คุณกำลังทำนายแค่ไหน \n", + "* ถ้าเป็น classification มันมี labels กี่ประเภท (class) แต่ละตัวอย่างเป็นได้มากกว่าหนึ่งประเภท (multi-label) หรือได้แค่ประเภทเดียว (multi-class) \n", + "* ถ้าเป็น regression แล้ว labels ของคุณกระจายตัวยังไง\n", + "* คุณมีตัวอย่างพอในแต่ละ class ให้โมเดลเรียนรู้หรือเปล่า\n", + "\n", + "features ของคุณหน้าตาเป็นอย่างไร\n", + "\n", + "* ถ้าเป็นรูปภาพ เป็นรูปจากมุมมองเดียวหรือหลายมุมมอง ขนาดรูป-ความละเอียดเท่ากันไหม\n", + "* ถ้าเป็นข้อความ การกระจายตัวของคำเป็นยังไง ทั้ง word count และ tfidf; คำที่เกิดบ่อยๆในแต่ละ labels คืออะไร\n", + "* ถ้าเป็นข้อมูลตาราง การกระจายตัวของ feature ที่เป็นตัวเลข (numerical) และประเภท (categorical) เป็นอย่างไรบ้าง ลองหา correlation ระหว่าง features ด้วยกัน และ feature กับ labels ดูว่ามีอะไรน่าสนใจไหม\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xcPUWlA9JGio" + }, + "source": [ + "### ☑️ เทรนโมเดลแรกให้จบจนเห็นผลลัพธ์เร็วที่สุดเท่าที่จะทำได้\n", + "\n", + "แน่นอนว่าการทำความสะอาดข้อมูลและการทำความเข้าใจข้อมูลเป็นสิ่งสำคัญ (คะแนนรวม 15+20=35) แต่วิธีที่ดีที่สุดที่จะรู้ว่าโมเดลของเราทำได้ดีแค่ไหนคือการลองเทรนไปเลยทั้งๆที่ไม่ต้องทำอะไรกับข้อมูลมาก หลังจากนั้นเราอาจจะมาไล่ดูข้อผิดพลาดของโมเดลเพื่อให้มีไอเดียในการจัดการข้อมูลและโมเดลต่อไปได้ \n", + "\n", + "คุณอาจจะคิดว่าการสร้าง ML model มีขั้นตอนแบบนี้:\n", + "\n", + "\n", + "\n", + "แต่จริงๆแล้วมันคือแบบนี้:\n", + "\n", + "\n", + "\n", + "เพราะฉะนั้นอย่ากลัวที่จะลองทำให้ครบขั้นตอนก่อนแล้วจึงย้อนกลับไปกลับมาเพื่อทำใหม่ให้ดียิ่งขึ้น" ] - }, - "metadata": {}, - "output_type": "display_data" } - ], - "source": [ - "dls = ImageDataLoaders.from_folder(path)\n", - "learn = cnn_learner(dls, resnet18, pretrained=False,\n", - " loss_func=F.cross_entropy, metrics=accuracy)\n", - "\n", - "#เทรนแค่ epoch เดียวก็ดีกว่าเทรน 20-40 epoch เมื่อกี้แล้ว\n", - "learn.fit_one_cycle(1, 0.1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uxfYQcZfJEsN" - }, - "source": [ - "# Checkpoint ท้ายบท" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wB1TBtie2VPG" - }, - "source": [ - "### ☑️ ดูวิดีโอ [3Blue1Brown](https://www.youtube.com/watch?v=IHZwWFHWa-w) เกี่ยวกับ SGD (มีซับไทย)\n", - "\n", - "วิดีโอนี้จะอธิบายการทำงานของ SGD ใน Neural Networks ที่เราเรียนกันในบทที่ 1 และ 3 ด้วยกราฟฟิคที่สวยงาม ช่วยเพิ่มความเข้าใจได้มากยิ่งขึ้น\n", - "\n", - "Youtube Link: [https://www.youtube.com/watch?v=IHZwWFHWa-w](https://www.youtube.com/watch?v=IHZwWFHWa-w)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4a0-VMv6Pau1" - }, - "source": [ - "## คำถามชวนคิดเกี่ยวกับบทเรียน" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "t_NIkQBMPgEh" - }, - "source": [ - "1. เราได้เรียนรู้การทำงานเกือบทุกส่วนของ ML Models ในบทเรียนนี้ คุณคิดว่าสำหรับโครงงานของคุณ การปรับแต่งส่วนไหนของโมเดลจะทำให้ได้ผลดีเพิ่มขึ้นมากที่สุด\n", - "\n", - "2. คุณคิดว่าการ initialize weights มีผลต่อคุณภาพของโมเดลที่ถูกเทรนหรือไม่ หากไม่ใช้ weights แบบสุ่ม เราควรใช้ weights อะไรในการเริ่มเทรนโมเดล\n", - "\n", - "3. Loss กับ Metric ต่างกันอย่างไร ทำไมคุณคิดว่าเราถึงเทรนโมเดลเพื่อให้ได้ Loss ที่ดีที่สุด แต่กลับวัดผลมันด้วย Metric\n", - "\n", - "4. คุณคิดว่า `Activation Function` มีความสำคัญอย่างไรกับประสิทธิภาพของโมเดล ML เราควรเลือกใช้อันไหนในกรณีไหน\n", - "\n", - "5. คุณคิดว่าทำไม Architecture อย่าง resnet18 ถึงทำได้ดีกว่า Architecture ที่เราสร้างขึ้นมาเองง่ายๆ ทั้งที่เทรนแค่ 1 epoch เทียบกับที่เราเทรนถึง 20-40 epochs\n", - "\n", - "6. คุณคิดว่ามีการ update weights ที่ดีกว่าเอา gradients มาลบไปทีละ iteration อย่างที่เราทำหรือไม่ คิดว่าวิธีไหนบ้าง หรือเป็นไปได้ไหมที่เราจะเทรนโมเดลโดยไม่ใช้ backpropagation กล่าวคือ update weights โดยไม่ใช้ gradients\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HeM7xTrJPc8w" - }, - "source": [ - "## สิ่งที่ควรเตรียมพร้อมสำหรับทำโครงงาน" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oPvpAnmLJGgU" - }, - "source": [ - "### ☑️ ทบทวนระบบการให้คะแนนโครงงานให้เรียบร้อย\n", - "\n", - "AI Builders จะออกใบประกาศนียบัตรจบการศึกษาให้กับผู้เข้าร่วมโครงการที่ส่งโครงงานได้คะแนนอย่างน้อย 70 จาก 100 คะแนนตามเกณฑ์ต่อไปนี้เท่านั้น\n", - "\n", - "1. problem statement; เหตุผลในการแก้ปัญหาเชิงธุรกิจ/ชีวิตประจำวันด้วย machine learning - 15 คะแนน\n", - "2. metrics and baselines; การให้เหตุผลเชื่อมโยงการแก้ปัญหากับตัวชี้วัดที่เลือก / การวัดผลเทียบกับวิธีแก้ปัญหาในปัจจุบัน - 15 คะแนน\n", - "3. data collection and cleaning; การเก็บและทำความสะอาดข้อมูล - 15 คะแนน\n", - "\n", - "**วันนี้เราจะคิดถึง 2 ข้อนี้เป็นพิเศษ**\n", - "\n", - "**4. exploratory data analysis; การทำความเข้าใจข้อมูล - 20 คะแนน**\n", - "\n", - "**5. modeling, validation and error analysis; การทำโมเดล, ทดสอบโมเดล และวิเคราะห์ข้อผิดพลาดของโมเดล - 20 คะแนน**\n", - "\n", - "6. deployment; การนำโมเดลไปใช้แก้ปัญหาจริง - 15 คะแนน" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9u8rUq1uPSuE" - }, - "source": [ - "### ☑️ เข้าใจชุดข้อมูล\n", - "\n", - "คุณรู้เกี่ยวกับ labels ที่คุณกำลังทำนายแค่ไหน \n", - "* ถ้าเป็น classification มันมี labels กี่ประเภท (class) แต่ละตัวอย่างเป็นได้มากกว่าหนึ่งประเภท (multi-label) หรือได้แค่ประเภทเดียว (multi-class) \n", - "* ถ้าเป็น regression แล้ว labels ของคุณกระจายตัวยังไง\n", - "* คุณมีตัวอย่างพอในแต่ละ class ให้โมเดลเรียนรู้หรือเปล่า\n", - "\n", - "features ของคุณหน้าตาเป็นอย่างไร\n", - "\n", - "* ถ้าเป็นรูปภาพ เป็นรูปจากมุมมองเดียวหรือหลายมุมมอง ขนาดรูป-ความละเอียดเท่ากันไหม\n", - "* ถ้าเป็นข้อความ การกระจายตัวของคำเป็นยังไง ทั้ง word count และ tfidf; คำที่เกิดบ่อยๆในแต่ละ labels คืออะไร\n", - "* ถ้าเป็นข้อมูลตาราง การกระจายตัวของ feature ที่เป็นตัวเลข (numerical) และประเภท (categorical) เป็นอย่างไรบ้าง ลองหา correlation ระหว่าง features ด้วยกัน และ feature กับ labels ดูว่ามีอะไรน่าสนใจไหม\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xcPUWlA9JGio" - }, - "source": [ - "### ☑️ เทรนโมเดลแรกให้จบจนเห็นผลลัพธ์เร็วที่สุดเท่าที่จะทำได้\n", - "\n", - "แน่นอนว่าการทำความสะอาดข้อมูลและการทำความเข้าใจข้อมูลเป็นสิ่งสำคัญ (คะแนนรวม 15+20=35) แต่วิธีที่ดีที่สุดที่จะรู้ว่าโมเดลของเราทำได้ดีแค่ไหนคือการลองเทรนไปเลยทั้งๆที่ไม่ต้องทำอะไรกับข้อมูลมาก หลังจากนั้นเราอาจจะมาไล่ดูข้อผิดพลาดของโมเดลเพื่อให้มีไอเดียในการจัดการข้อมูลและโมเดลต่อไปได้ \n", - "\n", - "คุณอาจจะคิดว่าการสร้าง ML model มีขั้นตอนแบบนี้:\n", - "\n", - "\n", - "\n", - "แต่จริงๆแล้วมันคือแบบนี้:\n", - "\n", - "\n", - "\n", - "เพราะฉะนั้นอย่ากลัวที่จะลองทำให้ครบขั้นตอนก่อนแล้วจึงย้อนกลับไปกลับมาเพื่อทำใหม่ให้ดียิ่งขึ้น" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "03_sgd_from_scratch.ipynb", - "provenance": [], - "toc_visible": true - }, - "jupytext": { - "split_at_heading": true - }, - "kernelspec": { - "display_name": "conda_amazonei_pytorch_latest_p36", - "language": "python", - "name": "conda_amazonei_pytorch_latest_p36" + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "03_sgd_from_scratch.ipynb", + "provenance": [], + "toc_visible": true + }, + "jupytext": { + "split_at_heading": true + }, + "kernelspec": { + "display_name": "conda_amazonei_pytorch_latest_p36", + "language": "python", + "name": "conda_amazonei_pytorch_latest_p36" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.13" + } }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.13" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file