diff --git a/Project_CNN.ipynb b/Project_CNN.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..6a61ed44847d21de482ca4f64af37fb430f4fe1e --- /dev/null +++ b/Project_CNN.ipynb @@ -0,0 +1,499 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# 0. Imports" + ], + "metadata": { + "id": "s88AvoLByyvb" + } + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "ihpUQFaYvpQW" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "from torch.utils.data import DataLoader\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torchvision.models import resnet18" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 1. Adat-előkészítés" + ], + "metadata": { + "id": "Z3fiYQZwy4wA" + } + }, + { + "cell_type": "code", + "source": [ + "train_transform = transforms.Compose([\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.RandomCrop(32, padding=4),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", + "])\n", + "\n", + "test_transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", + "])\n", + "\n", + "trainset = torchvision.datasets.CIFAR10(\n", + " root='./data',\n", + " train=True,\n", + " download=True,\n", + " transform=train_transform\n", + ")\n", + "\n", + "testset = torchvision.datasets.CIFAR10(\n", + " root='./data',\n", + " train=False,\n", + " download=True,\n", + " transform=test_transform\n", + ")\n", + "\n", + "trainloader = DataLoader(trainset, batch_size=64, shuffle=True)\n", + "testloader = DataLoader(testset, batch_size=64, shuffle=False)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6yReCwY_zHxd", + "outputId": "d8810cca-c44b-4179-98ae-36e4f5bdbd4b" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 170M/170M [00:07<00:00, 23.5MB/s]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 2. Modell architektúra" + ], + "metadata": { + "id": "F7MFKJWzL90w" + } + }, + { + "cell_type": "code", + "source": [ + "class CIFAR10_ResNet(nn.Module):\n", + " def __init__(self, pretrained=True):\n", + " super().__init__()\n", + " self.model = resnet18(weights='DEFAULT' if pretrained else None)\n", + "\n", + " # CIFAR-10 kompatibilis módosítások\n", + " self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n", + " self.model.maxpool = nn.Identity() # Eltávolítjuk a max pooling réteget\n", + " self.model.fc = nn.Linear(512, 10) # 10 osztály\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)" + ], + "metadata": { + "id": "0QBmgWSRLzWO" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# 3. Modell inicializálása" + ], + "metadata": { + "id": "fY_G06EKMCbl" + } + }, + { + "cell_type": "code", + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "model = CIFAR10_ResNet(pretrained=True).to(device)\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = nn.CrossEntropyLoss()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vjNrPunVMOsb", + "outputId": "f9882de9-08d1-4b7c-8e17-168690cf234d" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n", + "100%|██████████| 44.7M/44.7M [00:00<00:00, 189MB/s]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 4. Tanítási ciklus" + ], + "metadata": { + "id": "sJlf_EPNMRqS" + } + }, + { + "cell_type": "code", + "source": [ + "for epoch in range(30):\n", + " model.train()\n", + " running_loss = 0.0\n", + "\n", + " for i, (inputs, labels) in enumerate(trainloader):\n", + " inputs, labels = inputs.to(device), labels.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(inputs)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " running_loss += loss.item()\n", + " if i % 100 == 99:\n", + " print(f'Epoch: {epoch+1}, Batch: {i+1}, Loss: {running_loss/100:.4f}')\n", + " running_loss = 0.0\n", + "\n", + " # Tesztelés epoch végén\n", + " model.eval()\n", + " correct = 0\n", + " total = 0\n", + "\n", + " with torch.no_grad():\n", + " for inputs, labels in testloader:\n", + " inputs, labels = inputs.to(device), labels.to(device)\n", + " outputs = model(inputs)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + "\n", + " print(f'Epoch {epoch+1} Tesztpontosság: {100 * correct / total:.2f}%')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0D9j4DQIMR_8", + "outputId": "f10739b6-fbe0-4481-91b7-df7b3848ffc7" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 1, Batch: 100, Loss: 1.4680\n", + "Epoch: 1, Batch: 200, Loss: 0.9620\n", + "Epoch: 1, Batch: 300, Loss: 0.8665\n", + "Epoch: 1, Batch: 400, Loss: 0.7700\n", + "Epoch: 1, Batch: 500, Loss: 0.7240\n", + "Epoch: 1, Batch: 600, Loss: 0.6732\n", + "Epoch: 1, Batch: 700, Loss: 0.6179\n", + "Epoch 1 Tesztpontosság: 77.67%\n", + "Epoch: 2, Batch: 100, Loss: 0.5333\n", + "Epoch: 2, Batch: 200, Loss: 0.5511\n", + "Epoch: 2, Batch: 300, Loss: 0.5397\n", + "Epoch: 2, Batch: 400, Loss: 0.5328\n", + "Epoch: 2, Batch: 500, Loss: 0.4696\n", + "Epoch: 2, Batch: 600, Loss: 0.4975\n", + "Epoch: 2, Batch: 700, Loss: 0.4947\n", + "Epoch 2 Tesztpontosság: 82.47%\n", + "Epoch: 3, Batch: 100, Loss: 0.4422\n", + "Epoch: 3, Batch: 200, Loss: 0.4207\n", + "Epoch: 3, Batch: 300, Loss: 0.4368\n", + "Epoch: 3, Batch: 400, Loss: 0.4378\n", + "Epoch: 3, Batch: 500, Loss: 0.4034\n", + "Epoch: 3, Batch: 600, Loss: 0.4169\n", + "Epoch: 3, Batch: 700, Loss: 0.4098\n", + "Epoch 3 Tesztpontosság: 87.10%\n", + "Epoch: 4, Batch: 100, Loss: 0.3536\n", + "Epoch: 4, Batch: 200, Loss: 0.3616\n", + "Epoch: 4, Batch: 300, Loss: 0.3752\n", + "Epoch: 4, Batch: 400, Loss: 0.3411\n", + "Epoch: 4, Batch: 500, Loss: 0.3639\n", + "Epoch: 4, Batch: 600, Loss: 0.3431\n", + "Epoch: 4, Batch: 700, Loss: 0.3496\n", + "Epoch 4 Tesztpontosság: 86.70%\n", + "Epoch: 5, Batch: 100, Loss: 0.3108\n", + "Epoch: 5, Batch: 200, Loss: 0.3237\n", + "Epoch: 5, Batch: 300, Loss: 0.3030\n", + "Epoch: 5, Batch: 400, Loss: 0.3255\n", + "Epoch: 5, Batch: 500, Loss: 0.3062\n", + "Epoch: 5, Batch: 600, Loss: 0.3203\n", + "Epoch: 5, Batch: 700, Loss: 0.3354\n", + "Epoch 5 Tesztpontosság: 88.46%\n", + "Epoch: 6, Batch: 100, Loss: 0.2724\n", + "Epoch: 6, Batch: 200, Loss: 0.2630\n", + "Epoch: 6, Batch: 300, Loss: 0.2843\n", + "Epoch: 6, Batch: 400, Loss: 0.2824\n", + "Epoch: 6, Batch: 500, Loss: 0.2750\n", + "Epoch: 6, Batch: 600, Loss: 0.2951\n", + "Epoch: 6, Batch: 700, Loss: 0.2883\n", + "Epoch 6 Tesztpontosság: 88.35%\n", + "Epoch: 7, Batch: 100, Loss: 0.2395\n", + "Epoch: 7, Batch: 200, Loss: 0.2622\n", + "Epoch: 7, Batch: 300, Loss: 0.2538\n", + "Epoch: 7, Batch: 400, Loss: 0.2578\n", + "Epoch: 7, Batch: 500, Loss: 0.2514\n", + "Epoch: 7, Batch: 600, Loss: 0.2480\n", + "Epoch: 7, Batch: 700, Loss: 0.2719\n", + "Epoch 7 Tesztpontosság: 89.58%\n", + "Epoch: 8, Batch: 100, Loss: 0.2339\n", + "Epoch: 8, Batch: 200, Loss: 0.2230\n", + "Epoch: 8, Batch: 300, Loss: 0.2099\n", + "Epoch: 8, Batch: 400, Loss: 0.2184\n", + "Epoch: 8, Batch: 500, Loss: 0.2406\n", + "Epoch: 8, Batch: 600, Loss: 0.2465\n", + "Epoch: 8, Batch: 700, Loss: 0.2453\n", + "Epoch 8 Tesztpontosság: 89.00%\n", + "Epoch: 9, Batch: 100, Loss: 0.2055\n", + "Epoch: 9, Batch: 200, Loss: 0.1891\n", + "Epoch: 9, Batch: 300, Loss: 0.2235\n", + "Epoch: 9, Batch: 400, Loss: 0.2157\n", + "Epoch: 9, Batch: 500, Loss: 0.2056\n", + "Epoch: 9, Batch: 600, Loss: 0.2027\n", + "Epoch: 9, Batch: 700, Loss: 0.2140\n", + "Epoch 9 Tesztpontosság: 91.07%\n", + "Epoch: 10, Batch: 100, Loss: 0.1896\n", + "Epoch: 10, Batch: 200, Loss: 0.1978\n", + "Epoch: 10, Batch: 300, Loss: 0.1876\n", + "Epoch: 10, Batch: 400, Loss: 0.1940\n", + "Epoch: 10, Batch: 500, Loss: 0.1812\n", + "Epoch: 10, Batch: 600, Loss: 0.2067\n", + "Epoch: 10, Batch: 700, Loss: 0.2033\n", + "Epoch 10 Tesztpontosság: 90.68%\n", + "Epoch: 11, Batch: 100, Loss: 0.1759\n", + "Epoch: 11, Batch: 200, Loss: 0.1731\n", + "Epoch: 11, Batch: 300, Loss: 0.1723\n", + "Epoch: 11, Batch: 400, Loss: 0.1741\n", + "Epoch: 11, Batch: 500, Loss: 0.1707\n", + "Epoch: 11, Batch: 600, Loss: 0.1888\n", + "Epoch: 11, Batch: 700, Loss: 0.1947\n", + "Epoch 11 Tesztpontosság: 88.09%\n", + "Epoch: 12, Batch: 100, Loss: 0.1695\n", + "Epoch: 12, Batch: 200, Loss: 0.1626\n", + "Epoch: 12, Batch: 300, Loss: 0.1641\n", + "Epoch: 12, Batch: 400, Loss: 0.1712\n", + "Epoch: 12, Batch: 500, Loss: 0.1614\n", + "Epoch: 12, Batch: 600, Loss: 0.1699\n", + "Epoch: 12, Batch: 700, Loss: 0.1678\n", + "Epoch 12 Tesztpontosság: 90.86%\n", + "Epoch: 13, Batch: 100, Loss: 0.1418\n", + "Epoch: 13, Batch: 200, Loss: 0.1360\n", + "Epoch: 13, Batch: 300, Loss: 0.1517\n", + "Epoch: 13, Batch: 400, Loss: 0.1409\n", + "Epoch: 13, Batch: 500, Loss: 0.1710\n", + "Epoch: 13, Batch: 600, Loss: 0.1737\n", + "Epoch: 13, Batch: 700, Loss: 0.1523\n", + "Epoch 13 Tesztpontosság: 91.47%\n", + "Epoch: 14, Batch: 100, Loss: 0.1171\n", + "Epoch: 14, Batch: 200, Loss: 0.1287\n", + "Epoch: 14, Batch: 300, Loss: 0.1397\n", + "Epoch: 14, Batch: 400, Loss: 0.1522\n", + "Epoch: 14, Batch: 500, Loss: 0.1401\n", + "Epoch: 14, Batch: 600, Loss: 0.1544\n", + "Epoch: 14, Batch: 700, Loss: 0.1376\n", + "Epoch 14 Tesztpontosság: 91.91%\n", + "Epoch: 15, Batch: 100, Loss: 0.1405\n", + "Epoch: 15, Batch: 200, Loss: 0.1327\n", + "Epoch: 15, Batch: 300, Loss: 0.1288\n", + "Epoch: 15, Batch: 400, Loss: 0.1285\n", + "Epoch: 15, Batch: 500, Loss: 0.1299\n", + "Epoch: 15, Batch: 600, Loss: 0.1522\n", + "Epoch: 15, Batch: 700, Loss: 0.1458\n", + "Epoch 15 Tesztpontosság: 91.92%\n", + "Epoch: 16, Batch: 100, Loss: 0.1140\n", + "Epoch: 16, Batch: 200, Loss: 0.1281\n", + "Epoch: 16, Batch: 300, Loss: 0.1214\n", + "Epoch: 16, Batch: 400, Loss: 0.1279\n", + "Epoch: 16, Batch: 500, Loss: 0.1150\n", + "Epoch: 16, Batch: 600, Loss: 0.1293\n", + "Epoch: 16, Batch: 700, Loss: 0.1251\n", + "Epoch 16 Tesztpontosság: 91.35%\n", + "Epoch: 17, Batch: 100, Loss: 0.1142\n", + "Epoch: 17, Batch: 200, Loss: 0.1135\n", + "Epoch: 17, Batch: 300, Loss: 0.1085\n", + "Epoch: 17, Batch: 400, Loss: 0.1151\n", + "Epoch: 17, Batch: 500, Loss: 0.1299\n", + "Epoch: 17, Batch: 600, Loss: 0.1201\n", + "Epoch: 17, Batch: 700, Loss: 0.1175\n", + "Epoch 17 Tesztpontosság: 91.71%\n", + "Epoch: 18, Batch: 100, Loss: 0.1007\n", + "Epoch: 18, Batch: 200, Loss: 0.1121\n", + "Epoch: 18, Batch: 300, Loss: 0.1097\n", + "Epoch: 18, Batch: 400, Loss: 0.1106\n", + "Epoch: 18, Batch: 500, Loss: 0.0980\n", + "Epoch: 18, Batch: 600, Loss: 0.1032\n", + "Epoch: 18, Batch: 700, Loss: 0.1204\n", + "Epoch 18 Tesztpontosság: 91.99%\n", + "Epoch: 19, Batch: 100, Loss: 0.1125\n", + "Epoch: 19, Batch: 200, Loss: 0.0860\n", + "Epoch: 19, Batch: 300, Loss: 0.1014\n", + "Epoch: 19, Batch: 400, Loss: 0.1171\n", + "Epoch: 19, Batch: 500, Loss: 0.1103\n", + "Epoch: 19, Batch: 600, Loss: 0.1121\n", + "Epoch: 19, Batch: 700, Loss: 0.1062\n", + "Epoch 19 Tesztpontosság: 92.58%\n", + "Epoch: 20, Batch: 100, Loss: 0.0826\n", + "Epoch: 20, Batch: 200, Loss: 0.0860\n", + "Epoch: 20, Batch: 300, Loss: 0.0905\n", + "Epoch: 20, Batch: 400, Loss: 0.0881\n", + "Epoch: 20, Batch: 500, Loss: 0.0938\n", + "Epoch: 20, Batch: 600, Loss: 0.0906\n", + "Epoch: 20, Batch: 700, Loss: 0.1004\n", + "Epoch 20 Tesztpontosság: 90.97%\n", + "Epoch: 21, Batch: 100, Loss: 0.0998\n", + "Epoch: 21, Batch: 200, Loss: 0.1026\n", + "Epoch: 21, Batch: 300, Loss: 0.0826\n", + "Epoch: 21, Batch: 400, Loss: 0.0920\n", + "Epoch: 21, Batch: 500, Loss: 0.1117\n", + "Epoch: 21, Batch: 600, Loss: 0.0951\n", + "Epoch: 21, Batch: 700, Loss: 0.0983\n", + "Epoch 21 Tesztpontosság: 92.63%\n", + "Epoch: 22, Batch: 100, Loss: 0.0791\n", + "Epoch: 22, Batch: 200, Loss: 0.0790\n", + "Epoch: 22, Batch: 300, Loss: 0.0890\n", + "Epoch: 22, Batch: 400, Loss: 0.0817\n", + "Epoch: 22, Batch: 500, Loss: 0.0875\n", + "Epoch: 22, Batch: 600, Loss: 0.0988\n", + "Epoch: 22, Batch: 700, Loss: 0.0929\n", + "Epoch 22 Tesztpontosság: 91.48%\n", + "Epoch: 23, Batch: 100, Loss: 0.0729\n", + "Epoch: 23, Batch: 200, Loss: 0.0762\n", + "Epoch: 23, Batch: 300, Loss: 0.0842\n", + "Epoch: 23, Batch: 400, Loss: 0.0796\n", + "Epoch: 23, Batch: 500, Loss: 0.0856\n", + "Epoch: 23, Batch: 600, Loss: 0.0948\n", + "Epoch: 23, Batch: 700, Loss: 0.0938\n", + "Epoch 23 Tesztpontosság: 92.47%\n", + "Epoch: 24, Batch: 100, Loss: 0.0694\n", + "Epoch: 24, Batch: 200, Loss: 0.0818\n", + "Epoch: 24, Batch: 300, Loss: 0.0998\n", + "Epoch: 24, Batch: 400, Loss: 0.0775\n", + "Epoch: 24, Batch: 500, Loss: 0.0724\n", + "Epoch: 24, Batch: 600, Loss: 0.0851\n", + "Epoch: 24, Batch: 700, Loss: 0.0796\n", + "Epoch 24 Tesztpontosság: 91.97%\n", + "Epoch: 25, Batch: 100, Loss: 0.0762\n", + "Epoch: 25, Batch: 200, Loss: 0.0717\n", + "Epoch: 25, Batch: 300, Loss: 0.0770\n", + "Epoch: 25, Batch: 400, Loss: 0.0848\n", + "Epoch: 25, Batch: 500, Loss: 0.0819\n", + "Epoch: 25, Batch: 600, Loss: 0.0759\n", + "Epoch: 25, Batch: 700, Loss: 0.0759\n", + "Epoch 25 Tesztpontosság: 92.07%\n", + "Epoch: 26, Batch: 100, Loss: 0.0739\n", + "Epoch: 26, Batch: 200, Loss: 0.0572\n", + "Epoch: 26, Batch: 300, Loss: 0.0765\n", + "Epoch: 26, Batch: 400, Loss: 0.0729\n", + "Epoch: 26, Batch: 500, Loss: 0.0710\n", + "Epoch: 26, Batch: 600, Loss: 0.0701\n", + "Epoch: 26, Batch: 700, Loss: 0.0768\n", + "Epoch 26 Tesztpontosság: 92.46%\n", + "Epoch: 27, Batch: 100, Loss: 0.0641\n", + "Epoch: 27, Batch: 200, Loss: 0.0715\n", + "Epoch: 27, Batch: 300, Loss: 0.0572\n", + "Epoch: 27, Batch: 400, Loss: 0.0766\n", + "Epoch: 27, Batch: 500, Loss: 0.0809\n", + "Epoch: 27, Batch: 600, Loss: 0.0635\n", + "Epoch: 27, Batch: 700, Loss: 0.0622\n", + "Epoch 27 Tesztpontosság: 92.40%\n", + "Epoch: 28, Batch: 100, Loss: 0.0683\n", + "Epoch: 28, Batch: 200, Loss: 0.0676\n", + "Epoch: 28, Batch: 300, Loss: 0.0878\n", + "Epoch: 28, Batch: 400, Loss: 0.0681\n", + "Epoch: 28, Batch: 500, Loss: 0.0660\n", + "Epoch: 28, Batch: 600, Loss: 0.0616\n", + "Epoch: 28, Batch: 700, Loss: 0.0702\n", + "Epoch 28 Tesztpontosság: 92.49%\n", + "Epoch: 29, Batch: 100, Loss: 0.0736\n", + "Epoch: 29, Batch: 200, Loss: 0.0594\n", + "Epoch: 29, Batch: 300, Loss: 0.0706\n", + "Epoch: 29, Batch: 400, Loss: 0.0659\n", + "Epoch: 29, Batch: 500, Loss: 0.0609\n", + "Epoch: 29, Batch: 600, Loss: 0.0652\n", + "Epoch: 29, Batch: 700, Loss: 0.0598\n", + "Epoch 29 Tesztpontosság: 91.98%\n", + "Epoch: 30, Batch: 100, Loss: 0.0463\n", + "Epoch: 30, Batch: 200, Loss: 0.0531\n", + "Epoch: 30, Batch: 300, Loss: 0.0564\n", + "Epoch: 30, Batch: 400, Loss: 0.0621\n", + "Epoch: 30, Batch: 500, Loss: 0.0658\n", + "Epoch: 30, Batch: 600, Loss: 0.0604\n", + "Epoch: 30, Batch: 700, Loss: 0.0585\n", + "Epoch 30 Tesztpontosság: 92.40%\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Mentés" + ], + "metadata": { + "id": "NXIh9r4mXKti" + } + }, + { + "cell_type": "code", + "source": [ + "torch.save({\n", + " 'model_state_dict': model.state_dict(),\n", + " 'conv1_weights': model.model.conv1.weight,\n", + " 'fc_weights': model.model.fc.weight,\n", + " 'fc_bias': model.model.fc.bias\n", + "}, 'cnn_model.pth')" + ], + "metadata": { + "id": "wC2gba2EXRjg" + }, + "execution_count": 6, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/Project_CeNN.ipynb b/Project_CeNN.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..b6e1a45e32ba3c238b7fb5fa0c210d31244f3347 --- /dev/null +++ b/Project_CeNN.ipynb @@ -0,0 +1,742 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "8UDCqMDw_PWm" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "from torch.utils.data import DataLoader\n", + "import torch.nn as nn" + ] + }, + { + "cell_type": "code", + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Használt eszköz: {device}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "iMNqDPPNIzIq", + "outputId": "b9b27e02-9489-4998-b28e-f67df0fd2bd1" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Használt eszköz: cuda:0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 1. Adathalmaz betöltése" + ], + "metadata": { + "id": "JyR4J1W0ALuy" + } + }, + { + "cell_type": "code", + "source": [ + "transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", + "])\n", + "\n", + "trainset = torchvision.datasets.CIFAR10(\n", + " root='./data',\n", + " train=True,\n", + " download=True,\n", + " transform=transform\n", + ")\n", + "\n", + "trainloader = DataLoader(\n", + " trainset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " num_workers=2\n", + ")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JGYQXMp6_hzN", + "outputId": "4fdd0075-1820-4bcb-884e-cfb46ca7a5af" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 170M/170M [00:12<00:00, 13.3MB/s]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 2. CeNN modell osztály" + ], + "metadata": { + "id": "hw8Y_wi_AaDM" + } + }, + { + "cell_type": "code", + "source": [ + "class CeNN(nn.Module):\n", + " def __init__(self, h=0.1, iter_num=20, alpha=0.01):\n", + " super().__init__()\n", + " self.h = h\n", + " self.iter_num = iter_num\n", + " self.alpha = alpha\n", + "\n", + " # A: visszacsatoló kernel\n", + " self.A = nn.Parameter(torch.randn(1, 3, 3, 3)) # 3 input csatorna (RGB)\n", + " # B: előrecsatoló kernel\n", + " self.B = nn.Parameter(torch.randn(1, 3, 3, 3)) # 3 input csatorna\n", + " # Z: bias\n", + " self.Z = nn.Parameter(torch.randn(1))\n", + "\n", + " def forward(self, U):\n", + " x = U # Kezdeti állapot\n", + " for _ in range(self.iter_num):\n", + " # Aktivációs függvény\n", + " y = torch.minimum(x, 1 + self.alpha * x)\n", + " y = torch.maximum(y, -1 + self.alpha * y)\n", + "\n", + " # Konvolúciós számítások\n", + " fwd = torch.nn.functional.conv2d(U, self.B, padding=1) + self.Z\n", + " bwd = torch.nn.functional.conv2d(y, self.A, padding=1)\n", + "\n", + " # Euler-lépés\n", + " x = x + self.h * (-x + bwd + fwd)\n", + "\n", + " # Végső aktiváció\n", + " out = torch.minimum(x, 1 + self.alpha * x)\n", + " out = torch.maximum(out, -1 + self.alpha * out)\n", + " return out" + ], + "metadata": { + "id": "gYY4Wi04AZrw" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# 3. Modell és optimalizáló inicializálása" + ], + "metadata": { + "id": "xUgxBaQ8BgQK" + } + }, + { + "cell_type": "code", + "source": [ + "model = CeNN(h=0.1, iter_num=20).to(device)\n", + "optimizer = torch.optim.Adam([model.A, model.B, model.Z], lr=0.01)" + ], + "metadata": { + "id": "pGveFopjBhKc" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# 4. Tanítási ciklus" + ], + "metadata": { + "id": "HR1VbUxSCGWI" + } + }, + { + "cell_type": "code", + "source": [ + "for epoch in range(30): # 30 epoch\n", + " for batch_id, (data, labels) in enumerate(trainloader):\n", + " data = data.to(device)\n", + " labels = labels.to(device)\n", + " optimizer.zero_grad()\n", + "\n", + " # Előrehaladás\n", + " outputs = model(data)\n", + "\n", + " # Célértékek generálása (bináris eset: 0. osztály = -1, egyéb = 1)\n", + " expected_out = torch.ones_like(outputs)\n", + " expected_out[labels == 0] = -1\n", + "\n", + " # Hibaszámítás\n", + " loss = torch.mean((outputs - expected_out)**2)\n", + "\n", + " # Visszaterjesztés\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Accuracy számítás\n", + " correct_dist = torch.mean((outputs - expected_out)**2, dim=[1,2,3])\n", + " incorrect_dist = torch.mean((outputs - (-expected_out))**2, dim=[1,2,3])\n", + " accuracy = (correct_dist < incorrect_dist).float().mean()\n", + "\n", + " if batch_id % 50 == 0:\n", + " print(f\"Epoch: {epoch+1}, Batch: {batch_id}, Loss: {loss.item():.4f}, Accuracy: {accuracy.item():.4f}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "h1U7kJnfBjjF", + "outputId": "ac4b384b-541f-4d0b-8af9-0afb6dcdaea4" + }, + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 1, Batch: 0, Loss: 3.2504, Accuracy: 0.0781\n", + "Epoch: 1, Batch: 50, Loss: 1.7928, Accuracy: 0.0938\n", + "Epoch: 1, Batch: 100, Loss: 1.0809, Accuracy: 0.4062\n", + "Epoch: 1, Batch: 150, Loss: 0.9643, Accuracy: 0.6406\n", + "Epoch: 1, Batch: 200, Loss: 0.9468, Accuracy: 0.7344\n", + "Epoch: 1, Batch: 250, Loss: 0.9145, Accuracy: 0.6719\n", + "Epoch: 1, Batch: 300, Loss: 0.7937, Accuracy: 0.7500\n", + "Epoch: 1, Batch: 350, Loss: 0.3365, Accuracy: 0.9219\n", + "Epoch: 1, Batch: 400, Loss: 0.2891, Accuracy: 0.9219\n", + "Epoch: 1, Batch: 450, Loss: 0.6049, Accuracy: 0.7812\n", + "Epoch: 1, Batch: 500, Loss: 0.2799, Accuracy: 0.9219\n", + "Epoch: 1, Batch: 550, Loss: 0.3563, Accuracy: 0.8906\n", + "Epoch: 1, Batch: 600, Loss: 0.4664, Accuracy: 0.8438\n", + "Epoch: 1, Batch: 650, Loss: 0.2913, Accuracy: 0.9219\n", + "Epoch: 1, Batch: 700, Loss: 0.3457, Accuracy: 0.9062\n", + "Epoch: 1, Batch: 750, Loss: 0.3231, Accuracy: 0.8906\n", + "Epoch: 2, Batch: 0, Loss: 0.3062, Accuracy: 0.9219\n", + "Epoch: 2, Batch: 50, Loss: 0.2287, Accuracy: 0.9375\n", + "Epoch: 2, Batch: 100, Loss: 0.1790, Accuracy: 0.9375\n", + "Epoch: 2, Batch: 150, Loss: 0.2975, Accuracy: 0.9219\n", + "Epoch: 2, Batch: 200, Loss: 0.4369, Accuracy: 0.8594\n", + "Epoch: 2, Batch: 250, Loss: 0.2556, Accuracy: 0.9062\n", + "Epoch: 2, Batch: 300, Loss: 0.3841, Accuracy: 0.8906\n", + "Epoch: 2, Batch: 350, Loss: 0.2146, Accuracy: 0.9531\n", + "Epoch: 2, Batch: 400, Loss: 0.3204, Accuracy: 0.9062\n", + "Epoch: 2, Batch: 450, Loss: 0.4400, Accuracy: 0.8438\n", + "Epoch: 2, Batch: 500, Loss: 0.3326, Accuracy: 0.9062\n", + "Epoch: 2, Batch: 550, Loss: 0.3200, Accuracy: 0.9062\n", + "Epoch: 2, Batch: 600, Loss: 0.3158, Accuracy: 0.8906\n", + "Epoch: 2, Batch: 650, Loss: 0.4220, Accuracy: 0.8594\n", + "Epoch: 2, Batch: 700, Loss: 0.4758, Accuracy: 0.8594\n", + "Epoch: 2, Batch: 750, Loss: 0.4251, Accuracy: 0.8750\n", + "Epoch: 3, Batch: 0, Loss: 0.5105, Accuracy: 0.8438\n", + "Epoch: 3, Batch: 50, Loss: 0.1849, Accuracy: 0.9531\n", + "Epoch: 3, Batch: 100, Loss: 0.2802, Accuracy: 0.9219\n", + "Epoch: 3, Batch: 150, Loss: 0.1867, Accuracy: 0.9531\n", + "Epoch: 3, Batch: 200, Loss: 0.2983, Accuracy: 0.9219\n", + "Epoch: 3, Batch: 250, Loss: 0.2900, Accuracy: 0.9219\n", + "Epoch: 3, Batch: 300, Loss: 0.2088, Accuracy: 0.9375\n", + "Epoch: 3, Batch: 350, Loss: 0.2820, Accuracy: 0.9219\n", + "Epoch: 3, Batch: 400, Loss: 0.5264, Accuracy: 0.8438\n", + "Epoch: 3, Batch: 450, Loss: 0.3403, Accuracy: 0.9062\n", + "Epoch: 3, Batch: 500, Loss: 0.4824, Accuracy: 0.8438\n", + "Epoch: 3, Batch: 550, Loss: 0.2307, Accuracy: 0.9219\n", + "Epoch: 3, Batch: 600, Loss: 0.3859, Accuracy: 0.8906\n", + "Epoch: 3, Batch: 650, Loss: 0.3170, Accuracy: 0.9062\n", + "Epoch: 3, Batch: 700, Loss: 0.4403, Accuracy: 0.8438\n", + "Epoch: 3, Batch: 750, Loss: 0.3150, Accuracy: 0.9219\n", + "Epoch: 4, Batch: 0, Loss: 0.3866, Accuracy: 0.8750\n", + "Epoch: 4, Batch: 50, Loss: 0.3570, Accuracy: 0.8906\n", + "Epoch: 4, Batch: 100, Loss: 0.3416, Accuracy: 0.8906\n", + "Epoch: 4, Batch: 150, Loss: 0.4609, Accuracy: 0.8594\n", + "Epoch: 4, Batch: 200, Loss: 0.2462, Accuracy: 0.9531\n", + "Epoch: 4, Batch: 250, Loss: 0.3144, Accuracy: 0.9062\n", + "Epoch: 4, Batch: 300, Loss: 0.2060, Accuracy: 0.9375\n", + "Epoch: 4, Batch: 350, Loss: 0.5053, Accuracy: 0.8438\n", + "Epoch: 4, Batch: 400, Loss: 0.3569, Accuracy: 0.8750\n", + "Epoch: 4, Batch: 450, Loss: 0.3021, Accuracy: 0.9219\n", + "Epoch: 4, Batch: 500, Loss: 0.4802, Accuracy: 0.8594\n", + "Epoch: 4, Batch: 550, Loss: 0.2447, Accuracy: 0.9375\n", + "Epoch: 4, Batch: 600, Loss: 0.3382, Accuracy: 0.9062\n", + "Epoch: 4, Batch: 650, Loss: 0.1233, Accuracy: 0.9688\n", + "Epoch: 4, Batch: 700, Loss: 0.3592, Accuracy: 0.8750\n", + "Epoch: 4, Batch: 750, Loss: 0.2131, Accuracy: 0.9375\n", + "Epoch: 5, Batch: 0, Loss: 0.1955, Accuracy: 0.9531\n", + "Epoch: 5, Batch: 50, Loss: 0.2461, Accuracy: 0.9219\n", + "Epoch: 5, Batch: 100, Loss: 0.2585, Accuracy: 0.9375\n", + "Epoch: 5, Batch: 150, Loss: 0.3313, Accuracy: 0.9062\n", + "Epoch: 5, Batch: 200, Loss: 0.1465, Accuracy: 0.9688\n", + "Epoch: 5, Batch: 250, Loss: 0.2710, Accuracy: 0.9219\n", + "Epoch: 5, Batch: 300, Loss: 0.1964, Accuracy: 0.9531\n", + "Epoch: 5, Batch: 350, Loss: 0.5045, Accuracy: 0.8438\n", + "Epoch: 5, Batch: 400, Loss: 0.3234, Accuracy: 0.9062\n", + "Epoch: 5, Batch: 450, Loss: 0.4351, Accuracy: 0.8594\n", + "Epoch: 5, Batch: 500, Loss: 0.4981, Accuracy: 0.8438\n", + "Epoch: 5, Batch: 550, Loss: 0.3685, Accuracy: 0.8906\n", + "Epoch: 5, Batch: 600, Loss: 0.3898, Accuracy: 0.8906\n", + "Epoch: 5, Batch: 650, Loss: 0.2581, Accuracy: 0.9219\n", + "Epoch: 5, Batch: 700, Loss: 0.3540, Accuracy: 0.9062\n", + "Epoch: 5, Batch: 750, Loss: 0.6129, Accuracy: 0.8125\n", + "Epoch: 6, Batch: 0, Loss: 0.2730, Accuracy: 0.9219\n", + "Epoch: 6, Batch: 50, Loss: 0.3072, Accuracy: 0.9219\n", + "Epoch: 6, Batch: 100, Loss: 0.3900, Accuracy: 0.8750\n", + "Epoch: 6, Batch: 150, Loss: 0.2398, Accuracy: 0.9375\n", + "Epoch: 6, Batch: 200, Loss: 0.2062, Accuracy: 0.9531\n", + "Epoch: 6, Batch: 250, Loss: 0.5174, Accuracy: 0.8438\n", + "Epoch: 6, Batch: 300, Loss: 0.2991, Accuracy: 0.9062\n", + "Epoch: 6, Batch: 350, Loss: 0.4451, Accuracy: 0.8594\n", + "Epoch: 6, Batch: 400, Loss: 0.2166, Accuracy: 0.9531\n", + "Epoch: 6, Batch: 450, Loss: 0.4295, Accuracy: 0.8594\n", + "Epoch: 6, Batch: 500, Loss: 0.2640, Accuracy: 0.9219\n", + "Epoch: 6, Batch: 550, Loss: 0.3824, Accuracy: 0.8906\n", + "Epoch: 6, Batch: 600, Loss: 0.3471, Accuracy: 0.8906\n", + "Epoch: 6, Batch: 650, Loss: 0.3431, Accuracy: 0.9219\n", + "Epoch: 6, Batch: 700, Loss: 0.2478, Accuracy: 0.9375\n", + "Epoch: 6, Batch: 750, Loss: 0.2856, Accuracy: 0.9219\n", + "Epoch: 7, Batch: 0, Loss: 0.2095, Accuracy: 0.9375\n", + "Epoch: 7, Batch: 50, Loss: 0.2560, Accuracy: 0.9219\n", + "Epoch: 7, Batch: 100, Loss: 0.4533, Accuracy: 0.8594\n", + "Epoch: 7, Batch: 150, Loss: 0.1734, Accuracy: 0.9531\n", + "Epoch: 7, Batch: 200, Loss: 0.3265, Accuracy: 0.9062\n", + "Epoch: 7, Batch: 250, Loss: 0.2912, Accuracy: 0.9062\n", + "Epoch: 7, Batch: 300, Loss: 0.2897, Accuracy: 0.9062\n", + "Epoch: 7, Batch: 350, Loss: 0.3915, Accuracy: 0.8906\n", + "Epoch: 7, Batch: 400, Loss: 0.4599, Accuracy: 0.8594\n", + "Epoch: 7, Batch: 450, Loss: 0.2889, Accuracy: 0.9219\n", + "Epoch: 7, Batch: 500, Loss: 0.3306, Accuracy: 0.9062\n", + "Epoch: 7, Batch: 550, Loss: 0.2967, Accuracy: 0.9375\n", + "Epoch: 7, Batch: 600, Loss: 0.3698, Accuracy: 0.8906\n", + "Epoch: 7, Batch: 650, Loss: 0.2918, Accuracy: 0.9062\n", + "Epoch: 7, Batch: 700, Loss: 0.1363, Accuracy: 0.9688\n", + "Epoch: 7, Batch: 750, Loss: 0.2358, Accuracy: 0.9375\n", + "Epoch: 8, Batch: 0, Loss: 0.4078, Accuracy: 0.8906\n", + "Epoch: 8, Batch: 50, Loss: 0.2609, Accuracy: 0.9219\n", + "Epoch: 8, Batch: 100, Loss: 0.3632, Accuracy: 0.9062\n", + "Epoch: 8, Batch: 150, Loss: 0.2642, Accuracy: 0.9219\n", + "Epoch: 8, Batch: 200, Loss: 0.3594, Accuracy: 0.9062\n", + "Epoch: 8, Batch: 250, Loss: 0.2832, Accuracy: 0.9219\n", + "Epoch: 8, Batch: 300, Loss: 0.2071, Accuracy: 0.9531\n", + "Epoch: 8, Batch: 350, Loss: 0.3322, Accuracy: 0.8906\n", + "Epoch: 8, Batch: 400, Loss: 0.1634, Accuracy: 0.9688\n", + "Epoch: 8, Batch: 450, Loss: 0.3029, Accuracy: 0.9062\n", + "Epoch: 8, Batch: 500, Loss: 0.2814, Accuracy: 0.9219\n", + "Epoch: 8, Batch: 550, Loss: 0.3699, Accuracy: 0.8906\n", + "Epoch: 8, Batch: 600, Loss: 0.3284, Accuracy: 0.9062\n", + "Epoch: 8, Batch: 650, Loss: 0.3867, Accuracy: 0.8750\n", + "Epoch: 8, Batch: 700, Loss: 0.1809, Accuracy: 0.9531\n", + "Epoch: 8, Batch: 750, Loss: 0.3607, Accuracy: 0.8906\n", + "Epoch: 9, Batch: 0, Loss: 0.2970, Accuracy: 0.9219\n", + "Epoch: 9, Batch: 50, Loss: 0.6045, Accuracy: 0.8125\n", + "Epoch: 9, Batch: 100, Loss: 0.4143, Accuracy: 0.8750\n", + "Epoch: 9, Batch: 150, Loss: 0.3316, Accuracy: 0.8906\n", + "Epoch: 9, Batch: 200, Loss: 0.2525, Accuracy: 0.9375\n", + "Epoch: 9, Batch: 250, Loss: 0.3518, Accuracy: 0.9062\n", + "Epoch: 9, Batch: 300, Loss: 0.3512, Accuracy: 0.8906\n", + "Epoch: 9, Batch: 350, Loss: 0.3825, Accuracy: 0.8750\n", + "Epoch: 9, Batch: 400, Loss: 0.2434, Accuracy: 0.9375\n", + "Epoch: 9, Batch: 450, Loss: 0.5630, Accuracy: 0.8281\n", + "Epoch: 9, Batch: 500, Loss: 0.2942, Accuracy: 0.9219\n", + "Epoch: 9, Batch: 550, Loss: 0.1786, Accuracy: 0.9531\n", + "Epoch: 9, Batch: 600, Loss: 0.2020, Accuracy: 0.9375\n", + "Epoch: 9, Batch: 650, Loss: 0.4735, Accuracy: 0.8594\n", + "Epoch: 9, Batch: 700, Loss: 0.3972, Accuracy: 0.8906\n", + "Epoch: 9, Batch: 750, Loss: 0.4706, Accuracy: 0.8594\n", + "Epoch: 10, Batch: 0, Loss: 0.3030, Accuracy: 0.9219\n", + "Epoch: 10, Batch: 50, Loss: 0.3658, Accuracy: 0.8906\n", + "Epoch: 10, Batch: 100, Loss: 0.3322, Accuracy: 0.8906\n", + "Epoch: 10, Batch: 150, Loss: 0.3684, Accuracy: 0.8906\n", + "Epoch: 10, Batch: 200, Loss: 0.4174, Accuracy: 0.8750\n", + "Epoch: 10, Batch: 250, Loss: 0.3050, Accuracy: 0.9062\n", + "Epoch: 10, Batch: 300, Loss: 0.3083, Accuracy: 0.9062\n", + "Epoch: 10, Batch: 350, Loss: 0.4319, Accuracy: 0.8750\n", + "Epoch: 10, Batch: 400, Loss: 0.6263, Accuracy: 0.7969\n", + "Epoch: 10, Batch: 450, Loss: 0.2850, Accuracy: 0.9219\n", + "Epoch: 10, Batch: 500, Loss: 0.3495, Accuracy: 0.8906\n", + "Epoch: 10, Batch: 550, Loss: 0.2538, Accuracy: 0.9375\n", + "Epoch: 10, Batch: 600, Loss: 0.3938, Accuracy: 0.8750\n", + "Epoch: 10, Batch: 650, Loss: 0.4767, Accuracy: 0.8438\n", + "Epoch: 10, Batch: 700, Loss: 0.2534, Accuracy: 0.9375\n", + "Epoch: 10, Batch: 750, Loss: 0.3790, Accuracy: 0.8906\n", + "Epoch: 11, Batch: 0, Loss: 0.4360, Accuracy: 0.8594\n", + "Epoch: 11, Batch: 50, Loss: 0.3722, Accuracy: 0.8906\n", + "Epoch: 11, Batch: 100, Loss: 0.4026, Accuracy: 0.8906\n", + "Epoch: 11, Batch: 150, Loss: 0.1270, Accuracy: 0.9688\n", + "Epoch: 11, Batch: 200, Loss: 0.3616, Accuracy: 0.8906\n", + "Epoch: 11, Batch: 250, Loss: 0.2308, Accuracy: 0.9375\n", + "Epoch: 11, Batch: 300, Loss: 0.4520, Accuracy: 0.8594\n", + "Epoch: 11, Batch: 350, Loss: 0.4075, Accuracy: 0.8750\n", + "Epoch: 11, Batch: 400, Loss: 0.3222, Accuracy: 0.9062\n", + "Epoch: 11, Batch: 450, Loss: 0.2663, Accuracy: 0.9219\n", + "Epoch: 11, Batch: 500, Loss: 0.1387, Accuracy: 0.9531\n", + "Epoch: 11, Batch: 550, Loss: 0.4702, Accuracy: 0.8594\n", + "Epoch: 11, Batch: 600, Loss: 0.4505, Accuracy: 0.8594\n", + "Epoch: 11, Batch: 650, Loss: 0.4182, Accuracy: 0.8750\n", + "Epoch: 11, Batch: 700, Loss: 0.3478, Accuracy: 0.8906\n", + "Epoch: 11, Batch: 750, Loss: 0.3621, Accuracy: 0.8906\n", + "Epoch: 12, Batch: 0, Loss: 0.2849, Accuracy: 0.9219\n", + "Epoch: 12, Batch: 50, Loss: 0.3110, Accuracy: 0.9062\n", + "Epoch: 12, Batch: 100, Loss: 0.3051, Accuracy: 0.9062\n", + "Epoch: 12, Batch: 150, Loss: 0.2737, Accuracy: 0.9219\n", + "Epoch: 12, Batch: 200, Loss: 0.3742, Accuracy: 0.8750\n", + "Epoch: 12, Batch: 250, Loss: 0.4212, Accuracy: 0.8594\n", + "Epoch: 12, Batch: 300, Loss: 0.2280, Accuracy: 0.9375\n", + "Epoch: 12, Batch: 350, Loss: 0.4784, Accuracy: 0.8594\n", + "Epoch: 12, Batch: 400, Loss: 0.2956, Accuracy: 0.9062\n", + "Epoch: 12, Batch: 450, Loss: 0.2589, Accuracy: 0.9219\n", + "Epoch: 12, Batch: 500, Loss: 0.4105, Accuracy: 0.8906\n", + "Epoch: 12, Batch: 550, Loss: 0.3742, Accuracy: 0.8750\n", + "Epoch: 12, Batch: 600, Loss: 0.4328, Accuracy: 0.8594\n", + "Epoch: 12, Batch: 650, Loss: 0.4948, Accuracy: 0.8438\n", + "Epoch: 12, Batch: 700, Loss: 0.3466, Accuracy: 0.8906\n", + "Epoch: 12, Batch: 750, Loss: 0.2288, Accuracy: 0.9375\n", + "Epoch: 13, Batch: 0, Loss: 0.2461, Accuracy: 0.9375\n", + "Epoch: 13, Batch: 50, Loss: 0.1805, Accuracy: 0.9531\n", + "Epoch: 13, Batch: 100, Loss: 0.2795, Accuracy: 0.9062\n", + "Epoch: 13, Batch: 150, Loss: 0.3512, Accuracy: 0.9062\n", + "Epoch: 13, Batch: 200, Loss: 0.5196, Accuracy: 0.8125\n", + "Epoch: 13, Batch: 250, Loss: 0.4389, Accuracy: 0.8594\n", + "Epoch: 13, Batch: 300, Loss: 0.3456, Accuracy: 0.8906\n", + "Epoch: 13, Batch: 350, Loss: 0.2473, Accuracy: 0.9375\n", + "Epoch: 13, Batch: 400, Loss: 0.3316, Accuracy: 0.9062\n", + "Epoch: 13, Batch: 450, Loss: 0.3315, Accuracy: 0.9062\n", + "Epoch: 13, Batch: 500, Loss: 0.3654, Accuracy: 0.8906\n", + "Epoch: 13, Batch: 550, Loss: 0.2489, Accuracy: 0.9375\n", + "Epoch: 13, Batch: 600, Loss: 0.3257, Accuracy: 0.9062\n", + "Epoch: 13, Batch: 650, Loss: 0.2387, Accuracy: 0.9375\n", + "Epoch: 13, Batch: 700, Loss: 0.5489, Accuracy: 0.8281\n", + "Epoch: 13, Batch: 750, Loss: 0.2914, Accuracy: 0.9219\n", + "Epoch: 14, Batch: 0, Loss: 0.4353, Accuracy: 0.8594\n", + "Epoch: 14, Batch: 50, Loss: 0.3106, Accuracy: 0.9062\n", + "Epoch: 14, Batch: 100, Loss: 0.2361, Accuracy: 0.9219\n", + "Epoch: 14, Batch: 150, Loss: 0.3523, Accuracy: 0.9062\n", + "Epoch: 14, Batch: 200, Loss: 0.3012, Accuracy: 0.9219\n", + "Epoch: 14, Batch: 250, Loss: 0.2116, Accuracy: 0.9375\n", + "Epoch: 14, Batch: 300, Loss: 0.3845, Accuracy: 0.8750\n", + "Epoch: 14, Batch: 350, Loss: 0.4540, Accuracy: 0.8594\n", + "Epoch: 14, Batch: 400, Loss: 0.2243, Accuracy: 0.9375\n", + "Epoch: 14, Batch: 450, Loss: 0.3854, Accuracy: 0.8750\n", + "Epoch: 14, Batch: 500, Loss: 0.2479, Accuracy: 0.9219\n", + "Epoch: 14, Batch: 550, Loss: 0.5586, Accuracy: 0.8281\n", + "Epoch: 14, Batch: 600, Loss: 0.3884, Accuracy: 0.8750\n", + "Epoch: 14, Batch: 650, Loss: 0.5463, Accuracy: 0.8281\n", + "Epoch: 14, Batch: 700, Loss: 0.3759, Accuracy: 0.8906\n", + "Epoch: 14, Batch: 750, Loss: 0.3011, Accuracy: 0.9062\n", + "Epoch: 15, Batch: 0, Loss: 0.6202, Accuracy: 0.8125\n", + "Epoch: 15, Batch: 50, Loss: 0.4044, Accuracy: 0.8906\n", + "Epoch: 15, Batch: 100, Loss: 0.3505, Accuracy: 0.8906\n", + "Epoch: 15, Batch: 150, Loss: 0.2088, Accuracy: 0.9531\n", + "Epoch: 15, Batch: 200, Loss: 0.1554, Accuracy: 0.9688\n", + "Epoch: 15, Batch: 250, Loss: 0.3466, Accuracy: 0.9062\n", + "Epoch: 15, Batch: 300, Loss: 0.1943, Accuracy: 0.9531\n", + "Epoch: 15, Batch: 350, Loss: 0.5770, Accuracy: 0.8281\n", + "Epoch: 15, Batch: 400, Loss: 0.2344, Accuracy: 0.9375\n", + "Epoch: 15, Batch: 450, Loss: 0.4119, Accuracy: 0.8750\n", + "Epoch: 15, Batch: 500, Loss: 0.3551, Accuracy: 0.9062\n", + "Epoch: 15, Batch: 550, Loss: 0.2234, Accuracy: 0.9375\n", + "Epoch: 15, Batch: 600, Loss: 0.2032, Accuracy: 0.9531\n", + "Epoch: 15, Batch: 650, Loss: 0.1824, Accuracy: 0.9531\n", + "Epoch: 15, Batch: 700, Loss: 0.2137, Accuracy: 0.9375\n", + "Epoch: 15, Batch: 750, Loss: 0.3519, Accuracy: 0.8750\n", + "Epoch: 16, Batch: 0, Loss: 0.4850, Accuracy: 0.8594\n", + "Epoch: 16, Batch: 50, Loss: 0.2885, Accuracy: 0.9219\n", + "Epoch: 16, Batch: 100, Loss: 0.2115, Accuracy: 0.9375\n", + "Epoch: 16, Batch: 150, Loss: 0.1637, Accuracy: 0.9531\n", + "Epoch: 16, Batch: 200, Loss: 0.3442, Accuracy: 0.8906\n", + "Epoch: 16, Batch: 250, Loss: 0.4721, Accuracy: 0.8594\n", + "Epoch: 16, Batch: 300, Loss: 0.2668, Accuracy: 0.9219\n", + "Epoch: 16, Batch: 350, Loss: 0.2947, Accuracy: 0.9062\n", + "Epoch: 16, Batch: 400, Loss: 0.4097, Accuracy: 0.8594\n", + "Epoch: 16, Batch: 450, Loss: 0.5181, Accuracy: 0.8438\n", + "Epoch: 16, Batch: 500, Loss: 0.2716, Accuracy: 0.9062\n", + "Epoch: 16, Batch: 550, Loss: 0.2996, Accuracy: 0.9062\n", + "Epoch: 16, Batch: 600, Loss: 0.1031, Accuracy: 0.9844\n", + "Epoch: 16, Batch: 650, Loss: 0.4878, Accuracy: 0.8594\n", + "Epoch: 16, Batch: 700, Loss: 0.2796, Accuracy: 0.9219\n", + "Epoch: 16, Batch: 750, Loss: 0.2015, Accuracy: 0.9531\n", + "Epoch: 17, Batch: 0, Loss: 0.2510, Accuracy: 0.9219\n", + "Epoch: 17, Batch: 50, Loss: 0.2685, Accuracy: 0.9219\n", + "Epoch: 17, Batch: 100, Loss: 0.2255, Accuracy: 0.9375\n", + "Epoch: 17, Batch: 150, Loss: 0.2705, Accuracy: 0.9219\n", + "Epoch: 17, Batch: 200, Loss: 0.3725, Accuracy: 0.8750\n", + "Epoch: 17, Batch: 250, Loss: 0.4283, Accuracy: 0.8750\n", + "Epoch: 17, Batch: 300, Loss: 0.3770, Accuracy: 0.8906\n", + "Epoch: 17, Batch: 350, Loss: 0.2558, Accuracy: 0.9375\n", + "Epoch: 17, Batch: 400, Loss: 0.2779, Accuracy: 0.9219\n", + "Epoch: 17, Batch: 450, Loss: 0.3149, Accuracy: 0.8906\n", + "Epoch: 17, Batch: 500, Loss: 0.1625, Accuracy: 0.9531\n", + "Epoch: 17, Batch: 550, Loss: 0.3953, Accuracy: 0.8594\n", + "Epoch: 17, Batch: 600, Loss: 0.3675, Accuracy: 0.8906\n", + "Epoch: 17, Batch: 650, Loss: 0.4696, Accuracy: 0.8594\n", + "Epoch: 17, Batch: 700, Loss: 0.3609, Accuracy: 0.8906\n", + "Epoch: 17, Batch: 750, Loss: 0.3258, Accuracy: 0.9219\n", + "Epoch: 18, Batch: 0, Loss: 0.3999, Accuracy: 0.8750\n", + "Epoch: 18, Batch: 50, Loss: 0.4544, Accuracy: 0.8906\n", + "Epoch: 18, Batch: 100, Loss: 0.4615, Accuracy: 0.8438\n", + "Epoch: 18, Batch: 150, Loss: 0.2437, Accuracy: 0.9219\n", + "Epoch: 18, Batch: 200, Loss: 0.2197, Accuracy: 0.9531\n", + "Epoch: 18, Batch: 250, Loss: 0.4701, Accuracy: 0.8594\n", + "Epoch: 18, Batch: 300, Loss: 0.3827, Accuracy: 0.8750\n", + "Epoch: 18, Batch: 350, Loss: 0.2714, Accuracy: 0.9219\n", + "Epoch: 18, Batch: 400, Loss: 0.2385, Accuracy: 0.9375\n", + "Epoch: 18, Batch: 450, Loss: 0.2662, Accuracy: 0.9062\n", + "Epoch: 18, Batch: 500, Loss: 0.5095, Accuracy: 0.8438\n", + "Epoch: 18, Batch: 550, Loss: 0.4597, Accuracy: 0.8750\n", + "Epoch: 18, Batch: 600, Loss: 0.2917, Accuracy: 0.9219\n", + "Epoch: 18, Batch: 650, Loss: 0.2518, Accuracy: 0.9219\n", + "Epoch: 18, Batch: 700, Loss: 0.2892, Accuracy: 0.9219\n", + "Epoch: 18, Batch: 750, Loss: 0.3718, Accuracy: 0.8906\n", + "Epoch: 19, Batch: 0, Loss: 0.4317, Accuracy: 0.8750\n", + "Epoch: 19, Batch: 50, Loss: 0.3516, Accuracy: 0.8906\n", + "Epoch: 19, Batch: 100, Loss: 0.4401, Accuracy: 0.8750\n", + "Epoch: 19, Batch: 150, Loss: 0.3609, Accuracy: 0.8906\n", + "Epoch: 19, Batch: 200, Loss: 0.3140, Accuracy: 0.9062\n", + "Epoch: 19, Batch: 250, Loss: 0.4060, Accuracy: 0.8750\n", + "Epoch: 19, Batch: 300, Loss: 0.2689, Accuracy: 0.9219\n", + "Epoch: 19, Batch: 350, Loss: 0.1900, Accuracy: 0.9375\n", + "Epoch: 19, Batch: 400, Loss: 0.3847, Accuracy: 0.8750\n", + "Epoch: 19, Batch: 450, Loss: 0.3713, Accuracy: 0.8906\n", + "Epoch: 19, Batch: 500, Loss: 0.2912, Accuracy: 0.9219\n", + "Epoch: 19, Batch: 550, Loss: 0.3616, Accuracy: 0.8906\n", + "Epoch: 19, Batch: 600, Loss: 0.3331, Accuracy: 0.9062\n", + "Epoch: 19, Batch: 650, Loss: 0.2643, Accuracy: 0.9219\n", + "Epoch: 19, Batch: 700, Loss: 0.3527, Accuracy: 0.9062\n", + "Epoch: 19, Batch: 750, Loss: 0.3679, Accuracy: 0.8750\n", + "Epoch: 20, Batch: 0, Loss: 0.4632, Accuracy: 0.8438\n", + "Epoch: 20, Batch: 50, Loss: 0.3168, Accuracy: 0.8906\n", + "Epoch: 20, Batch: 100, Loss: 0.3277, Accuracy: 0.9219\n", + "Epoch: 20, Batch: 150, Loss: 0.1503, Accuracy: 0.9688\n", + "Epoch: 20, Batch: 200, Loss: 0.3160, Accuracy: 0.9062\n", + "Epoch: 20, Batch: 250, Loss: 0.2083, Accuracy: 0.9531\n", + "Epoch: 20, Batch: 300, Loss: 0.4475, Accuracy: 0.8594\n", + "Epoch: 20, Batch: 350, Loss: 0.4463, Accuracy: 0.8594\n", + "Epoch: 20, Batch: 400, Loss: 0.1941, Accuracy: 0.9531\n", + "Epoch: 20, Batch: 450, Loss: 0.2940, Accuracy: 0.9219\n", + "Epoch: 20, Batch: 500, Loss: 0.4440, Accuracy: 0.8594\n", + "Epoch: 20, Batch: 550, Loss: 0.5096, Accuracy: 0.8594\n", + "Epoch: 20, Batch: 600, Loss: 0.2695, Accuracy: 0.9219\n", + "Epoch: 20, Batch: 650, Loss: 0.4495, Accuracy: 0.8594\n", + "Epoch: 20, Batch: 700, Loss: 0.3003, Accuracy: 0.9062\n", + "Epoch: 20, Batch: 750, Loss: 0.2347, Accuracy: 0.9375\n", + "Epoch: 21, Batch: 0, Loss: 0.3034, Accuracy: 0.9219\n", + "Epoch: 21, Batch: 50, Loss: 0.3877, Accuracy: 0.8750\n", + "Epoch: 21, Batch: 100, Loss: 0.4309, Accuracy: 0.8750\n", + "Epoch: 21, Batch: 150, Loss: 0.3924, Accuracy: 0.8750\n", + "Epoch: 21, Batch: 200, Loss: 0.4508, Accuracy: 0.8594\n", + "Epoch: 21, Batch: 250, Loss: 0.2740, Accuracy: 0.9219\n", + "Epoch: 21, Batch: 300, Loss: 0.4891, Accuracy: 0.8438\n", + "Epoch: 21, Batch: 350, Loss: 0.2856, Accuracy: 0.9219\n", + "Epoch: 21, Batch: 400, Loss: 0.0455, Accuracy: 1.0000\n", + "Epoch: 21, Batch: 450, Loss: 0.4910, Accuracy: 0.8438\n", + "Epoch: 21, Batch: 500, Loss: 0.2694, Accuracy: 0.9062\n", + "Epoch: 21, Batch: 550, Loss: 0.3082, Accuracy: 0.9062\n", + "Epoch: 21, Batch: 600, Loss: 0.2260, Accuracy: 0.9531\n", + "Epoch: 21, Batch: 650, Loss: 0.1631, Accuracy: 0.9531\n", + "Epoch: 21, Batch: 700, Loss: 0.3069, Accuracy: 0.9219\n", + "Epoch: 21, Batch: 750, Loss: 0.4678, Accuracy: 0.8438\n", + "Epoch: 22, Batch: 0, Loss: 0.4563, Accuracy: 0.8594\n", + "Epoch: 22, Batch: 50, Loss: 0.2691, Accuracy: 0.9375\n", + "Epoch: 22, Batch: 100, Loss: 0.4520, Accuracy: 0.8438\n", + "Epoch: 22, Batch: 150, Loss: 0.3972, Accuracy: 0.8750\n", + "Epoch: 22, Batch: 200, Loss: 0.3320, Accuracy: 0.8906\n", + "Epoch: 22, Batch: 250, Loss: 0.4119, Accuracy: 0.8906\n", + "Epoch: 22, Batch: 300, Loss: 0.2597, Accuracy: 0.9219\n", + "Epoch: 22, Batch: 350, Loss: 0.2030, Accuracy: 0.9531\n", + "Epoch: 22, Batch: 400, Loss: 0.2151, Accuracy: 0.9375\n", + "Epoch: 22, Batch: 450, Loss: 0.2871, Accuracy: 0.9219\n", + "Epoch: 22, Batch: 500, Loss: 0.2892, Accuracy: 0.9062\n", + "Epoch: 22, Batch: 550, Loss: 0.4443, Accuracy: 0.8594\n", + "Epoch: 22, Batch: 600, Loss: 0.2177, Accuracy: 0.9219\n", + "Epoch: 22, Batch: 650, Loss: 0.3408, Accuracy: 0.9062\n", + "Epoch: 22, Batch: 700, Loss: 0.4698, Accuracy: 0.8594\n", + "Epoch: 22, Batch: 750, Loss: 0.4157, Accuracy: 0.8594\n", + "Epoch: 23, Batch: 0, Loss: 0.3918, Accuracy: 0.8906\n", + "Epoch: 23, Batch: 50, Loss: 0.1003, Accuracy: 0.9844\n", + "Epoch: 23, Batch: 100, Loss: 0.3034, Accuracy: 0.9062\n", + "Epoch: 23, Batch: 150, Loss: 0.2433, Accuracy: 0.9219\n", + "Epoch: 23, Batch: 200, Loss: 0.4006, Accuracy: 0.8750\n", + "Epoch: 23, Batch: 250, Loss: 0.4214, Accuracy: 0.8594\n", + "Epoch: 23, Batch: 300, Loss: 0.4411, Accuracy: 0.8594\n", + "Epoch: 23, Batch: 350, Loss: 0.2651, Accuracy: 0.9219\n", + "Epoch: 23, Batch: 400, Loss: 0.2992, Accuracy: 0.9219\n", + "Epoch: 23, Batch: 450, Loss: 0.3926, Accuracy: 0.8750\n", + "Epoch: 23, Batch: 500, Loss: 0.1348, Accuracy: 0.9688\n", + "Epoch: 23, Batch: 550, Loss: 0.2408, Accuracy: 0.9375\n", + "Epoch: 23, Batch: 600, Loss: 0.2532, Accuracy: 0.9219\n", + "Epoch: 23, Batch: 650, Loss: 0.3790, Accuracy: 0.8906\n", + "Epoch: 23, Batch: 700, Loss: 0.3529, Accuracy: 0.8906\n", + "Epoch: 23, Batch: 750, Loss: 0.4256, Accuracy: 0.8750\n", + "Epoch: 24, Batch: 0, Loss: 0.5058, Accuracy: 0.8438\n", + "Epoch: 24, Batch: 50, Loss: 0.2562, Accuracy: 0.9375\n", + "Epoch: 24, Batch: 100, Loss: 0.1293, Accuracy: 0.9688\n", + "Epoch: 24, Batch: 150, Loss: 0.2718, Accuracy: 0.9219\n", + "Epoch: 24, Batch: 200, Loss: 0.1990, Accuracy: 0.9375\n", + "Epoch: 24, Batch: 250, Loss: 0.5353, Accuracy: 0.8281\n", + "Epoch: 24, Batch: 300, Loss: 0.4225, Accuracy: 0.8750\n", + "Epoch: 24, Batch: 350, Loss: 0.5562, Accuracy: 0.8281\n", + "Epoch: 24, Batch: 400, Loss: 0.4328, Accuracy: 0.8750\n", + "Epoch: 24, Batch: 450, Loss: 0.5232, Accuracy: 0.8438\n", + "Epoch: 24, Batch: 500, Loss: 0.4480, Accuracy: 0.8750\n", + "Epoch: 24, Batch: 550, Loss: 0.2664, Accuracy: 0.9219\n", + "Epoch: 24, Batch: 600, Loss: 0.1258, Accuracy: 0.9844\n", + "Epoch: 24, Batch: 650, Loss: 0.2008, Accuracy: 0.9375\n", + "Epoch: 24, Batch: 700, Loss: 0.3194, Accuracy: 0.9062\n", + "Epoch: 24, Batch: 750, Loss: 0.2823, Accuracy: 0.9219\n", + "Epoch: 25, Batch: 0, Loss: 0.3141, Accuracy: 0.9062\n", + "Epoch: 25, Batch: 50, Loss: 0.4506, Accuracy: 0.8594\n", + "Epoch: 25, Batch: 100, Loss: 0.3840, Accuracy: 0.8906\n", + "Epoch: 25, Batch: 150, Loss: 0.3327, Accuracy: 0.8906\n", + "Epoch: 25, Batch: 200, Loss: 0.2974, Accuracy: 0.9219\n", + "Epoch: 25, Batch: 250, Loss: 0.1352, Accuracy: 0.9688\n", + "Epoch: 25, Batch: 300, Loss: 0.2615, Accuracy: 0.9219\n", + "Epoch: 25, Batch: 350, Loss: 0.2592, Accuracy: 0.9219\n", + "Epoch: 25, Batch: 400, Loss: 0.1725, Accuracy: 0.9531\n", + "Epoch: 25, Batch: 450, Loss: 0.2702, Accuracy: 0.9375\n", + "Epoch: 25, Batch: 500, Loss: 0.2383, Accuracy: 0.9375\n", + "Epoch: 25, Batch: 550, Loss: 0.3933, Accuracy: 0.8750\n", + "Epoch: 25, Batch: 600, Loss: 0.3011, Accuracy: 0.9219\n", + "Epoch: 25, Batch: 650, Loss: 0.2424, Accuracy: 0.9375\n", + "Epoch: 25, Batch: 700, Loss: 0.3150, Accuracy: 0.9062\n", + "Epoch: 25, Batch: 750, Loss: 0.4006, Accuracy: 0.8750\n", + "Epoch: 26, Batch: 0, Loss: 0.1929, Accuracy: 0.9531\n", + "Epoch: 26, Batch: 50, Loss: 0.5716, Accuracy: 0.8125\n", + "Epoch: 26, Batch: 100, Loss: 0.2425, Accuracy: 0.9375\n", + "Epoch: 26, Batch: 150, Loss: 0.3950, Accuracy: 0.8750\n", + "Epoch: 26, Batch: 200, Loss: 0.3319, Accuracy: 0.9062\n", + "Epoch: 26, Batch: 250, Loss: 0.1654, Accuracy: 0.9688\n", + "Epoch: 26, Batch: 300, Loss: 0.5397, Accuracy: 0.8438\n", + "Epoch: 26, Batch: 350, Loss: 0.2591, Accuracy: 0.9219\n", + "Epoch: 26, Batch: 400, Loss: 0.2968, Accuracy: 0.9219\n", + "Epoch: 26, Batch: 450, Loss: 0.5532, Accuracy: 0.8125\n", + "Epoch: 26, Batch: 500, Loss: 0.1760, Accuracy: 0.9531\n", + "Epoch: 26, Batch: 550, Loss: 0.3264, Accuracy: 0.8906\n", + "Epoch: 26, Batch: 600, Loss: 0.2803, Accuracy: 0.9219\n", + "Epoch: 26, Batch: 650, Loss: 0.2472, Accuracy: 0.9375\n", + "Epoch: 26, Batch: 700, Loss: 0.2606, Accuracy: 0.9219\n", + "Epoch: 26, Batch: 750, Loss: 0.2745, Accuracy: 0.9219\n", + "Epoch: 27, Batch: 0, Loss: 0.3512, Accuracy: 0.8906\n", + "Epoch: 27, Batch: 50, Loss: 0.2941, Accuracy: 0.9062\n", + "Epoch: 27, Batch: 100, Loss: 0.2187, Accuracy: 0.9375\n", + "Epoch: 27, Batch: 150, Loss: 0.2879, Accuracy: 0.9219\n", + "Epoch: 27, Batch: 200, Loss: 0.2874, Accuracy: 0.9219\n", + "Epoch: 27, Batch: 250, Loss: 0.5574, Accuracy: 0.8125\n", + "Epoch: 27, Batch: 300, Loss: 0.4487, Accuracy: 0.8594\n", + "Epoch: 27, Batch: 350, Loss: 0.2447, Accuracy: 0.9375\n", + "Epoch: 27, Batch: 400, Loss: 0.1519, Accuracy: 0.9688\n", + "Epoch: 27, Batch: 450, Loss: 0.3609, Accuracy: 0.8906\n", + "Epoch: 27, Batch: 500, Loss: 0.2143, Accuracy: 0.9375\n", + "Epoch: 27, Batch: 550, Loss: 0.3586, Accuracy: 0.9062\n", + "Epoch: 27, Batch: 600, Loss: 0.1934, Accuracy: 0.9375\n", + "Epoch: 27, Batch: 650, Loss: 0.3593, Accuracy: 0.8906\n", + "Epoch: 27, Batch: 700, Loss: 0.3218, Accuracy: 0.9062\n", + "Epoch: 27, Batch: 750, Loss: 0.3630, Accuracy: 0.8906\n", + "Epoch: 28, Batch: 0, Loss: 0.3300, Accuracy: 0.8906\n", + "Epoch: 28, Batch: 50, Loss: 0.3491, Accuracy: 0.8906\n", + "Epoch: 28, Batch: 100, Loss: 0.2256, Accuracy: 0.9375\n", + "Epoch: 28, Batch: 150, Loss: 0.3707, Accuracy: 0.8750\n", + "Epoch: 28, Batch: 200, Loss: 0.4507, Accuracy: 0.8438\n", + "Epoch: 28, Batch: 250, Loss: 0.3541, Accuracy: 0.8906\n", + "Epoch: 28, Batch: 300, Loss: 0.3883, Accuracy: 0.8750\n", + "Epoch: 28, Batch: 350, Loss: 0.2863, Accuracy: 0.9219\n", + "Epoch: 28, Batch: 400, Loss: 0.2894, Accuracy: 0.9219\n", + "Epoch: 28, Batch: 450, Loss: 0.3787, Accuracy: 0.8906\n", + "Epoch: 28, Batch: 500, Loss: 0.3526, Accuracy: 0.9062\n", + "Epoch: 28, Batch: 550, Loss: 0.3710, Accuracy: 0.8906\n", + "Epoch: 28, Batch: 600, Loss: 0.4035, Accuracy: 0.8750\n", + "Epoch: 28, Batch: 650, Loss: 0.2845, Accuracy: 0.9062\n", + "Epoch: 28, Batch: 700, Loss: 0.3513, Accuracy: 0.8906\n", + "Epoch: 28, Batch: 750, Loss: 0.3650, Accuracy: 0.8906\n", + "Epoch: 29, Batch: 0, Loss: 0.4340, Accuracy: 0.8594\n", + "Epoch: 29, Batch: 50, Loss: 0.3820, Accuracy: 0.8906\n", + "Epoch: 29, Batch: 100, Loss: 0.4393, Accuracy: 0.8750\n", + "Epoch: 29, Batch: 150, Loss: 0.5519, Accuracy: 0.8281\n", + "Epoch: 29, Batch: 200, Loss: 0.2818, Accuracy: 0.9219\n", + "Epoch: 29, Batch: 250, Loss: 0.4767, Accuracy: 0.8438\n", + "Epoch: 29, Batch: 300, Loss: 0.2885, Accuracy: 0.9219\n", + "Epoch: 29, Batch: 350, Loss: 0.4896, Accuracy: 0.8594\n", + "Epoch: 29, Batch: 400, Loss: 0.3687, Accuracy: 0.8906\n", + "Epoch: 29, Batch: 450, Loss: 0.2621, Accuracy: 0.9219\n", + "Epoch: 29, Batch: 500, Loss: 0.3818, Accuracy: 0.8906\n", + "Epoch: 29, Batch: 550, Loss: 0.3212, Accuracy: 0.9062\n", + "Epoch: 29, Batch: 600, Loss: 0.3866, Accuracy: 0.8906\n", + "Epoch: 29, Batch: 650, Loss: 0.3533, Accuracy: 0.9062\n", + "Epoch: 29, Batch: 700, Loss: 0.3396, Accuracy: 0.9062\n", + "Epoch: 29, Batch: 750, Loss: 0.3327, Accuracy: 0.9219\n", + "Epoch: 30, Batch: 0, Loss: 0.1774, Accuracy: 0.9531\n", + "Epoch: 30, Batch: 50, Loss: 0.1823, Accuracy: 0.9531\n", + "Epoch: 30, Batch: 100, Loss: 0.2740, Accuracy: 0.9219\n", + "Epoch: 30, Batch: 150, Loss: 0.2063, Accuracy: 0.9531\n", + "Epoch: 30, Batch: 200, Loss: 0.3956, Accuracy: 0.8906\n", + "Epoch: 30, Batch: 250, Loss: 0.2873, Accuracy: 0.9062\n", + "Epoch: 30, Batch: 300, Loss: 0.3908, Accuracy: 0.8906\n", + "Epoch: 30, Batch: 350, Loss: 0.2038, Accuracy: 0.9531\n", + "Epoch: 30, Batch: 400, Loss: 0.2114, Accuracy: 0.9531\n", + "Epoch: 30, Batch: 450, Loss: 0.1781, Accuracy: 0.9531\n", + "Epoch: 30, Batch: 500, Loss: 0.5614, Accuracy: 0.8281\n", + "Epoch: 30, Batch: 550, Loss: 0.3801, Accuracy: 0.8906\n", + "Epoch: 30, Batch: 600, Loss: 0.2480, Accuracy: 0.9375\n", + "Epoch: 30, Batch: 650, Loss: 0.2774, Accuracy: 0.9219\n", + "Epoch: 30, Batch: 700, Loss: 0.4436, Accuracy: 0.8594\n", + "Epoch: 30, Batch: 750, Loss: 0.2464, Accuracy: 0.9375\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 5. Modell mentése" + ], + "metadata": { + "id": "LAii5r1LFyJ1" + } + }, + { + "cell_type": "code", + "source": [ + "torch.save({\n", + " 'model_state_dict': model.state_dict(),\n", + " 'A': model.A,\n", + " 'B': model.B,\n", + " 'Z': model.Z\n", + "}, 'cenn_model.pth')" + ], + "metadata": { + "id": "yyG9r3whFxuQ" + }, + "execution_count": 7, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/Project_CeNN_CNN.ipynb b/Project_CeNN_CNN.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ff3d6f653762f1cca72404b1d7c43c03489906ab --- /dev/null +++ b/Project_CeNN_CNN.ipynb @@ -0,0 +1,515 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# 0. Importok" + ], + "metadata": { + "id": "Lil58l51RzsM" + } + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "jc3M1FaFQNsp" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "from torch.utils.data import DataLoader\n", + "import torch.nn as nn\n", + "from torchvision.models import resnet18" + ] + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "id": "SEMlMaM1R92Y" + } + }, + { + "cell_type": "code", + "source": [ + "# 1. Adatelőkészítés\n", + "transform = transforms.Compose([\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.RandomCrop(32, padding=4),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", + "])\n", + "\n", + "trainset = torchvision.datasets.CIFAR10(\n", + " root='./data',\n", + " train=True,\n", + " download=True,\n", + " transform=transform\n", + ")\n", + "\n", + "trainloader = DataLoader(trainset, batch_size=64, shuffle=True)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ybVCnWi4R-NS", + "outputId": "71d7e5f2-b80b-44a4-b368-1b16c6d70e65" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 170M/170M [00:06<00:00, 28.1MB/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting ./data/cifar-10-python.tar.gz to ./data\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 2. Hibrid modell osztály" + ], + "metadata": { + "id": "MAF_1TgfSYHJ" + } + }, + { + "cell_type": "code", + "source": [ + "class HybridCeNN_CNN(nn.Module):\n", + " def __init__(self, h=0.1, iter_num=10):\n", + " super().__init__()\n", + " self.h = h\n", + " self.iter_num = iter_num\n", + "\n", + " # CeNN rész\n", + " self.A = nn.Parameter(torch.randn(1, 3, 3, 3))\n", + " self.B = nn.Parameter(torch.randn(1, 3, 3, 3))\n", + " self.Z = nn.Parameter(torch.randn(1))\n", + "\n", + " # CNN rész (módosított ResNet)\n", + " self.cnn = resnet18(weights=None)\n", + " self.cnn.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n", + " self.cnn.maxpool = nn.Identity()\n", + " self.cnn.fc = nn.Linear(512, 10)\n", + "\n", + " def forward(self, U):\n", + " # CeNN előfeldolgozás\n", + " x = U\n", + " for _ in range(self.iter_num):\n", + " y = torch.clamp(x, -1, 1) # Egyszerűsített aktiváció\n", + " fwd = torch.nn.functional.conv2d(U, self.B, padding=1) + self.Z\n", + " bwd = torch.nn.functional.conv2d(y, self.A, padding=1)\n", + " x = x + self.h * (-x + bwd + fwd)\n", + "\n", + " # CNN osztályozás\n", + " return self.cnn(x)" + ], + "metadata": { + "id": "iMuSg8YNSXxi" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# 3. Modell és optimalizáló" + ], + "metadata": { + "id": "AgTBXU42T8pe" + } + }, + { + "cell_type": "code", + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Használt eszköz: {device}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GWjVItvj_IJF", + "outputId": "1088e026-9db3-4638-ce19-55c2477190e3" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Használt eszköz: cuda\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "model = HybridCeNN_CNN().to(device)\n", + "optimizer = torch.optim.Adam([\n", + " {'params': model.A},\n", + " {'params': model.B},\n", + " {'params': model.Z},\n", + " {'params': model.cnn.parameters()}\n", + "], lr=0.001)\n", + "\n", + "criterion = nn.CrossEntropyLoss()" + ], + "metadata": { + "id": "H0hd36GPT8dA" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# 4. Tanítási ciklus" + ], + "metadata": { + "id": "eZPvhqalUNGG" + } + }, + { + "cell_type": "code", + "source": [ + "for epoch in range(30):\n", + " model.train()\n", + " for batch_idx, (data, targets) in enumerate(trainloader):\n", + " data, targets = data.to(device), targets.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(data)\n", + " loss = criterion(outputs, targets)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if batch_idx % 100 == 0:\n", + " print(f'Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item():.4f}')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uuVVLhpBUSr9", + "outputId": "3c219607-0135-42cb-fec6-8000744c33a3" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch: 1, Batch: 0, Loss: 2.3162\n", + "Epoch: 1, Batch: 100, Loss: 1.8707\n", + "Epoch: 1, Batch: 200, Loss: 1.6191\n", + "Epoch: 1, Batch: 300, Loss: 1.4384\n", + "Epoch: 1, Batch: 400, Loss: 1.5578\n", + "Epoch: 1, Batch: 500, Loss: 1.3173\n", + "Epoch: 1, Batch: 600, Loss: 1.4486\n", + "Epoch: 1, Batch: 700, Loss: 0.9764\n", + "Epoch: 2, Batch: 0, Loss: 1.1860\n", + "Epoch: 2, Batch: 100, Loss: 0.9394\n", + "Epoch: 2, Batch: 200, Loss: 0.9662\n", + "Epoch: 2, Batch: 300, Loss: 0.8775\n", + "Epoch: 2, Batch: 400, Loss: 0.8830\n", + "Epoch: 2, Batch: 500, Loss: 1.0138\n", + "Epoch: 2, Batch: 600, Loss: 0.8941\n", + "Epoch: 2, Batch: 700, Loss: 0.6419\n", + "Epoch: 3, Batch: 0, Loss: 0.7703\n", + "Epoch: 3, Batch: 100, Loss: 0.6211\n", + "Epoch: 3, Batch: 200, Loss: 0.8289\n", + "Epoch: 3, Batch: 300, Loss: 0.9120\n", + "Epoch: 3, Batch: 400, Loss: 0.8239\n", + "Epoch: 3, Batch: 500, Loss: 0.8767\n", + "Epoch: 3, Batch: 600, Loss: 0.6739\n", + "Epoch: 3, Batch: 700, Loss: 0.4059\n", + "Epoch: 4, Batch: 0, Loss: 0.6486\n", + "Epoch: 4, Batch: 100, Loss: 0.6528\n", + "Epoch: 4, Batch: 200, Loss: 0.5334\n", + "Epoch: 4, Batch: 300, Loss: 0.5487\n", + "Epoch: 4, Batch: 400, Loss: 0.5126\n", + "Epoch: 4, Batch: 500, Loss: 0.6614\n", + "Epoch: 4, Batch: 600, Loss: 0.5521\n", + "Epoch: 4, Batch: 700, Loss: 0.5489\n", + "Epoch: 5, Batch: 0, Loss: 0.5379\n", + "Epoch: 5, Batch: 100, Loss: 0.4864\n", + "Epoch: 5, Batch: 200, Loss: 0.5838\n", + "Epoch: 5, Batch: 300, Loss: 0.7733\n", + "Epoch: 5, Batch: 400, Loss: 0.3621\n", + "Epoch: 5, Batch: 500, Loss: 0.6030\n", + "Epoch: 5, Batch: 600, Loss: 0.4617\n", + "Epoch: 5, Batch: 700, Loss: 0.4785\n", + "Epoch: 6, Batch: 0, Loss: 0.6555\n", + "Epoch: 6, Batch: 100, Loss: 0.6239\n", + "Epoch: 6, Batch: 200, Loss: 0.4051\n", + "Epoch: 6, Batch: 300, Loss: 0.4033\n", + "Epoch: 6, Batch: 400, Loss: 0.3421\n", + "Epoch: 6, Batch: 500, Loss: 0.4544\n", + "Epoch: 6, Batch: 600, Loss: 0.4110\n", + "Epoch: 6, Batch: 700, Loss: 0.4839\n", + "Epoch: 7, Batch: 0, Loss: 0.4884\n", + "Epoch: 7, Batch: 100, Loss: 0.4682\n", + "Epoch: 7, Batch: 200, Loss: 0.4639\n", + "Epoch: 7, Batch: 300, Loss: 0.5159\n", + "Epoch: 7, Batch: 400, Loss: 0.2979\n", + "Epoch: 7, Batch: 500, Loss: 0.4135\n", + "Epoch: 7, Batch: 600, Loss: 0.4771\n", + "Epoch: 7, Batch: 700, Loss: 0.3821\n", + "Epoch: 8, Batch: 0, Loss: 0.3357\n", + "Epoch: 8, Batch: 100, Loss: 0.4008\n", + "Epoch: 8, Batch: 200, Loss: 0.3050\n", + "Epoch: 8, Batch: 300, Loss: 0.2159\n", + "Epoch: 8, Batch: 400, Loss: 0.4112\n", + "Epoch: 8, Batch: 500, Loss: 0.4043\n", + "Epoch: 8, Batch: 600, Loss: 0.2823\n", + "Epoch: 8, Batch: 700, Loss: 0.3153\n", + "Epoch: 9, Batch: 0, Loss: 0.3715\n", + "Epoch: 9, Batch: 100, Loss: 0.4164\n", + "Epoch: 9, Batch: 200, Loss: 0.2845\n", + "Epoch: 9, Batch: 300, Loss: 0.3807\n", + "Epoch: 9, Batch: 400, Loss: 0.5043\n", + "Epoch: 9, Batch: 500, Loss: 0.2632\n", + "Epoch: 9, Batch: 600, Loss: 0.3511\n", + "Epoch: 9, Batch: 700, Loss: 0.3831\n", + "Epoch: 10, Batch: 0, Loss: 0.3084\n", + "Epoch: 10, Batch: 100, Loss: 0.1943\n", + "Epoch: 10, Batch: 200, Loss: 0.2923\n", + "Epoch: 10, Batch: 300, Loss: 0.1154\n", + "Epoch: 10, Batch: 400, Loss: 0.1852\n", + "Epoch: 10, Batch: 500, Loss: 0.1287\n", + "Epoch: 10, Batch: 600, Loss: 0.3440\n", + "Epoch: 10, Batch: 700, Loss: 0.2018\n", + "Epoch: 11, Batch: 0, Loss: 0.1535\n", + "Epoch: 11, Batch: 100, Loss: 0.2198\n", + "Epoch: 11, Batch: 200, Loss: 0.3176\n", + "Epoch: 11, Batch: 300, Loss: 0.1781\n", + "Epoch: 11, Batch: 400, Loss: 0.2402\n", + "Epoch: 11, Batch: 500, Loss: 0.3165\n", + "Epoch: 11, Batch: 600, Loss: 0.1474\n", + "Epoch: 11, Batch: 700, Loss: 0.3127\n", + "Epoch: 12, Batch: 0, Loss: 0.2592\n", + "Epoch: 12, Batch: 100, Loss: 0.1840\n", + "Epoch: 12, Batch: 200, Loss: 0.3848\n", + "Epoch: 12, Batch: 300, Loss: 0.2294\n", + "Epoch: 12, Batch: 400, Loss: 0.2759\n", + "Epoch: 12, Batch: 500, Loss: 0.2546\n", + "Epoch: 12, Batch: 600, Loss: 0.2649\n", + "Epoch: 12, Batch: 700, Loss: 0.4619\n", + "Epoch: 13, Batch: 0, Loss: 0.1814\n", + "Epoch: 13, Batch: 100, Loss: 0.3335\n", + "Epoch: 13, Batch: 200, Loss: 0.1933\n", + "Epoch: 13, Batch: 300, Loss: 0.2852\n", + "Epoch: 13, Batch: 400, Loss: 0.2613\n", + "Epoch: 13, Batch: 500, Loss: 0.2625\n", + "Epoch: 13, Batch: 600, Loss: 0.1588\n", + "Epoch: 13, Batch: 700, Loss: 0.2675\n", + "Epoch: 14, Batch: 0, Loss: 0.1518\n", + "Epoch: 14, Batch: 100, Loss: 0.1809\n", + "Epoch: 14, Batch: 200, Loss: 0.1823\n", + "Epoch: 14, Batch: 300, Loss: 0.2635\n", + "Epoch: 14, Batch: 400, Loss: 0.4669\n", + "Epoch: 14, Batch: 500, Loss: 0.1534\n", + "Epoch: 14, Batch: 600, Loss: 0.0958\n", + "Epoch: 14, Batch: 700, Loss: 0.2517\n", + "Epoch: 15, Batch: 0, Loss: 0.1085\n", + "Epoch: 15, Batch: 100, Loss: 0.1393\n", + "Epoch: 15, Batch: 200, Loss: 0.1517\n", + "Epoch: 15, Batch: 300, Loss: 0.1721\n", + "Epoch: 15, Batch: 400, Loss: 0.0785\n", + "Epoch: 15, Batch: 500, Loss: 0.2807\n", + "Epoch: 15, Batch: 600, Loss: 0.0837\n", + "Epoch: 15, Batch: 700, Loss: 0.1786\n", + "Epoch: 16, Batch: 0, Loss: 0.0274\n", + "Epoch: 16, Batch: 100, Loss: 0.1039\n", + "Epoch: 16, Batch: 200, Loss: 0.2527\n", + "Epoch: 16, Batch: 300, Loss: 0.1042\n", + "Epoch: 16, Batch: 400, Loss: 0.1901\n", + "Epoch: 16, Batch: 500, Loss: 0.2128\n", + "Epoch: 16, Batch: 600, Loss: 0.1787\n", + "Epoch: 16, Batch: 700, Loss: 0.1444\n", + "Epoch: 17, Batch: 0, Loss: 0.0999\n", + "Epoch: 17, Batch: 100, Loss: 0.2429\n", + "Epoch: 17, Batch: 200, Loss: 0.0779\n", + "Epoch: 17, Batch: 300, Loss: 0.1294\n", + "Epoch: 17, Batch: 400, Loss: 0.1542\n", + "Epoch: 17, Batch: 500, Loss: 0.2249\n", + "Epoch: 17, Batch: 600, Loss: 0.1600\n", + "Epoch: 17, Batch: 700, Loss: 0.3770\n", + "Epoch: 18, Batch: 0, Loss: 0.1151\n", + "Epoch: 18, Batch: 100, Loss: 0.0851\n", + "Epoch: 18, Batch: 200, Loss: 0.1027\n", + "Epoch: 18, Batch: 300, Loss: 0.1325\n", + "Epoch: 18, Batch: 400, Loss: 0.1680\n", + "Epoch: 18, Batch: 500, Loss: 0.2423\n", + "Epoch: 18, Batch: 600, Loss: 0.1666\n", + "Epoch: 18, Batch: 700, Loss: 0.2186\n", + "Epoch: 19, Batch: 0, Loss: 0.1842\n", + "Epoch: 19, Batch: 100, Loss: 0.1486\n", + "Epoch: 19, Batch: 200, Loss: 0.1381\n", + "Epoch: 19, Batch: 300, Loss: 0.0629\n", + "Epoch: 19, Batch: 400, Loss: 0.0521\n", + "Epoch: 19, Batch: 500, Loss: 0.0506\n", + "Epoch: 19, Batch: 600, Loss: 0.2298\n", + "Epoch: 19, Batch: 700, Loss: 0.0978\n", + "Epoch: 20, Batch: 0, Loss: 0.1976\n", + "Epoch: 20, Batch: 100, Loss: 0.0710\n", + "Epoch: 20, Batch: 200, Loss: 0.1511\n", + "Epoch: 20, Batch: 300, Loss: 0.1046\n", + "Epoch: 20, Batch: 400, Loss: 0.2077\n", + "Epoch: 20, Batch: 500, Loss: 0.1461\n", + "Epoch: 20, Batch: 600, Loss: 0.1101\n", + "Epoch: 20, Batch: 700, Loss: 0.2041\n", + "Epoch: 21, Batch: 0, Loss: 0.0737\n", + "Epoch: 21, Batch: 100, Loss: 0.1191\n", + "Epoch: 21, Batch: 200, Loss: 0.1232\n", + "Epoch: 21, Batch: 300, Loss: 0.1898\n", + "Epoch: 21, Batch: 400, Loss: 0.1074\n", + "Epoch: 21, Batch: 500, Loss: 0.0719\n", + "Epoch: 21, Batch: 600, Loss: 0.0778\n", + "Epoch: 21, Batch: 700, Loss: 0.2157\n", + "Epoch: 22, Batch: 0, Loss: 0.1120\n", + "Epoch: 22, Batch: 100, Loss: 0.0175\n", + "Epoch: 22, Batch: 200, Loss: 0.1115\n", + "Epoch: 22, Batch: 300, Loss: 0.1243\n", + "Epoch: 22, Batch: 400, Loss: 0.0250\n", + "Epoch: 22, Batch: 500, Loss: 0.0922\n", + "Epoch: 22, Batch: 600, Loss: 0.2028\n", + "Epoch: 22, Batch: 700, Loss: 0.0796\n", + "Epoch: 23, Batch: 0, Loss: 0.1341\n", + "Epoch: 23, Batch: 100, Loss: 0.0497\n", + "Epoch: 23, Batch: 200, Loss: 0.2450\n", + "Epoch: 23, Batch: 300, Loss: 0.0892\n", + "Epoch: 23, Batch: 400, Loss: 0.0652\n", + "Epoch: 23, Batch: 500, Loss: 0.2296\n", + "Epoch: 23, Batch: 600, Loss: 0.0666\n", + "Epoch: 23, Batch: 700, Loss: 0.1161\n", + "Epoch: 24, Batch: 0, Loss: 0.1039\n", + "Epoch: 24, Batch: 100, Loss: 0.1519\n", + "Epoch: 24, Batch: 200, Loss: 0.1032\n", + "Epoch: 24, Batch: 300, Loss: 0.0345\n", + "Epoch: 24, Batch: 400, Loss: 0.1313\n", + "Epoch: 24, Batch: 500, Loss: 0.1651\n", + "Epoch: 24, Batch: 600, Loss: 0.1424\n", + "Epoch: 24, Batch: 700, Loss: 0.1687\n", + "Epoch: 25, Batch: 0, Loss: 0.0630\n", + "Epoch: 25, Batch: 100, Loss: 0.0532\n", + "Epoch: 25, Batch: 200, Loss: 0.1020\n", + "Epoch: 25, Batch: 300, Loss: 0.1182\n", + "Epoch: 25, Batch: 400, Loss: 0.0716\n", + "Epoch: 25, Batch: 500, Loss: 0.0289\n", + "Epoch: 25, Batch: 600, Loss: 0.1287\n", + "Epoch: 25, Batch: 700, Loss: 0.1860\n", + "Epoch: 26, Batch: 0, Loss: 0.0946\n", + "Epoch: 26, Batch: 100, Loss: 0.0702\n", + "Epoch: 26, Batch: 200, Loss: 0.1193\n", + "Epoch: 26, Batch: 300, Loss: 0.0780\n", + "Epoch: 26, Batch: 400, Loss: 0.0149\n", + "Epoch: 26, Batch: 500, Loss: 0.0486\n", + "Epoch: 26, Batch: 600, Loss: 0.0951\n", + "Epoch: 26, Batch: 700, Loss: 0.1454\n", + "Epoch: 27, Batch: 0, Loss: 0.0472\n", + "Epoch: 27, Batch: 100, Loss: 0.0720\n", + "Epoch: 27, Batch: 200, Loss: 0.0521\n", + "Epoch: 27, Batch: 300, Loss: 0.0654\n", + "Epoch: 27, Batch: 400, Loss: 0.0928\n", + "Epoch: 27, Batch: 500, Loss: 0.0249\n", + "Epoch: 27, Batch: 600, Loss: 0.1215\n", + "Epoch: 27, Batch: 700, Loss: 0.0526\n", + "Epoch: 28, Batch: 0, Loss: 0.0769\n", + "Epoch: 28, Batch: 100, Loss: 0.1254\n", + "Epoch: 28, Batch: 200, Loss: 0.1225\n", + "Epoch: 28, Batch: 300, Loss: 0.0684\n", + "Epoch: 28, Batch: 400, Loss: 0.1082\n", + "Epoch: 28, Batch: 500, Loss: 0.0621\n", + "Epoch: 28, Batch: 600, Loss: 0.0434\n", + "Epoch: 28, Batch: 700, Loss: 0.0655\n", + "Epoch: 29, Batch: 0, Loss: 0.0395\n", + "Epoch: 29, Batch: 100, Loss: 0.0380\n", + "Epoch: 29, Batch: 200, Loss: 0.0291\n", + "Epoch: 29, Batch: 300, Loss: 0.0632\n", + "Epoch: 29, Batch: 400, Loss: 0.0638\n", + "Epoch: 29, Batch: 500, Loss: 0.1050\n", + "Epoch: 29, Batch: 600, Loss: 0.0911\n", + "Epoch: 29, Batch: 700, Loss: 0.0645\n", + "Epoch: 30, Batch: 0, Loss: 0.0548\n", + "Epoch: 30, Batch: 100, Loss: 0.0504\n", + "Epoch: 30, Batch: 200, Loss: 0.0496\n", + "Epoch: 30, Batch: 300, Loss: 0.0294\n", + "Epoch: 30, Batch: 400, Loss: 0.0203\n", + "Epoch: 30, Batch: 500, Loss: 0.0696\n", + "Epoch: 30, Batch: 600, Loss: 0.1455\n", + "Epoch: 30, Batch: 700, Loss: 0.0719\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Mentés" + ], + "metadata": { + "id": "fJOZXRqVUoqo" + } + }, + { + "cell_type": "code", + "source": [ + "torch.save({\n", + " 'model_state_dict': model.state_dict(),\n", + " 'A': model.A,\n", + " 'B': model.B,\n", + " 'Z': model.Z,\n", + " 'cnn_conv1': model.cnn.conv1.weight,\n", + " 'cnn_fc': model.cnn.fc.state_dict(),\n", + " 'h': model.h,\n", + " 'iter_num': model.iter_num\n", + "}, 'hybrid_model.pth')" + ], + "metadata": { + "id": "WObN5P2rXhJg" + }, + "execution_count": 9, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/Project_Test.ipynb b/Project_Test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..7877a2114e31719948c0dac4770cabbce74c817f --- /dev/null +++ b/Project_Test.ipynb @@ -0,0 +1,598 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# 0. Importok, device" + ], + "metadata": { + "id": "tDLgxyVKaDl_" + } + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "id": "6l13ZBfQZLuf" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "from torch.utils.data import DataLoader\n", + "import torch.nn as nn\n", + "from torchvision.models import resnet18" + ] + }, + { + "cell_type": "code", + "source": [ + "# GPU ellenőrzés\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Használt eszköz: {device}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "icTgUWEAZV6b", + "outputId": "7ddc398d-1396-4e2f-ae0f-7775dcd256eb" + }, + "execution_count": 30, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Használt eszköz: cpu\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 1. Dataset" + ], + "metadata": { + "id": "rnDrSHuyaJ3I" + } + }, + { + "cell_type": "code", + "source": [ + "transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", + "])\n", + "\n", + "testset = torchvision.datasets.CIFAR10(\n", + " root='./data',\n", + " train=False,\n", + " download=True,\n", + " transform=transform\n", + ")\n", + "\n", + "testloader = DataLoader(\n", + " testset,\n", + " batch_size=64,\n", + " shuffle=False\n", + ")" + ], + "metadata": { + "id": "7PXRQQNGaRXF" + }, + "execution_count": 31, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# 2. Model definitions" + ], + "metadata": { + "id": "MdqLUOlQb6Z3" + } + }, + { + "cell_type": "code", + "source": [ + "# CeNN modell\n", + "class CeNN(nn.Module):\n", + " def __init__(self, h=0.1, iter_num=20):\n", + " super().__init__()\n", + " self.h = h\n", + " self.iter_num = iter_num\n", + " self.alpha = 0.01\n", + " self.A = nn.Parameter(torch.randn(1, 3, 3, 3))\n", + " self.B = nn.Parameter(torch.randn(1, 3, 3, 3))\n", + " self.Z = nn.Parameter(torch.randn(1))\n", + "\n", + " def forward(self, U):\n", + " x = U\n", + " for _ in range(self.iter_num):\n", + " y = torch.minimum(x, 1 + self.alpha * x)\n", + " y = torch.maximum(y, -1 + self.alpha * y)\n", + " fwd = torch.nn.functional.conv2d(U, self.B, padding=1) + self.Z\n", + " bwd = torch.nn.functional.conv2d(y, self.A, padding=1)\n", + " x = x + self.h * (-x + bwd + fwd)\n", + " out = torch.minimum(x, 1 + self.alpha * x)\n", + " out = torch.maximum(out, -1 + self.alpha * out)\n", + " return out" + ], + "metadata": { + "id": "zgm3MJOMqlIC" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# CNN modell\n", + "class CIFAR10_ResNet(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.model = resnet18()\n", + " self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n", + " self.model.maxpool = nn.Identity()\n", + " self.model.fc = nn.Linear(512, 10)\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)" + ], + "metadata": { + "id": "1Muj5ySpqk73" + }, + "execution_count": 20, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Hibrid modell\n", + "class HybridCeNN_CNN(nn.Module):\n", + " def __init__(self, h=0.1, iter_num=10):\n", + " super().__init__()\n", + " self.h = h\n", + " self.iter_num = iter_num\n", + "\n", + " # CeNN rész\n", + " self.A = nn.Parameter(torch.randn(1, 3, 3, 3))\n", + " self.B = nn.Parameter(torch.randn(1, 3, 3, 3))\n", + " self.Z = nn.Parameter(torch.randn(1))\n", + "\n", + " # CNN rész (módosított ResNet)\n", + " self.cnn = resnet18(weights=None)\n", + " self.cnn.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n", + " self.cnn.maxpool = nn.Identity()\n", + " self.cnn.fc = nn.Linear(512, 10)\n", + "\n", + " def forward(self, U):\n", + " # CeNN előfeldolgozás\n", + " x = U\n", + " for _ in range(self.iter_num):\n", + " y = torch.clamp(x, -1, 1) # Egyszerűsített aktiváció\n", + " fwd = torch.nn.functional.conv2d(U, self.B, padding=1) + self.Z\n", + " bwd = torch.nn.functional.conv2d(y, self.A, padding=1)\n", + " x = x + self.h * (-x + bwd + fwd)\n", + "\n", + " # CNN osztályozás\n", + " return self.cnn(x)" + ], + "metadata": { + "id": "r-3L0H3Obzcm" + }, + "execution_count": 32, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# 3. Modellek betöltése" + ], + "metadata": { + "id": "-W8u5L-tc4Ot" + } + }, + { + "cell_type": "code", + "source": [ + "# CeNN betöltése\n", + "cenn_model = CeNN().to(device)\n", + "checkpoint_cenn = torch.load('cenn_model.pth', map_location=device)\n", + "cenn_model.load_state_dict(checkpoint_cenn['model_state_dict'])\n", + "cenn_model.A.data = checkpoint_cenn['A']\n", + "cenn_model.B.data = checkpoint_cenn['B']\n", + "cenn_model.Z.data = checkpoint_cenn['Z']\n", + "cenn_model.eval()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "k9EGUbpNqsu1", + "outputId": "64996546-d8bf-40ea-ab05-81b49da6338e" + }, + "execution_count": 21, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "CeNN()" + ] + }, + "metadata": {}, + "execution_count": 21 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# CNN betöltése\n", + "cnn_model = CIFAR10_ResNet().to(device)\n", + "checkpoint_cnn = torch.load('cnn_model.pth', map_location=device)\n", + "cnn_model.load_state_dict(checkpoint_cnn['model_state_dict'])\n", + "cnn_model.model.conv1.weight = checkpoint_cnn['conv1_weights']\n", + "cnn_model.model.fc.weight = checkpoint_cnn['fc_weights']\n", + "cnn_model.model.fc.bias = checkpoint_cnn['fc_bias']\n", + "cnn_model.eval()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "S9-ZoV2Zc7ih", + "outputId": "42c4c6ab-e083-4767-bd65-e0572106a9de" + }, + "execution_count": 22, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "CIFAR10_ResNet(\n", + " (model): ResNet(\n", + " (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (maxpool): Identity()\n", + " (layer1): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (layer2): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (layer3): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (layer4): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", + " (fc): Linear(in_features=512, out_features=10, bias=True)\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Hibrid modell betöltése\n", + "hybrid_model = HybridCeNN_CNN().to(device)\n", + "checkpoint_h = torch.load('hybrid_model.pth', map_location=device)\n", + "hybrid_model.load_state_dict(checkpoint_h['model_state_dict'])\n", + "hybrid_model.A.data = checkpoint_h['A']\n", + "hybrid_model.B.data = checkpoint_h['B']\n", + "hybrid_model.Z.data = checkpoint_h['Z']\n", + "hybrid_model.cnn.conv1.weight = checkpoint_h['cnn_conv1']\n", + "hybrid_model.cnn.fc.load_state_dict(checkpoint_h['cnn_fc'])\n", + "hybrid_model.eval()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "oukhJeGVqsT8", + "outputId": "42d1af54-c5d6-4c12-e1b7-a06e41056d98" + }, + "execution_count": 33, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "HybridCeNN_CNN(\n", + " (cnn): ResNet(\n", + " (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (maxpool): Identity()\n", + " (layer1): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (layer2): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (layer3): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (layer4): Sequential(\n", + " (0): BasicBlock(\n", + " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (downsample): Sequential(\n", + " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", + " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): BasicBlock(\n", + " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", + " (fc): Linear(in_features=512, out_features=10, bias=True)\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 33 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# 4. Kiértékelés\n", + "\n" + ], + "metadata": { + "id": "1AjDZUsWf63_" + } + }, + { + "cell_type": "code", + "source": [ + "def evaluate_model(model, model_name):\n", + " correct = 0\n", + " total = 0\n", + " with torch.no_grad():\n", + " for images, labels in testloader:\n", + " images, labels = images.to(device), labels.to(device)\n", + " outputs = model(images)\n", + "\n", + " # CeNN speciális értékelés\n", + " if isinstance(model, CeNN):\n", + " expected = torch.ones_like(outputs)\n", + " expected[labels == 0] = -1\n", + " correct_dist = torch.mean((outputs - expected)**2, [1,2,3])\n", + " incorrect_dist = torch.mean((outputs - (-expected))**2, [1,2,3])\n", + " correct += (correct_dist < incorrect_dist).sum().item()\n", + "\n", + " # CNN/Hibrid standard értékelés\n", + " else:\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " correct += (predicted == labels).sum().item()\n", + "\n", + " total += labels.size(0)\n", + "\n", + " print(f\"{model_name} pontossága: {100 * correct / total:.2f}%\")" + ], + "metadata": { + "id": "GyRfDVwugAd3" + }, + "execution_count": 23, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "evaluate_model(cenn_model, \"CeNN\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Qsbc5sXjq_Op", + "outputId": "84d1f18e-f5c3-44a1-8306-05b2bff06915" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CeNN pontossága: 90.02%\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "evaluate_model(cnn_model, \"CNN\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PviP-1Y-rAMW", + "outputId": "c3f456a1-c434-4fd2-faca-814487815524" + }, + "execution_count": 24, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CNN pontossága: 92.40%\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "evaluate_model(hybrid_model, \"Hibrid modell\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CJTPURkcq_p9", + "outputId": "9a53ec7f-3d78-45e4-d8bb-30906f11633a" + }, + "execution_count": 34, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Hibrid modell pontossága: 91.45%\n" + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/cenn_model.pth b/cenn_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..08622191c943f0013ea6bbc5714263894ffe191d Binary files /dev/null and b/cenn_model.pth differ diff --git a/cnn-project b/cnn-project new file mode 160000 index 0000000000000000000000000000000000000000..aeaa420605938f77f59aba8231cd39a09ac9e9a0 --- /dev/null +++ b/cnn-project @@ -0,0 +1 @@ +Subproject commit aeaa420605938f77f59aba8231cd39a09ac9e9a0 diff --git a/cnn_model.pth b/cnn_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..352b21cb2c8d02d12c5740931b97364efafce4c4 Binary files /dev/null and b/cnn_model.pth differ diff --git a/hybrid_model.pth b/hybrid_model.pth new file mode 100644 index 0000000000000000000000000000000000000000..9ac070cf831e5bf129fab4af33cd5caca5d8a2c4 Binary files /dev/null and b/hybrid_model.pth differ