Commit 845fa5f3 authored by Shawn Nithyan Stanley's avatar Shawn Nithyan Stanley
Browse files

Replace activation_layer.c

parent 9d0f47f1
......@@ -23,7 +23,41 @@ matrix forward_activation_layer(layer l, matrix x)
// logistic(x) = 1/(1+e^(-x))
// relu(x) = x if x > 0 else 0
// lrelu(x) = x if x > 0 else .01 * x
// softmax(x) = e^{x_i} / sum(e^{x_j}) for all x_j in the same row
// softmax(x) = e^{x_i} / sum(e^{x_j}) for all x_j in the same row
if (a == LOGISTIC) {
scal_matrix(-1, y);
for (int i = 0; i < y.rows; i++) {
for (int j = 0; j < y.cols; j++) {
y.data[i * y.cols + j] = 1 / (1 + exp(y.data[i * y.cols + j]));
}
}
} else if (a == RELU) {
for (int i = 0; i < y.rows; i++) {
for (int j = 0; j < y.cols; j++) {
if (y.data[i * y.cols + j] <= 0) {
y.data[i * y.cols + j] = 0;
}
}
}
} else if (a == LRELU) {
for (int i = 0; i < y.rows; i++) {
for (int j = 0; j < y.cols; j++) {
if (y.data[i * y.cols + j] <= 0) {
y.data[i * y.cols + j] = 0.01 * y.data[i * y.cols + j];
}
}
}
} else if (a == SOFTMAX) {
for (int i = 0; i < y.rows; i++) {
float sum = 0;
for (int j = 0; j < y.cols; j++) {
sum += exp(y.data[i * y.cols + j]);
}
for (int j = 0; j < y.cols; j++) {
y.data[i * y.cols + j] = exp(y.data[i * y.cols + j]) / sum;
}
}
}
return y;
}
......@@ -47,7 +81,56 @@ matrix backward_activation_layer(layer l, matrix dy)
// d/dx relu(x) = 1 if x > 0 else 0
// d/dx lrelu(x) = 1 if x > 0 else 0.01
// d/dx softmax(x) = 1
matrix x_prime = copy_matrix(x);
if (a == LOGISTIC) {
scal_matrix(-1, x_prime);
for (int i = 0; i < x.rows; i++) {
for (int j = 0; j < x.cols; j++) {
x_prime.data[i * x.cols + j] = 1 / (1 + exp(x_prime.data[i * x.cols + j]));
}
}
matrix other = copy_matrix(x_prime);
for (int i = 0; i < x_prime.rows; i++) {
for (int j = 0; j < x_prime.cols; j++) {
x_prime.data[i * x_prime.cols + j] *= 1 - other.data[i * x_prime.cols + j];
}
}
for (int i = 0; i < x_prime.rows; i++) {
for (int j = 0; j < x_prime.cols; j++) {
dx.data[i * x_prime.cols + j] *= x_prime.data[i * x_prime.cols + j];
}
}
} else if (a == RELU) {
for (int i = 0; i < x_prime.rows; i++) {
for (int j = 0; j < x_prime.cols; j++) {
if (x_prime.data[i * x_prime.cols + j] <= 0) {
x_prime.data[i * x_prime.cols + j] = 0;
} else {
x_prime.data[i * x_prime.cols + j] = 1;
}
}
}
for (int i = 0; i < x_prime.rows; i++) {
for (int j = 0; j < x_prime.cols; j++) {
dx.data[i * x_prime.cols + j] *= x_prime.data[i * x_prime.cols + j];
}
}
} else if (a == LRELU) {
for (int i = 0; i < x_prime.rows; i++) {
for (int j = 0; j < x_prime.cols; j++) {
if (x_prime.data[i * x_prime.cols + j] <= 0) {
x_prime.data[i * x_prime.cols + j] = 0.01;
} else {
x_prime.data[i * x_prime.cols + j] = 1;
}
}
}
for (int i = 0; i < x_prime.rows; i++) {
for (int j = 0; j < x_prime.cols; j++) {
dx.data[i * x_prime.cols + j] *= x_prime.data[i * x_prime.cols + j];
}
}
}
return dx;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment