-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNode.java
More file actions
136 lines (109 loc) · 3.28 KB
/
Node.java
File metadata and controls
136 lines (109 loc) · 3.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
/**
* Created by Srikanth on 2/27/2017.
*/
public class Node {
private double outputX;
private double gradient;
private Layer layer;
private boolean visited;
private double targetValue;
private List<Edge> adjList;
private List<Edge> revAdjList;
private double bias;
public Node(Layer layer) {
this.layer = layer;
adjList = new ArrayList<>();
revAdjList = new ArrayList<>();
// randomly generate weight of bias connection
if (this.layer.getLayerType() != LayerType.INPUT) {
Random rand = new Random();
bias = rand.nextFloat() * 2 - 1;
}
}
public void performForwardPassCalculation() {
float sum = 0;
for (Edge prevLayerEdge : revAdjList) {
Node prevNode = prevLayerEdge.getDestination();
double outputValue = prevNode.getOutputX();
Double weight = prevLayerEdge.getWeight();
sum += outputValue * weight;
}
// add bias value
sum += bias;
outputX = Utility.getSigmoidValue(sum);
}
public void calculateOutputGradientValue() {
gradient = (targetValue - outputX) * sigmoidDerivative();
}
public void calculateGradientValue() {
double sum = 0.0;
for (Edge edge : adjList)
// Sum all the gradient*connection
sum += edge.getDestination().getGradient() * edge.getWeight();
// calculate gradient of this node
gradient = sum * sigmoidDerivative();
}
public void updateEdgesWeight() {
for (int i = 0;i < adjList.size();i++) {
Double newWeight = adjList.get(i).getWeight() + (ApplicationRunner.getLearningRate() * outputX *
adjList.get(i).getDestination().getGradient());
adjList.get(i).setWeight(newWeight);
adjList.get(i).getDestination().revAdjList.get(i).setWeight(newWeight);
System.out.print(" " + newWeight);
}
}
private double sigmoidDerivative() {
return outputX * (1 - outputX);
}
public void addEdge(Edge edge) {
adjList.add(edge);
}
public void addRevEdge(Edge edge) {
revAdjList.add(edge);
}
public double getOutputX() {
return outputX;
}
public void setOutputX(double outputX) {
this.outputX = outputX;
}
public double getGradient() {
return gradient;
}
public void setGradient(double gradient) {
this.gradient = gradient;
}
public Layer getLayer() {
return layer;
}
public void setLayer(Layer layer) {
this.layer = layer;
}
public boolean isVisited() {
return visited;
}
public void setVisited(boolean visited) {
this.visited = visited;
}
public double getTargetValue() {
return targetValue;
}
public void setTargetValue(double targetValue) {
this.targetValue = targetValue;
}
public List<Edge> getAdjList() {
return adjList;
}
public List<Edge> getRevAdjList() {
return revAdjList;
}
public double getBias() {
return bias;
}
public void setBias(double bias) {
this.bias = bias;
}
}