aboutsummaryrefslogtreecommitdiff
path: root/docs/pytorch_vol2.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'docs/pytorch_vol2.ipynb')
-rw-r--r--docs/pytorch_vol2.ipynb388
1 files changed, 388 insertions, 0 deletions
diff --git a/docs/pytorch_vol2.ipynb b/docs/pytorch_vol2.ipynb
new file mode 100644
index 00000000..7f931149
--- /dev/null
+++ b/docs/pytorch_vol2.ipynb
@@ -0,0 +1,388 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "pytorch.ipynb",
+ "provenance": [],
+ "collapsed_sections": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "lxhgSo4rOWCg",
+ "outputId": "5242373c-1d80-4a96-ef17-1243f7eea994"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (1.10.0+cu111)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (3.10.0.2)\n"
+ ]
+ }
+ ],
+ "source": [
+ "pip install torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(torch.__version__)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "458HSdE3WK68",
+ "outputId": "be8888f7-f9ee-474f-b393-1d3545382b27"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "1.10.0+cu111\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "from torchvision import datasets, transforms\n",
+ "\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "import torch.optim as optim\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "if torch.cuda.is_available():\n",
+ " device = torch.device(\"cuda:0\")\n",
+ " print(\"GPU\")\n",
+ "else:\n",
+ " device = torch.device(\"cpu\")\n",
+ " print(\"CPU\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "WoWY-yVoPKiI",
+ "outputId": "707345d2-53fb-4b01-f417-0c5b2db0aafa"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "GPU\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class Network(nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.input_layer = nn.Linear(784, 64)\n",
+ " self.hidden1 = nn.Linear(64, 64)\n",
+ " self.hidden2 = nn.Linear(64, 64)\n",
+ " self.output = nn.Linear(64, 10)\n",
+ "\n",
+ " def forward(self, data):\n",
+ " data = F.relu(self.input_layer(data))\n",
+ " data = F.relu(self.hidden1(data))\n",
+ " data = F.relu(self.hidden2(data))\n",
+ " data = F.relu(self.output(data))\n",
+ "\n",
+ " return F.log_softmax(data, dim=1)"
+ ],
+ "metadata": {
+ "id": "zHPjts1vPjDo"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from torch.utils.data.dataset import T\n",
+ "training = datasets.MNIST(\"\", train = True, download = True,\n",
+ " transform = transforms.Compose([transforms.ToTensor()]))\n",
+ "\n",
+ "testing = datasets.MNIST(\"\", train = False, download = True,\n",
+ " transform = transforms.Compose([transforms.ToTensor()]))\n",
+ "\n",
+ "train_set = torch.utils.data.DataLoader(training, batch_size=10, shuffle=True)\n",
+ "test_set = torch.utils.data.DataLoader(testing, batch_size=10, shuffle=True)"
+ ],
+ "metadata": {
+ "id": "Phj4o7piR4FU"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from torch.autograd import backward\n",
+ "network = Network().to(device)\n",
+ "learn_rate = optim.Adam(network.parameters(), lr=0.001)\n",
+ "epochs = 4\n",
+ "\n",
+ "for i in tqdm(range(epochs)):\n",
+ " for data in train_set:\n",
+ " image, output = data\n",
+ " image = image.to(device)\n",
+ " output = output.to(device)\n",
+ " network.zero_grad()\n",
+ " result = network(image.view(-1, 784))\n",
+ " loss = F.nll_loss(result, output)\n",
+ " loss.backward()\n",
+ " learn_rate.step()\n",
+ " print(loss)\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "eqIk--nMTYIm",
+ "outputId": "44638ff7-e230-4f03-e4a7-cda5f75ea10d"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ " 25%|██▌ | 1/4 [00:15<00:45, 15.29s/it]"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "tensor(0.7714, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "\r 50%|█████ | 2/4 [00:30<00:30, 15.16s/it]"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "tensor(0.9215, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "\r 75%|███████▌ | 3/4 [00:45<00:15, 15.05s/it]"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "tensor(1.3817, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "100%|██████████| 4/4 [01:00<00:00, 15.09s/it]"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "tensor(1.6124, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Test Network\n",
+ "network.eval()\n",
+ "\n",
+ "correct = 0\n",
+ "total = 0\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for data in test_set:\n",
+ " image, output = data\n",
+ " image = image.to(device)\n",
+ " output = output.to(device)\n",
+ " result = network(image.view(-1, 784))\n",
+ " for index, tensor_value in enumerate(result):\n",
+ " total += 1\n",
+ " if torch.argmax(tensor_value) == output[index]:\n",
+ " correct += 1\n",
+ "\n",
+ "accuracy = correct/total\n",
+ "print(f\"Accuracy: {accuracy}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "fLRRyQnuYmRn",
+ "outputId": "0dc1fc3c-bf92-40dc-fd3a-d7d0335fc94e"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Accuracy: 0.7795\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from google.colab import drive\n",
+ "drive.mount('/content/gdrive', force_remount=True)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "97UyaTsHcIft",
+ "outputId": "be688931-b92b-4efb-b20e-12be438f03d8"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Mounted at /content/gdrive\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!ls '/content/gdrive/My Drive/TestPytorch'"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "RQop3IL3csZo",
+ "outputId": "635017e3-b7d2-42a1-9e5c-72e6fefc2f73"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "5.png 7.png\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from PIL import Image\n",
+ "import numpy as np\n",
+ "import PIL.ImageOps\n",
+ "\n",
+ "img = Image.open(\"gdrive/My Drive/TestPytorch/5.png\")\n",
+ "img = img.resize((28, 28))\n",
+ "img = img.convert(\"L\")\n",
+ "img = PIL.ImageOps.invert(img)\n",
+ "\n",
+ "plt.imshow(img)\n",
+ "\n",
+ "img = np.array(img)\n",
+ "img = img / 255\n",
+ "image = torch.from_numpy(img)\n",
+ "image = image.float()\n",
+ "image = image.to(device)\n",
+ "\n",
+ "res = network.forward(image.view(-1, 28*28))\n",
+ "res = res.to(device)\n",
+ "print(torch.argmax(output))"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 283
+ },
+ "id": "GMf65cgLdlfa",
+ "outputId": "422c46d3-526c-402f-f7b9-cfa713d3fd72"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "tensor(6, device='cuda:0')\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAALYklEQVR4nO3dX6hchZ3A8e+vSUwwbSWpbYh/utriw4ZlN13uRpfK4iIV60ssBWkeJELhtqDQQtldaRfqoyzbyj7sCmkNTf/ZLVZrHmS32VAQwYpXNyZR2+q6kZqNydo8GIXGRH/7cI/da7wz9zrnzJy5/X0/cJmZc2YyPwa/nplz5t4TmYmkP3zv63sASZNh7FIRxi4VYexSEcYuFbF6kk92XqzNdayf5FNKpfyO13kjT8di61rFHhHXA/8ErAK+nZl3Drv/OtZzZVzb5iklDfFY7h+4buS38RGxCvhn4NPAFmBHRGwZ9d+TNF5tPrNvA57PzBcy8w3gR8D2bsaS1LU2sV8M/GbB7ZeaZe8QEbMRMRcRc2c43eLpJLUx9r3xmbkrM2cyc2YNa8f9dJIGaBP7UeDSBbcvaZZJmkJtYn8cuCIiLo+I84DPAXu7GUtS10Y+9JaZZyPiNuDfmT/0tjszn+5ssvcqFj20+P/87T4V1+o4e2Y+BDzU0SySxsivy0pFGLtUhLFLRRi7VISxS0UYu1TERH+fvbWr/nTgqlu/d9/Qh/7Nv+4cuv6yv390pJGklcItu1SEsUtFGLtUhLFLRRi7VISxS0WsqENvpy47f+C6pQ6tXf7ga0PX+wuw+kPnll0qwtilIoxdKsLYpSKMXSrC2KUijF0qYkUdZ39r1eB1614Z/qek8/FDHU8jrSxu2aUijF0qwtilIoxdKsLYpSKMXSrC2KUiVtZx9tWDj6U/9bf/MvSxV73yxaHrL/j+L0aaSVopWsUeEUeAU8CbwNnMnOliKEnd62LL/teZ+UoH/46kMfIzu1RE29gT+FlEPBERs4vdISJmI2IuIubOcLrl00kaVdu38Vdn5tGI+AiwLyJ+mZkPL7xDZu4CdgF8MDb6dx2lnrTasmfm0ebyBPAAsK2LoSR1b+TYI2J9RHzg7evAdcDhrgaT1K02b+M3AQ9ExNv/zg8z8986mWqAD913cOC6656/Zehjf7t9+O+7XzDKQNIKMnLsmfkC8GcdziJpjDz0JhVh7FIRxi4VYexSEcYuFbGyfsX19dcHrotHh/+p6NWfurLrcaQVxS27VISxS0UYu1SEsUtFGLtUhLFLRRi7VMSKOs7++mcHHyu/4D+PD33sm+f5R3JUm1t2qQhjl4owdqkIY5eKMHapCGOXijB2qYgVdZz91CWrBq77y6+9OPSx/33/RV2PI60obtmlIoxdKsLYpSKMXSrC2KUijF0qwtilIlbUcfaLvv3UwHX3/fFfDH3slh/+z9D1Z0eaSFo5ltyyR8TuiDgREYcXLNsYEfsi4rnmcsN4x5TU1nLexn8HuP6cZbcD+zPzCmB/c1vSFFsy9sx8GDh5zuLtwJ7m+h7gxo7nktSxUT+zb8rMY831l4FNg+4YEbPALMA6zh/x6SS11XpvfGYmMPCvOWbmrsycycyZNaxt+3SSRjRq7McjYjNAc3miu5EkjcOose8FdjbXdwIPdjOOpHGJ+XfhQ+4QcS9wDXAhcBz4OvBT4MfAR4EXgZsy89ydeO/ywdiYV8a1LUeWNMhjuZ9X82Qstm7JHXSZuWPAKquVVhC/LisVYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRS8YeEbsj4kREHF6w7I6IOBoRB5qfG8Y7pqS2lrNl/w5w/SLL78rMrc3PQ92OJalrS8aemQ8DJycwi6QxavOZ/baIONi8zd8w6E4RMRsRcxExd4bTLZ5OUhujxn438HFgK3AM+MagO2bmrsycycyZNawd8ekktTVS7Jl5PDPfzMy3gG8B27odS1LXRoo9IjYvuPkZ4PCg+0qaDquXukNE3AtcA1wYES8BXweuiYitQAJHgC+McUZJHVgy9szcscjie8Ywi6Qx8ht0UhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFbFk7BFxaUT8PCKeiYinI+JLzfKNEbEvIp5rLjeMf1xJo1rOlv0s8JXM3AJcBdwaEVuA24H9mXkFsL+5LWlKLRl7Zh7LzCeb66eAZ4GLge3AnuZue4AbxzWkpPZWv5c7R8RlwCeAx4BNmXmsWfUysGnAY2aBWYB1nD/qnJJaWvYOuoh4P/AT4MuZ+erCdZmZQC72uMzclZkzmTmzhrWthpU0umXFHhFrmA/9B5l5f7P4eERsbtZvBk6MZ0RJXVjO3vgA7gGezcxvLli1F9jZXN8JPNj9eJK6spzP7J8EbgYORcSBZtlXgTuBH0fE54EXgZvGM6KkLiwZe2Y+AsSA1dd2O46kcfEbdFIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHLOT/7pRHx84h4JiKejogvNcvviIijEXGg+blh/ONKGtVyzs9+FvhKZj4ZER8AnoiIfc26uzLzH8c3nqSuLOf87MeAY831UxHxLHDxuAeT1K339Jk9Ii4DPgE81iy6LSIORsTuiNgw4DGzETEXEXNnON1qWEmjW3bsEfF+4CfAlzPzVeBu4OPAVua3/N9Y7HGZuSszZzJzZg1rOxhZ0iiWFXtErGE+9B9k5v0AmXk8M9/MzLeAbwHbxjempLaWszc+gHuAZzPzmwuWb15wt88Ah7sfT1JXlrM3/pPAzcChiDjQLPsqsCMitgIJHAG+MJYJJXViOXvjHwFikVUPdT+OpHHxG3RSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFWHsUhHGLhVh7FIRxi4VYexSEcYuFRGZObkni/hf4MUFiy4EXpnYAO/NtM42rXOBs42qy9n+KDM/vNiKicb+riePmMvMmd4GGGJaZ5vWucDZRjWp2XwbLxVh7FIRfce+q+fnH2ZaZ5vWucDZRjWR2Xr9zC5pcvreskuaEGOXiugl9oi4PiJ+FRHPR8TtfcwwSEQciYhDzWmo53qeZXdEnIiIwwuWbYyIfRHxXHO56Dn2epptKk7jPeQ0472+dn2f/nzin9kjYhXwa+BTwEvA48COzHxmooMMEBFHgJnM7P0LGBHxV8BrwHcz80+aZf8AnMzMO5v/UW7IzL+bktnuAF7r+zTezdmKNi88zThwI3ALPb52Q+a6iQm8bn1s2bcBz2fmC5n5BvAjYHsPc0y9zHwYOHnO4u3Anub6Hub/Y5m4AbNNhcw8lplPNtdPAW+fZrzX127IXBPRR+wXA79ZcPslput87wn8LCKeiIjZvodZxKbMPNZcfxnY1Ocwi1jyNN6TdM5pxqfmtRvl9OdtuYPu3a7OzD8HPg3c2rxdnUo5/xlsmo6dLus03pOyyGnGf6/P127U05+31UfsR4FLF9y+pFk2FTLzaHN5AniA6TsV9fG3z6DbXJ7oeZ7fm6bTeC92mnGm4LXr8/TnfcT+OHBFRFweEecBnwP29jDHu0TE+mbHCRGxHriO6TsV9V5gZ3N9J/Bgj7O8w7ScxnvQacbp+bXr/fTnmTnxH+AG5vfI/xfwtT5mGDDXx4Cnmp+n+54NuJf5t3VnmN+38XngQ8B+4DngP4CNUzTb94BDwEHmw9rc02xXM/8W/SBwoPm5oe/XbshcE3nd/LqsVIQ76KQijF0qwtilIoxdKsLYpSKMXSrC2KUi/g+c6njxmLUDLAAAAABJRU5ErkJggg==\n",
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ }
+ }
+ ]
+ }
+ ]
+} \ No newline at end of file