diff options
Diffstat (limited to 'docs/pytorch_vol2.ipynb')
-rw-r--r-- | docs/pytorch_vol2.ipynb | 388 |
1 files changed, 388 insertions, 0 deletions
diff --git a/docs/pytorch_vol2.ipynb b/docs/pytorch_vol2.ipynb new file mode 100644 index 00000000..7f931149 --- /dev/null +++ b/docs/pytorch_vol2.ipynb @@ -0,0 +1,388 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "pytorch.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lxhgSo4rOWCg", + "outputId": "5242373c-1d80-4a96-ef17-1243f7eea994" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (1.10.0+cu111)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (3.10.0.2)\n" + ] + } + ], + "source": [ + "pip install torch" + ] + }, + { + "cell_type": "code", + "source": [ + "print(torch.__version__)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "458HSdE3WK68", + "outputId": "be8888f7-f9ee-474f-b393-1d3545382b27" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "1.10.0+cu111\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "from torchvision import datasets, transforms\n", + "\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torch.optim as optim\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from tqdm import tqdm\n", + "\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda:0\")\n", + " print(\"GPU\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + " print(\"CPU\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "WoWY-yVoPKiI", + "outputId": "707345d2-53fb-4b01-f417-0c5b2db0aafa" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "GPU\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "class Network(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.input_layer = nn.Linear(784, 64)\n", + " self.hidden1 = nn.Linear(64, 64)\n", + " self.hidden2 = nn.Linear(64, 64)\n", + " self.output = nn.Linear(64, 10)\n", + "\n", + " def forward(self, data):\n", + " data = F.relu(self.input_layer(data))\n", + " data = F.relu(self.hidden1(data))\n", + " data = F.relu(self.hidden2(data))\n", + " data = F.relu(self.output(data))\n", + "\n", + " return F.log_softmax(data, dim=1)" + ], + "metadata": { + "id": "zHPjts1vPjDo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from torch.utils.data.dataset import T\n", + "training = datasets.MNIST(\"\", train = True, download = True,\n", + " transform = transforms.Compose([transforms.ToTensor()]))\n", + "\n", + "testing = datasets.MNIST(\"\", train = False, download = True,\n", + " transform = transforms.Compose([transforms.ToTensor()]))\n", + "\n", + "train_set = torch.utils.data.DataLoader(training, batch_size=10, shuffle=True)\n", + "test_set = torch.utils.data.DataLoader(testing, batch_size=10, shuffle=True)" + ], + "metadata": { + "id": "Phj4o7piR4FU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from torch.autograd import backward\n", + "network = Network().to(device)\n", + "learn_rate = optim.Adam(network.parameters(), lr=0.001)\n", + "epochs = 4\n", + "\n", + "for i in tqdm(range(epochs)):\n", + " for data in train_set:\n", + " image, output = data\n", + " image = image.to(device)\n", + " output = output.to(device)\n", + " network.zero_grad()\n", + " result = network(image.view(-1, 784))\n", + " loss = F.nll_loss(result, output)\n", + " loss.backward()\n", + " learn_rate.step()\n", + " print(loss)\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eqIk--nMTYIm", + "outputId": "44638ff7-e230-4f03-e4a7-cda5f75ea10d" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + " 25%|██▌ | 1/4 [00:15<00:45, 15.29s/it]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor(0.7714, device='cuda:0', grad_fn=<NllLossBackward0>)\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\r 50%|█████ | 2/4 [00:30<00:30, 15.16s/it]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor(0.9215, device='cuda:0', grad_fn=<NllLossBackward0>)\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\r 75%|███████▌ | 3/4 [00:45<00:15, 15.05s/it]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor(1.3817, device='cuda:0', grad_fn=<NllLossBackward0>)\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 4/4 [01:00<00:00, 15.09s/it]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor(1.6124, device='cuda:0', grad_fn=<NllLossBackward0>)\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Test Network\n", + "network.eval()\n", + "\n", + "correct = 0\n", + "total = 0\n", + "\n", + "with torch.no_grad():\n", + " for data in test_set:\n", + " image, output = data\n", + " image = image.to(device)\n", + " output = output.to(device)\n", + " result = network(image.view(-1, 784))\n", + " for index, tensor_value in enumerate(result):\n", + " total += 1\n", + " if torch.argmax(tensor_value) == output[index]:\n", + " correct += 1\n", + "\n", + "accuracy = correct/total\n", + "print(f\"Accuracy: {accuracy}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fLRRyQnuYmRn", + "outputId": "0dc1fc3c-bf92-40dc-fd3a-d7d0335fc94e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Accuracy: 0.7795\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/gdrive', force_remount=True)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "97UyaTsHcIft", + "outputId": "be688931-b92b-4efb-b20e-12be438f03d8" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/gdrive\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "!ls '/content/gdrive/My Drive/TestPytorch'" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RQop3IL3csZo", + "outputId": "635017e3-b7d2-42a1-9e5c-72e6fefc2f73" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "5.png 7.png\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from PIL import Image\n", + "import numpy as np\n", + "import PIL.ImageOps\n", + "\n", + "img = Image.open(\"gdrive/My Drive/TestPytorch/5.png\")\n", + "img = img.resize((28, 28))\n", + "img = img.convert(\"L\")\n", + "img = PIL.ImageOps.invert(img)\n", + "\n", + "plt.imshow(img)\n", + "\n", + "img = np.array(img)\n", + "img = img / 255\n", + "image = torch.from_numpy(img)\n", + "image = image.float()\n", + "image = image.to(device)\n", + "\n", + "res = network.forward(image.view(-1, 28*28))\n", + "res = res.to(device)\n", + "print(torch.argmax(output))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 283 + }, + "id": "GMf65cgLdlfa", + "outputId": "422c46d3-526c-402f-f7b9-cfa713d3fd72" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor(6, device='cuda:0')\n" + ] + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAALYklEQVR4nO3dX6hchZ3A8e+vSUwwbSWpbYh/utriw4ZlN13uRpfK4iIV60ssBWkeJELhtqDQQtldaRfqoyzbyj7sCmkNTf/ZLVZrHmS32VAQwYpXNyZR2+q6kZqNydo8GIXGRH/7cI/da7wz9zrnzJy5/X0/cJmZc2YyPwa/nplz5t4TmYmkP3zv63sASZNh7FIRxi4VYexSEcYuFbF6kk92XqzNdayf5FNKpfyO13kjT8di61rFHhHXA/8ErAK+nZl3Drv/OtZzZVzb5iklDfFY7h+4buS38RGxCvhn4NPAFmBHRGwZ9d+TNF5tPrNvA57PzBcy8w3gR8D2bsaS1LU2sV8M/GbB7ZeaZe8QEbMRMRcRc2c43eLpJLUx9r3xmbkrM2cyc2YNa8f9dJIGaBP7UeDSBbcvaZZJmkJtYn8cuCIiLo+I84DPAXu7GUtS10Y+9JaZZyPiNuDfmT/0tjszn+5ssvcqFj20+P/87T4V1+o4e2Y+BDzU0SySxsivy0pFGLtUhLFLRRi7VISxS0UYu1TERH+fvbWr/nTgqlu/d9/Qh/7Nv+4cuv6yv390pJGklcItu1SEsUtFGLtUhLFLRRi7VISxS0WsqENvpy47f+C6pQ6tXf7ga0PX+wuw+kPnll0qwtilIoxdKsLYpSKMXSrC2KUijF0qYkUdZ39r1eB1614Z/qek8/FDHU8jrSxu2aUijF0qwtilIoxdKsLYpSKMXSrC2KUiVtZx9tWDj6U/9bf/MvSxV73yxaHrL/j+L0aaSVopWsUeEUeAU8CbwNnMnOliKEnd62LL/teZ+UoH/46kMfIzu1RE29gT+FlEPBERs4vdISJmI2IuIubOcLrl00kaVdu38Vdn5tGI+AiwLyJ+mZkPL7xDZu4CdgF8MDb6dx2lnrTasmfm0ebyBPAAsK2LoSR1b+TYI2J9RHzg7evAdcDhrgaT1K02b+M3AQ9ExNv/zg8z8986mWqAD913cOC6656/Zehjf7t9+O+7XzDKQNIKMnLsmfkC8GcdziJpjDz0JhVh7FIRxi4VYexSEcYuFbGyfsX19dcHrotHh/+p6NWfurLrcaQVxS27VISxS0UYu1SEsUtFGLtUhLFLRRi7VMSKOs7++mcHHyu/4D+PD33sm+f5R3JUm1t2qQhjl4owdqkIY5eKMHapCGOXijB2qYgVdZz91CWrBq77y6+9OPSx/33/RV2PI60obtmlIoxdKsLYpSKMXSrC2KUijF0qwtilIlbUcfaLvv3UwHX3/fFfDH3slh/+z9D1Z0eaSFo5ltyyR8TuiDgREYcXLNsYEfsi4rnmcsN4x5TU1nLexn8HuP6cZbcD+zPzCmB/c1vSFFsy9sx8GDh5zuLtwJ7m+h7gxo7nktSxUT+zb8rMY831l4FNg+4YEbPALMA6zh/x6SS11XpvfGYmMPCvOWbmrsycycyZNaxt+3SSRjRq7McjYjNAc3miu5EkjcOose8FdjbXdwIPdjOOpHGJ+XfhQ+4QcS9wDXAhcBz4OvBT4MfAR4EXgZsy89ydeO/ywdiYV8a1LUeWNMhjuZ9X82Qstm7JHXSZuWPAKquVVhC/LisVYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRS8YeEbsj4kREHF6w7I6IOBoRB5qfG8Y7pqS2lrNl/w5w/SLL78rMrc3PQ92OJalrS8aemQ8DJycwi6QxavOZ/baIONi8zd8w6E4RMRsRcxExd4bTLZ5OUhujxn438HFgK3AM+MagO2bmrsycycyZNawd8ekktTVS7Jl5PDPfzMy3gG8B27odS1LXRoo9IjYvuPkZ4PCg+0qaDquXukNE3AtcA1wYES8BXweuiYitQAJHgC+McUZJHVgy9szcscjie8Ywi6Qx8ht0UhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFbFk7BFxaUT8PCKeiYinI+JLzfKNEbEvIp5rLjeMf1xJo1rOlv0s8JXM3AJcBdwaEVuA24H9mXkFsL+5LWlKLRl7Zh7LzCeb66eAZ4GLge3AnuZue4AbxzWkpPZWv5c7R8RlwCeAx4BNmXmsWfUysGnAY2aBWYB1nD/qnJJaWvYOuoh4P/AT4MuZ+erCdZmZQC72uMzclZkzmTmzhrWthpU0umXFHhFrmA/9B5l5f7P4eERsbtZvBk6MZ0RJXVjO3vgA7gGezcxvLli1F9jZXN8JPNj9eJK6spzP7J8EbgYORcSBZtlXgTuBH0fE54EXgZvGM6KkLiwZe2Y+AsSA1dd2O46kcfEbdFIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHLOT/7pRHx84h4JiKejogvNcvviIijEXGg+blh/ONKGtVyzs9+FvhKZj4ZER8AnoiIfc26uzLzH8c3nqSuLOf87MeAY831UxHxLHDxuAeT1K339Jk9Ii4DPgE81iy6LSIORsTuiNgw4DGzETEXEXNnON1qWEmjW3bsEfF+4CfAlzPzVeBu4OPAVua3/N9Y7HGZuSszZzJzZg1rOxhZ0iiWFXtErGE+9B9k5v0AmXk8M9/MzLeAbwHbxjempLaWszc+gHuAZzPzmwuWb15wt88Ah7sfT1JXlrM3/pPAzcChiDjQLPsqsCMitgIJHAG+MJYJJXViOXvjHwFikVUPdT+OpHHxG3RSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFRGZObkni/hf4MUFiy4EXpnYAO/NtM42rXOBs42qy9n+KDM/vNiKicb+riePmMvMmd4GGGJaZ5vWucDZRjWp2XwbLxVh7FIRfce+q+fnH2ZaZ5vWucDZRjWR2Xr9zC5pcvreskuaEGOXiugl9oi4PiJ+FRHPR8TtfcwwSEQciYhDzWmo53qeZXdEnIiIwwuWbYyIfRHxXHO56Dn2epptKk7jPeQ0472+dn2f/nzin9kjYhXwa+BTwEvA48COzHxmooMMEBFHgJnM7P0LGBHxV8BrwHcz80+aZf8AnMzMO5v/UW7IzL+bktnuAF7r+zTezdmKNi88zThwI3ALPb52Q+a6iQm8bn1s2bcBz2fmC5n5BvAjYHsPc0y9zHwYOHnO4u3Anub6Hub/Y5m4AbNNhcw8lplPNtdPAW+fZrzX127IXBPRR+wXA79ZcPslput87wn8LCKeiIjZvodZxKbMPNZcfxnY1Ocwi1jyNN6TdM5pxqfmtRvl9OdtuYPu3a7OzD8HPg3c2rxdnUo5/xlsmo6dLus03pOyyGnGf6/P127U05+31UfsR4FLF9y+pFk2FTLzaHN5AniA6TsV9fG3z6DbXJ7oeZ7fm6bTeC92mnGm4LXr8/TnfcT+OHBFRFweEecBnwP29jDHu0TE+mbHCRGxHriO6TsV9V5gZ3N9J/Bgj7O8w7ScxnvQacbp+bXr/fTnmTnxH+AG5vfI/xfwtT5mGDDXx4Cnmp+n+54NuJf5t3VnmN+38XngQ8B+4DngP4CNUzTb94BDwEHmw9rc02xXM/8W/SBwoPm5oe/XbshcE3nd/LqsVIQ76KQijF0qwtilIoxdKsLYpSKMXSrC2KUi/g+c6njxmLUDLAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + } + } + ] + } + ] +}
\ No newline at end of file |