Neweret commited on
Commit
3edc6be
·
verified ·
1 Parent(s): 1be121d

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -20
model.py DELETED
@@ -1,20 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class SimpleClassifier(nn.Module):
6
- def __init__(self, input_dim, num_classes, p_dropout=0.3):
7
- super().__init__()
8
- self.linear1 = nn.Linear(input_dim, 256)
9
- self.ln1 = nn.LayerNorm(256)
10
- self.dropout = nn.Dropout(p_dropout)
11
- self.linear2 = nn.Linear(256, 128)
12
- self.ln2 = nn.LayerNorm(128)
13
- self.linear_out = nn.Linear(128, num_classes)
14
-
15
- def forward(self, x):
16
- x = F.gelu(self.ln1(self.linear1(x)))
17
- x = self.dropout(x)
18
- x = F.gelu(self.ln2(self.linear2(x)))
19
- x = self.dropout(x)
20
- return self.linear_out(x)