Skip to content
Snippets Groups Projects
Commit f2764877 authored by RyanBHC's avatar RyanBHC
Browse files

Adding nonlinear projection in self attention

parent 82633454
Branches master
No related tags found
No related merge requests found
...@@ -83,7 +83,12 @@ class CausalSelfAttention(nn.Module): ...@@ -83,7 +83,12 @@ class CausalSelfAttention(nn.Module):
# key, query, value projections for all heads, but in a batch # key, query, value projections for all heads, but in a batch
# self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) # self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
# output projection # output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd) #self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.c_proj = nn.Sequential(
nn.Linear(config.n_embd, config.n_embd),
nn.ReLU(),
nn.Linear(config.n_embd, config.n_embd)
)
# regularization # regularization
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop)
...@@ -102,8 +107,8 @@ class CausalSelfAttention(nn.Module): ...@@ -102,8 +107,8 @@ class CausalSelfAttention(nn.Module):
y[mask_tokens] = 0 y[mask_tokens] = 0
# output projection # output projection
# y = self.resid_dropout(self.c_proj(y)) y = self.resid_dropout(self.c_proj(y))
y = self.resid_dropout(y) # y = self.resid_dropout(y)
return y return y
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment