import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_splitclass WaterDataset(Dataset):
def __init__(self,csv_path):
super().__init__() # here it behave like a TensorDataSet
df = pd.read_csv(csv_path)
self.data = df.to_numpy()
def __len__(self):
return self.data.shape[0]
def __getitem__(self,idx):
features = self.data[idx,:-1]
label = self.data[idx,-1]
return np.float32(features), np.float32(label)dataset_train = WaterDataset(
"/content/drive/MyDrive/AI/Pytorch/water_potability.csv"
)len(dataset_train)2011
from torch.utils.data import DataLoaderdataloader_train = DataLoader(
dataset_train,
batch_size=50,
shuffle=True
)features,labels = next(iter(dataloader_train))
print(f"Features: {features}, \nLabels:{labels}")Features: tensor([[8.2471e+00, 1.7591e+02, 9.9140e+03, 4.9745e+00, 3.2443e+02, 3.0220e+02,
1.1071e+01, 7.1438e+01, 3.9420e+00],
[7.9653e+00, 1.5154e+02, 2.5275e+04, 7.1060e+00, 3.5232e+02, 5.2769e+02,
1.5793e+01, 5.2268e+01, 3.3910e+00],
[6.3531e+00, 2.0627e+02, 1.8409e+04, 3.7238e+00, 3.4821e+02, 3.8968e+02,
1.3205e+01, 6.0291e+01, 2.8562e+00],
[3.4451e+00, 2.0793e+02, 3.3425e+04, 8.7821e+00, 3.8401e+02, 4.4179e+02,
1.3806e+01, 3.0285e+01, 4.1844e+00],
[9.9607e+00, 1.6993e+02, 1.5835e+04, 7.1439e+00, 3.0178e+02, 3.8890e+02,
1.3564e+01, 7.1594e+01, 3.4345e+00],
[8.2056e+00, 2.0467e+02, 1.7415e+04, 6.8396e+00, 2.7677e+02, 3.4654e+02,
1.2506e+01, 8.3917e+01, 5.1295e+00],
[6.1091e+00, 1.9176e+02, 2.6854e+04, 9.0646e+00, 3.1220e+02, 3.7555e+02,
1.5514e+01, 7.3790e+01, 4.8811e+00],
[7.8931e+00, 2.0143e+02, 2.0526e+04, 5.6288e+00, 2.9902e+02, 3.0388e+02,
1.5255e+01, 7.1542e+01, 3.3022e+00],
[6.5051e+00, 2.2642e+02, 1.6982e+04, 6.9385e+00, 3.1825e+02, 4.8409e+02,
1.8527e+01, 8.0463e+01, 2.8910e+00],
[6.5828e+00, 2.1918e+02, 2.1962e+04, 6.9880e+00, 3.4890e+02, 3.4121e+02,
1.5178e+01, 6.8982e+01, 3.6689e+00],
[9.9207e+00, 2.0282e+02, 9.9739e+03, 6.8822e+00, 3.3735e+02, 3.3319e+02,
2.3918e+01, 7.1834e+01, 4.6907e+00],
[8.2271e+00, 2.7435e+02, 4.0547e+04, 7.1302e+00, 2.4145e+02, 4.1767e+02,
9.8097e+00, 7.9397e+01, 3.6192e+00],
[7.3228e+00, 2.3034e+02, 2.4682e+04, 7.4256e+00, 3.2383e+02, 3.4931e+02,
9.4979e+00, 5.0660e+01, 3.9820e+00],
[8.5128e+00, 1.5767e+02, 3.3093e+04, 6.7655e+00, 3.0586e+02, 3.7762e+02,
1.3309e+01, 4.3019e+01, 4.0266e+00],
[9.0426e+00, 2.2132e+02, 1.4150e+04, 5.3259e+00, 3.6667e+02, 3.7786e+02,
1.3008e+01, 8.7896e+01, 4.3484e+00],
[5.0581e+00, 2.3857e+02, 3.4874e+04, 8.9833e+00, 3.7443e+02, 6.6973e+02,
1.3353e+01, 7.6522e+01, 5.1067e+00],
[8.6640e+00, 2.0692e+02, 2.9551e+04, 6.0303e+00, 3.3903e+02, 3.2971e+02,
9.1399e+00, 6.3766e+01, 3.6892e+00],
[9.9187e+00, 1.9918e+02, 2.1470e+04, 6.7998e+00, 3.2918e+02, 4.3051e+02,
1.5966e+01, 5.9292e+01, 3.3855e+00],
[7.4938e+00, 1.9733e+02, 2.6678e+04, 7.1984e+00, 2.6989e+02, 3.7550e+02,
1.3135e+01, 6.9591e+01, 3.8199e+00],
[5.8429e+00, 1.6830e+02, 1.9156e+04, 6.8783e+00, 3.3148e+02, 5.0676e+02,
1.4526e+01, 8.0424e+01, 4.1432e+00],
[6.3441e+00, 1.6482e+02, 1.4973e+04, 1.0707e+01, 3.1614e+02, 3.3722e+02,
1.9412e+01, 6.4385e+01, 3.8435e+00],
[8.5553e+00, 2.1666e+02, 1.8337e+04, 8.2907e+00, 3.1133e+02, 3.9094e+02,
1.7139e+01, 3.9777e+01, 3.6872e+00],
[7.2960e+00, 2.3574e+02, 3.6044e+04, 5.1962e+00, 3.7719e+02, 3.8561e+02,
1.7053e+01, 8.9624e+01, 4.1685e+00],
[7.5758e+00, 2.0388e+02, 2.0855e+04, 8.1093e+00, 3.3403e+02, 5.3231e+02,
1.4236e+01, 7.4639e+01, 3.1764e+00],
[6.0447e+00, 1.5067e+02, 1.3594e+04, 6.4562e+00, 4.0174e+02, 3.9221e+02,
1.9827e+01, 4.3564e+01, 4.9149e+00],
[6.9752e+00, 1.7542e+02, 3.5701e+04, 5.4943e+00, 2.9005e+02, 4.0106e+02,
1.0285e+01, 6.6421e+01, 5.2565e+00],
[8.9778e+00, 1.9900e+02, 2.0226e+04, 7.5695e+00, 3.5269e+02, 4.9210e+02,
1.9622e+01, 6.4177e+01, 3.2000e+00],
[6.3324e+00, 1.8684e+02, 2.3073e+04, 8.0820e+00, 3.2698e+02, 2.3391e+02,
9.6414e+00, 6.0940e+01, 5.1590e+00],
[6.8105e+00, 2.0974e+02, 3.2602e+04, 7.4228e+00, 3.4117e+02, 3.4003e+02,
1.6737e+01, 4.2349e+01, 4.4023e+00],
[9.3186e+00, 3.1734e+02, 2.4498e+04, 7.5975e+00, 3.5717e+02, 4.7651e+02,
1.2032e+01, 6.8600e+01, 4.6427e+00],
[6.9578e+00, 2.1924e+02, 2.0216e+04, 7.0541e+00, 3.0665e+02, 4.3137e+02,
1.7427e+01, 5.6436e+01, 4.5877e+00],
[6.7025e+00, 2.0732e+02, 1.7247e+04, 7.7081e+00, 3.0451e+02, 3.2927e+02,
1.6217e+01, 2.8879e+01, 3.4430e+00],
[3.7301e+00, 2.3030e+02, 1.6893e+04, 6.9972e+00, 3.2352e+02, 4.5691e+02,
1.0342e+01, 4.7096e+01, 4.9430e+00],
[5.0338e+00, 1.5532e+02, 3.4972e+04, 7.1215e+00, 3.2012e+02, 5.0064e+02,
1.8312e+01, 6.3193e+01, 3.2449e+00],
[5.6677e+00, 2.2993e+02, 1.6954e+04, 8.7743e+00, 2.9357e+02, 5.5412e+02,
1.4255e+01, 5.4437e+01, 3.6332e+00],
[7.2748e+00, 1.9512e+02, 2.1497e+04, 6.5711e+00, 3.6070e+02, 4.1837e+02,
1.1383e+01, 8.1236e+01, 4.2716e+00],
[1.2247e+01, 2.1737e+02, 1.1318e+04, 8.4652e+00, 3.7589e+02, 3.4765e+02,
9.7625e+00, 7.3832e+01, 3.5332e+00],
[7.9733e+00, 2.3750e+02, 2.3518e+04, 5.3545e+00, 2.8358e+02, 4.7889e+02,
1.5260e+01, 5.3671e+01, 2.8261e+00],
[8.6313e+00, 1.6437e+02, 1.4881e+04, 7.2783e+00, 3.5095e+02, 4.4411e+02,
1.6857e+01, 5.8178e+01, 3.6409e+00],
[6.1574e+00, 1.5584e+02, 2.5938e+04, 8.1631e+00, 2.9819e+02, 5.3329e+02,
1.4356e+01, 6.8121e+01, 4.7702e+00],
[6.6439e+00, 1.5189e+02, 1.0909e+04, 3.7496e+00, 2.4094e+02, 4.3791e+02,
1.5265e+01, 6.4204e+01, 3.8130e+00],
[8.0401e+00, 2.2486e+02, 6.8798e+03, 8.1369e+00, 4.1896e+02, 3.6096e+02,
1.2406e+01, 7.3218e+01, 3.9865e+00],
[7.8151e+00, 1.9031e+02, 2.0229e+04, 9.1869e+00, 3.3564e+02, 3.7922e+02,
1.4979e+01, 7.3425e+01, 3.0962e+00],
[6.7159e+00, 2.1806e+02, 1.7180e+04, 8.9160e+00, 3.9328e+02, 3.4900e+02,
1.4355e+01, 5.7247e+01, 2.4047e+00],
[5.4684e+00, 1.8038e+02, 1.2306e+04, 5.4466e+00, 4.1046e+02, 3.8875e+02,
1.2553e+01, 6.0628e+01, 4.9335e+00],
[8.6395e+00, 1.6861e+02, 7.0132e+03, 8.4516e+00, 3.5521e+02, 4.1979e+02,
1.7929e+01, 3.3642e+01, 4.1929e+00],
[4.8501e+00, 1.8668e+02, 3.2808e+04, 7.4963e+00, 2.9376e+02, 3.9270e+02,
1.0631e+01, 8.5158e+01, 3.8854e+00],
[4.4898e+00, 1.8825e+02, 1.4906e+04, 8.4860e+00, 3.7423e+02, 5.1859e+02,
1.1227e+01, 6.5183e+01, 3.7762e+00],
[5.2027e+00, 1.9533e+02, 2.3051e+04, 6.9596e+00, 2.4573e+02, 4.7355e+02,
1.1659e+01, 4.9522e+01, 3.9280e+00],
[7.8183e+00, 1.7922e+02, 2.7559e+04, 7.5045e+00, 3.1545e+02, 3.5910e+02,
1.2619e+01, 5.2377e+01, 3.3292e+00]]),
Labels:tensor([0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0.,
1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0.,
0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0.])
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.fc1 = nn.Linear(9,16)
self.fc2 = nn.Linear(16,8)
self.fc3 = nn.Linear(8,1)
def forward(self,x):
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = nn.functional.sigmoid(self.fc3(x))
return x
net = Net()criterion = nn.BCELoss()
optimizer = optim.RMSprop(net.parameters(),lr=0.01)epochs = 20
for epoch in range(epochs):
total_loss = 0.0
for features ,labels in dataloader_train:
optimizer.zero_grad()
outputs = net(features)
loss = criterion(outputs,labels.view(-1,1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(dataloader_train)}')
Epoch [1/20], Loss: 55.33805800647271
Epoch [2/20], Loss: 59.4013303431069
Epoch [3/20], Loss: 59.05543238942216
Epoch [4/20], Loss: 59.4013303431069
Epoch [5/20], Loss: 59.74722838983303
Epoch [6/20], Loss: 59.57427941299066
Epoch [7/20], Loss: 59.4013303431069
Epoch [8/20], Loss: 59.22838136626453
Epoch [9/20], Loss: 59.4013303431069
Epoch [10/20], Loss: 59.57427941299066
Epoch [11/20], Loss: 59.4013303431069
Epoch [12/20], Loss: 59.57427941299066
Epoch [13/20], Loss: 59.4013303431069
Epoch [14/20], Loss: 59.74722838983303
Epoch [15/20], Loss: 59.57427941299066
Epoch [16/20], Loss: 59.74722838983303
Epoch [17/20], Loss: 58.88248336605909
Epoch [18/20], Loss: 60.09312643655917
Epoch [19/20], Loss: 59.920177366675404
Epoch [20/20], Loss: 60.09312643655917
Code Explanations
Model Evaluation
!pip install torchmetricsCollecting torchmetrics
Downloading torchmetrics-1.4.0-py3-none-any.whl (868 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/868.8 kB ? eta -:--:--
━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 71.7/868.8 kB 2.0 MB/s eta 0:00:01
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 860.2/868.8 kB 12.3 MB/s eta 0:00:01
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 868.8/868.8 kB 10.1 MB/s eta 0:00:00
Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (1.25.2)
Requirement already satisfied: packaging>17.1 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (24.0)
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (2.2.1+cu121)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
Downloading lightning_utilities-0.11.2-py3-none-any.whl (26 kB)
Collecting pretty-errors==1.2.25 (from torchmetrics)
Downloading pretty_errors-1.2.25-py3-none-any.whl (17 kB)
Collecting colorama (from pretty-errors==1.2.25->torchmetrics)
Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (67.7.2)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (4.11.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.14.0)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.1.4)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (2023.6.0)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)
Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)
Collecting nvidia-nccl-cu12==2.19.3 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)
Collecting nvidia-nvtx-cu12==12.1.105 (from torch>=1.10.0->torchmetrics)
Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)
Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (2.2.0)
Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.10.0->torchmetrics)
Using cached nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->torchmetrics) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->torchmetrics) (1.3.0)
Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, colorama, pretty-errors, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics
Successfully installed colorama-0.4.6 lightning-utilities-0.11.2 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.4.127 nvidia-nvtx-cu12-12.1.105 pretty-errors-1.2.25 torchmetrics-1.4.0
from torchmetrics import Accuracyacc = Accuracy(task="binary")
net.eval()
with torch.no_grad():
for features,labels in dataloader_train:
outputs = net(features)
preds = (outputs>=0.5).float()
acc(preds,labels.view(-1,1))
test_accuracy = acc.compute()
print(f"Test accuracy:{test_accuracy}")Test accuracy:0.40328195691108704
Explanation
Vanishing and Exploding Gradients
Weights Initializations
layer = nn.Linear(8,1)
print(layer.weight)Parameter containing:
tensor([[-0.0200, 0.1698, 0.1512, 0.3476, 0.0478, -0.0884, -0.2626, 0.2111]],
requires_grad=True)
# weights initialization using He/Kaiming Initialization
import torch.nn.init as init
init.kaiming_uniform_(layer.weight)
print(layer.weight.detach().numpy())[[ 0.07571412 -0.857833 -0.7581954 -0.19864778 0.59920317 0.4359296
-0.38736093 -0.329777 ]]
# import torch
# import torch.nn.functional as F
# import math
# input_size=1200
# hidden_size=4000
# w=torch.randn(hidden_size,input_size)*math.sqrt(2/input_size)
# x=torch.randn(input_size,1)
# y=w@F.relu(x)
# x.var(),y.var()class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.fc1 = nn.Linear(9,16)
self.fc2 = nn.Linear(16,8)
self.fc3 = nn.Linear(8,1)
init.kaiming_uniform_(self.fc1.weight)
init.kaiming_uniform_(self.fc2.weight)
init.kaiming_uniform_(
self.fc3.weight,
nonlinearity="sigmoid",
)
def forward(self,x):
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = nn.functional.sigmoid(self.fc3(x))
return x
net = Net()
epochs = 20
for epoch in range(epochs):
total_loss = 0.0
for features ,labels in dataloader_train:
optimizer.zero_grad()
outputs = net(features)
loss = criterion(outputs,labels.view(-1,1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(dataloader_train)}')
acc = Accuracy(task="binary")
net.eval()
with torch.no_grad():
for features,labels in dataloader_train:
outputs = net(features)
preds = (outputs>=0.5).float()
acc(preds,labels.view(-1,1))
test_accuracy = acc.compute()
print(f"Test accuracy:{test_accuracy}")Epoch [1/20], Loss: 39.977170153361996
Epoch [2/20], Loss: 40.841915130615234
Epoch [3/20], Loss: 39.45832312979349
Epoch [4/20], Loss: 39.63127212989621
Epoch [5/20], Loss: 39.45832312979349
Epoch [6/20], Loss: 39.80422112999893
Epoch [7/20], Loss: 40.32306810704673
Epoch [8/20], Loss: 40.15011913020437
Epoch [9/20], Loss: 40.32306810704673
Epoch [10/20], Loss: 40.15011913020437
Epoch [11/20], Loss: 39.63127212989621
Epoch [12/20], Loss: 40.15011913020437
Epoch [13/20], Loss: 40.15011913020437
Epoch [14/20], Loss: 39.97717019988269
Epoch [15/20], Loss: 40.32306810704673
Epoch [16/20], Loss: 40.32306810704673
Epoch [17/20], Loss: 40.15011917672506
Epoch [18/20], Loss: 39.977170153361996
Epoch [19/20], Loss: 40.32306810704673
Epoch [20/20], Loss: 40.15011913020437
Test accuracy:0.5967180728912354
Batch Normalization
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.fc1 = nn.Linear(9,16)
self.bn1 = nn.BatchNorm1d(16)
self.fc2 = nn.Linear(16,8)
self.fc3 = nn.Linear(8,1)
init.kaiming_uniform_(self.fc1.weight)
init.kaiming_uniform_(self.fc2.weight)
init.kaiming_uniform_(
self.fc3.weight,
nonlinearity="sigmoid",
)
def forward(self,x):
x = self.fc1(x)
x = self.bn1(x)
x = nn.functional.elu(x)
x = nn.functional.elu(self.fc2(x))
x = nn.functional.sigmoid(self.fc3(x))
return x
net = Net()
epochs = 20
for epoch in range(epochs):
total_loss = 0.0
for features ,labels in dataloader_train:
optimizer.zero_grad()
outputs = net(features)
loss = criterion(outputs,labels.view(-1,1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(dataloader_train)}')
acc = Accuracy(task="binary")
net.eval()
with torch.no_grad():
for features,labels in dataloader_train:
outputs = net(features)
preds = (outputs>=0.5).float()
acc(preds,labels.view(-1,1))
test_accuracy = acc.compute()
print(f"Test accuracy:{test_accuracy}")Epoch [1/20], Loss: 0.676278037268941
Epoch [2/20], Loss: 0.6801634590800215
Epoch [3/20], Loss: 0.6774980760202175
Epoch [4/20], Loss: 0.6785745780642439
Epoch [5/20], Loss: 0.6793669302289079
Epoch [6/20], Loss: 0.6797820867561712
Epoch [7/20], Loss: 0.6769404280476454
Epoch [8/20], Loss: 0.6800835961248817
Epoch [9/20], Loss: 0.6798731190402333
Epoch [10/20], Loss: 0.6800505766054479
Epoch [11/20], Loss: 0.6773186674932155
Epoch [12/20], Loss: 0.6798041811803492
Epoch [13/20], Loss: 0.6802630497188102
Epoch [14/20], Loss: 0.6796207835034627
Epoch [15/20], Loss: 0.6793888763683599
Epoch [16/20], Loss: 0.6800950373091349
Epoch [17/20], Loss: 0.6793506639759715
Epoch [18/20], Loss: 0.6784993773553429
Epoch [19/20], Loss: 0.6791951525502089
Epoch [20/20], Loss: 0.6790863086537617
Test accuracy:0.5967180728912354