Model Class definition:
Base Class for all NN modules:
torch.nn.Module(...)
All models should subclass this class.
Most basic network
import torch
import torch.nn as nn
class BasicNet(nn.Module):
def __init__(self):
super(BasicNet, self).__init__()
self.net1 = nn.Linear(4, 3)
def forward(self, x):
return self.net(x)
Every Layer is in fact a nn.Module, be it a linear layer, a convolutional layer or other. This makes sense, because it is simply an input vector and an output vector. The amount of or type of layers in between do not matter for the class definition.
class GarmentClassifier(nn.Module):
def __init__(self):
super(GarmentClassifier, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(16 * 4 * 4, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10)
)
Running it
output = model(input)
Input data format
Since PyTorch is working with tensors, the input data needs to be in tensor format. Furthermore we want to input the data in batches to accelerate training.
You look at the forward method to figure out the input shape.
Example 1: Linear Net:
class BasicNet(nn.Module):
def __init__(self):
super(BasicNet, self).__init__()
self.net = nn.Linear(4, 3)
def forward(self, x):
return self.net(x)
Here the important part is the first Layer: nn.Linear(4, 3)
. The input is of dimension 4, the output of dimension 3.
torch.float32
The following input would work:
input = torch.tensor(np.random.rand(10), dtype=torch.float32)
model(input)
Example 2: Conv Net:
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
The important part for the input is the first layer: nn.Conv2d(1, 20, 5)
**Accepted input format: ** [in_channels, height, width]
Note that the height and width are required to be bigger than the respective kernel sizes.
The following input would work:
input = torch.tensor(np.random.rand(1, 10, 10), dtype=torch.float32)
model(input)
More Model types:
Further steps
*These steps are more clearly defined in 01. MOC CS knowledge base
Pytorch Model training
Hyperparameter Search Optuna
K-Fold cross validation
python, ROC Curve, AUC
Model Training Acceleration
Model Debugging (pytorch))