{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "pytorch.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lxhgSo4rOWCg", "outputId": "5242373c-1d80-4a96-ef17-1243f7eea994" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (1.10.0+cu111)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (3.10.0.2)\n" ] } ], "source": [ "pip install torch" ] }, { "cell_type": "code", "source": [ "print(torch.__version__)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "458HSdE3WK68", "outputId": "be8888f7-f9ee-474f-b393-1d3545382b27" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "1.10.0+cu111\n" ] } ] }, { "cell_type": "code", "source": [ "import torch\n", "from torchvision import datasets, transforms\n", "\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "import torch.optim as optim\n", "import matplotlib.pyplot as plt\n", "\n", "from tqdm import tqdm\n", "\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda:0\")\n", " print(\"GPU\")\n", "else:\n", " device = torch.device(\"cpu\")\n", " print(\"CPU\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "WoWY-yVoPKiI", "outputId": "707345d2-53fb-4b01-f417-0c5b2db0aafa" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "GPU\n" ] } ] }, { "cell_type": "code", "source": [ "class Network(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.input_layer = nn.Linear(784, 64)\n", " self.hidden1 = nn.Linear(64, 64)\n", " self.hidden2 = nn.Linear(64, 64)\n", " self.output = nn.Linear(64, 10)\n", "\n", " def forward(self, data):\n", " data = F.relu(self.input_layer(data))\n", " data = F.relu(self.hidden1(data))\n", " data = F.relu(self.hidden2(data))\n", " data = F.relu(self.output(data))\n", "\n", " return F.log_softmax(data, dim=1)" ], "metadata": { "id": "zHPjts1vPjDo" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from torch.utils.data.dataset import T\n", "training = datasets.MNIST(\"\", train = True, download = True,\n", " transform = transforms.Compose([transforms.ToTensor()]))\n", "\n", "testing = datasets.MNIST(\"\", train = False, download = True,\n", " transform = transforms.Compose([transforms.ToTensor()]))\n", "\n", "train_set = torch.utils.data.DataLoader(training, batch_size=10, shuffle=True)\n", "test_set = torch.utils.data.DataLoader(testing, batch_size=10, shuffle=True)" ], "metadata": { "id": "Phj4o7piR4FU" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from torch.autograd import backward\n", "network = Network().to(device)\n", "learn_rate = optim.Adam(network.parameters(), lr=0.001)\n", "epochs = 4\n", "\n", "for i in tqdm(range(epochs)):\n", " for data in train_set:\n", " image, output = data\n", " image = image.to(device)\n", " output = output.to(device)\n", " network.zero_grad()\n", " result = network(image.view(-1, 784))\n", " loss = F.nll_loss(result, output)\n", " loss.backward()\n", " learn_rate.step()\n", " print(loss)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "eqIk--nMTYIm", "outputId": "44638ff7-e230-4f03-e4a7-cda5f75ea10d" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ " 25%|██▌ | 1/4 [00:15<00:45, 15.29s/it]" ] }, { "output_type": "stream", "name": "stdout", "text": [ "tensor(0.7714, device='cuda:0', grad_fn=)\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "\r 50%|█████ | 2/4 [00:30<00:30, 15.16s/it]" ] }, { "output_type": "stream", "name": "stdout", "text": [ "tensor(0.9215, device='cuda:0', grad_fn=)\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "\r 75%|███████▌ | 3/4 [00:45<00:15, 15.05s/it]" ] }, { "output_type": "stream", "name": "stdout", "text": [ "tensor(1.3817, device='cuda:0', grad_fn=)\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "100%|██████████| 4/4 [01:00<00:00, 15.09s/it]" ] }, { "output_type": "stream", "name": "stdout", "text": [ "tensor(1.6124, device='cuda:0', grad_fn=)\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "\n" ] } ] }, { "cell_type": "code", "source": [ "# Test Network\n", "network.eval()\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for data in test_set:\n", " image, output = data\n", " image = image.to(device)\n", " output = output.to(device)\n", " result = network(image.view(-1, 784))\n", " for index, tensor_value in enumerate(result):\n", " total += 1\n", " if torch.argmax(tensor_value) == output[index]:\n", " correct += 1\n", "\n", "accuracy = correct/total\n", "print(f\"Accuracy: {accuracy}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fLRRyQnuYmRn", "outputId": "0dc1fc3c-bf92-40dc-fd3a-d7d0335fc94e" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Accuracy: 0.7795\n" ] } ] }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/gdrive', force_remount=True)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "97UyaTsHcIft", "outputId": "be688931-b92b-4efb-b20e-12be438f03d8" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mounted at /content/gdrive\n" ] } ] }, { "cell_type": "code", "source": [ "!ls '/content/gdrive/My Drive/TestPytorch'" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "RQop3IL3csZo", "outputId": "635017e3-b7d2-42a1-9e5c-72e6fefc2f73" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "5.png 7.png\n" ] } ] }, { "cell_type": "code", "source": [ "from PIL import Image\n", "import numpy as np\n", "import PIL.ImageOps\n", "\n", "img = Image.open(\"gdrive/My Drive/TestPytorch/5.png\")\n", "img = img.resize((28, 28))\n", "img = img.convert(\"L\")\n", "img = PIL.ImageOps.invert(img)\n", "\n", "plt.imshow(img)\n", "\n", "img = np.array(img)\n", "img = img / 255\n", "image = torch.from_numpy(img)\n", "image = image.float()\n", "image = image.to(device)\n", "\n", "res = network.forward(image.view(-1, 28*28))\n", "res = res.to(device)\n", "print(torch.argmax(output))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 283 }, "id": "GMf65cgLdlfa", "outputId": "422c46d3-526c-402f-f7b9-cfa713d3fd72" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor(6, device='cuda:0')\n" ] }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAALYklEQVR4nO3dX6hchZ3A8e+vSUwwbSWpbYh/utriw4ZlN13uRpfK4iIV60ssBWkeJELhtqDQQtldaRfqoyzbyj7sCmkNTf/ZLVZrHmS32VAQwYpXNyZR2+q6kZqNydo8GIXGRH/7cI/da7wz9zrnzJy5/X0/cJmZc2YyPwa/nplz5t4TmYmkP3zv63sASZNh7FIRxi4VYexSEcYuFbF6kk92XqzNdayf5FNKpfyO13kjT8di61rFHhHXA/8ErAK+nZl3Drv/OtZzZVzb5iklDfFY7h+4buS38RGxCvhn4NPAFmBHRGwZ9d+TNF5tPrNvA57PzBcy8w3gR8D2bsaS1LU2sV8M/GbB7ZeaZe8QEbMRMRcRc2c43eLpJLUx9r3xmbkrM2cyc2YNa8f9dJIGaBP7UeDSBbcvaZZJmkJtYn8cuCIiLo+I84DPAXu7GUtS10Y+9JaZZyPiNuDfmT/0tjszn+5ssvcqFj20+P/87T4V1+o4e2Y+BDzU0SySxsivy0pFGLtUhLFLRRi7VISxS0UYu1TERH+fvbWr/nTgqlu/d9/Qh/7Nv+4cuv6yv390pJGklcItu1SEsUtFGLtUhLFLRRi7VISxS0WsqENvpy47f+C6pQ6tXf7ga0PX+wuw+kPnll0qwtilIoxdKsLYpSKMXSrC2KUijF0qYkUdZ39r1eB1614Z/qek8/FDHU8jrSxu2aUijF0qwtilIoxdKsLYpSKMXSrC2KUiVtZx9tWDj6U/9bf/MvSxV73yxaHrL/j+L0aaSVopWsUeEUeAU8CbwNnMnOliKEnd62LL/teZ+UoH/46kMfIzu1RE29gT+FlEPBERs4vdISJmI2IuIubOcLrl00kaVdu38Vdn5tGI+AiwLyJ+mZkPL7xDZu4CdgF8MDb6dx2lnrTasmfm0ebyBPAAsK2LoSR1b+TYI2J9RHzg7evAdcDhrgaT1K02b+M3AQ9ExNv/zg8z8986mWqAD913cOC6656/Zehjf7t9+O+7XzDKQNIKMnLsmfkC8GcdziJpjDz0JhVh7FIRxi4VYexSEcYuFbGyfsX19dcHrotHh/+p6NWfurLrcaQVxS27VISxS0UYu1SEsUtFGLtUhLFLRRi7VMSKOs7++mcHHyu/4D+PD33sm+f5R3JUm1t2qQhjl4owdqkIY5eKMHapCGOXijB2qYgVdZz91CWrBq77y6+9OPSx/33/RV2PI60obtmlIoxdKsLYpSKMXSrC2KUijF0qwtilIlbUcfaLvv3UwHX3/fFfDH3slh/+z9D1Z0eaSFo5ltyyR8TuiDgREYcXLNsYEfsi4rnmcsN4x5TU1nLexn8HuP6cZbcD+zPzCmB/c1vSFFsy9sx8GDh5zuLtwJ7m+h7gxo7nktSxUT+zb8rMY831l4FNg+4YEbPALMA6zh/x6SS11XpvfGYmMPCvOWbmrsycycyZNaxt+3SSRjRq7McjYjNAc3miu5EkjcOose8FdjbXdwIPdjOOpHGJ+XfhQ+4QcS9wDXAhcBz4OvBT4MfAR4EXgZsy89ydeO/ywdiYV8a1LUeWNMhjuZ9X82Qstm7JHXSZuWPAKquVVhC/LisVYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRS8YeEbsj4kREHF6w7I6IOBoRB5qfG8Y7pqS2lrNl/w5w/SLL78rMrc3PQ92OJalrS8aemQ8DJycwi6QxavOZ/baIONi8zd8w6E4RMRsRcxExd4bTLZ5OUhujxn438HFgK3AM+MagO2bmrsycycyZNawd8ekktTVS7Jl5PDPfzMy3gG8B27odS1LXRoo9IjYvuPkZ4PCg+0qaDquXukNE3AtcA1wYES8BXweuiYitQAJHgC+McUZJHVgy9szcscjie8Ywi6Qx8ht0UhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFbFk7BFxaUT8PCKeiYinI+JLzfKNEbEvIp5rLjeMf1xJo1rOlv0s8JXM3AJcBdwaEVuA24H9mXkFsL+5LWlKLRl7Zh7LzCeb66eAZ4GLge3AnuZue4AbxzWkpPZWv5c7R8RlwCeAx4BNmXmsWfUysGnAY2aBWYB1nD/qnJJaWvYOuoh4P/AT4MuZ+erCdZmZQC72uMzclZkzmTmzhrWthpU0umXFHhFrmA/9B5l5f7P4eERsbtZvBk6MZ0RJXVjO3vgA7gGezcxvLli1F9jZXN8JPNj9eJK6spzP7J8EbgYORcSBZtlXgTuBH0fE54EXgZvGM6KkLiwZe2Y+AsSA1dd2O46kcfEbdFIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHLOT/7pRHx84h4JiKejogvNcvviIijEXGg+blh/ONKGtVyzs9+FvhKZj4ZER8AnoiIfc26uzLzH8c3nqSuLOf87MeAY831UxHxLHDxuAeT1K339Jk9Ii4DPgE81iy6LSIORsTuiNgw4DGzETEXEXNnON1qWEmjW3bsEfF+4CfAlzPzVeBu4OPAVua3/N9Y7HGZuSszZzJzZg1rOxhZ0iiWFXtErGE+9B9k5v0AmXk8M9/MzLeAbwHbxjempLaWszc+gHuAZzPzmwuWb15wt88Ah7sfT1JXlrM3/pPAzcChiDjQLPsqsCMitgIJHAG+MJYJJXViOXvjHwFikVUPdT+OpHHxG3RSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFRGZObkni/hf4MUFiy4EXpnYAO/NtM42rXOBs42qy9n+KDM/vNiKicb+riePmMvMmd4GGGJaZ5vWucDZRjWp2XwbLxVh7FIRfce+q+fnH2ZaZ5vWucDZRjWR2Xr9zC5pcvreskuaEGOXiugl9oi4PiJ+FRHPR8TtfcwwSEQciYhDzWmo53qeZXdEnIiIwwuWbYyIfRHxXHO56Dn2epptKk7jPeQ0472+dn2f/nzin9kjYhXwa+BTwEvA48COzHxmooMMEBFHgJnM7P0LGBHxV8BrwHcz80+aZf8AnMzMO5v/UW7IzL+bktnuAF7r+zTezdmKNi88zThwI3ALPb52Q+a6iQm8bn1s2bcBz2fmC5n5BvAjYHsPc0y9zHwYOHnO4u3Anub6Hub/Y5m4AbNNhcw8lplPNtdPAW+fZrzX127IXBPRR+wXA79ZcPslput87wn8LCKeiIjZvodZxKbMPNZcfxnY1Ocwi1jyNN6TdM5pxqfmtRvl9OdtuYPu3a7OzD8HPg3c2rxdnUo5/xlsmo6dLus03pOyyGnGf6/P127U05+31UfsR4FLF9y+pFk2FTLzaHN5AniA6TsV9fG3z6DbXJ7oeZ7fm6bTeC92mnGm4LXr8/TnfcT+OHBFRFweEecBnwP29jDHu0TE+mbHCRGxHriO6TsV9V5gZ3N9J/Bgj7O8w7ScxnvQacbp+bXr/fTnmTnxH+AG5vfI/xfwtT5mGDDXx4Cnmp+n+54NuJf5t3VnmN+38XngQ8B+4DngP4CNUzTb94BDwEHmw9rc02xXM/8W/SBwoPm5oe/XbshcE3nd/LqsVIQ76KQijF0qwtilIoxdKsLYpSKMXSrC2KUi/g+c6njxmLUDLAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } } ] } ] }