This example uses the Jupyter Notebook and python to understand the feedforward and backpropagation methods in an ANN.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

248 lines
5.3 KiB

1 year ago
  1. # Exam
  2. $$ \frac{\partial E}{\partial w_{jk}}= -e_j\cdot \sigma\left(\sum_i w_{ij} o\_i\right) \left(1-\sigma\left(\sum\_i w\_{ij} o\_i\right) \right) o\_i $$
  3. $$w_{New} = w_{old}-\alpha \frac{\partial E}{\partial w} $$
  4. ## Doing step by step
  5. ```
  6. def af(x):
  7. sigmoid = 1/(1+np.exp(-x))
  8. return sigmoid
  9. import numpy as np
  10. inN = 3
  11. hiN= 4
  12. outN=3
  13. lr= 0.4
  14. #weight W11 W21 W31
  15. # W12 W22 W32
  16. # .....
  17. np.random.seed(53)
  18. wih=np.random.rand(hiN, inN)-0.5
  19. who=np.random.rand(outN, hiN)-0.5
  20. print("Wih: ", wih)
  21. print("Who: ", who)
  22. Wih: [[ 0.34666241 0.06116554 -0.0451246 ]
  23. [-0.14782509 0.08585138 0.03574974]
  24. [ 0.32745628 -0.2354578 -0.02162094]
  25. [-0.15221498 -0.36552168 -0.24002265]]
  26. Who: [[-0.45236532 -0.1057067 -0.12838381 0.05673292]
  27. [ 0.39749455 -0.33265411 -0.09279358 0.15235334]
  28. [ 0.06774908 0.06651886 0.0243551 0.10758002]]
  29. ```
  30. ## Feedforward
  31. ```
  32. inputList = [0.32, 0.27, 0.18]
  33. inputs = np.array(inputList, ndmin=2).T
  34. Xh = np.dot(wih, inputs)
  35. print('Xh: ', Xh)
  36. Oh = af(Xh)
  37. print('Oh:', Oh)
  38. # computing output
  39. Xo = np.dot(who, Oh)
  40. print('Xo: ', Xo)
  41. Oo = af(Xo)
  42. print('Oo: ', Oo)
  43. Xh: [[ 0.11932424]
  44. [-0.0176892 ]
  45. [ 0.03732063]
  46. [-0.19060372]]
  47. Oh: [[0.52979571]
  48. [0.49557782]
  49. [0.50932908]
  50. [0.45249281]]
  51. Xo: [[-0.33176547]
  52. [ 0.06741123]
  53. [ 0.12994239]]
  54. Oo: [[0.41781112]
  55. [0.51684643]
  56. [0.53243996]]
  57. ```
  58. ## Backpropagation
  59. ```
  60. inputList = [0.32, 0.27, 0.18]
  61. targetList = [0.82, 0.25, 0.44]
  62. inputs = np.array(inputList, ndmin=2).T
  63. target = np.array(targetList, ndmin=2).T
  64. #computting hidden layer
  65. Xh = np.dot(wih, inputs)
  66. Oh = af(Xh)
  67. # computing output
  68. Xo = np.dot(who, Oh)
  69. Oo = af(Xo)
  70. # Output error
  71. oe = target-Oo
  72. # E propagation
  73. hiddenE = np.dot(who.T, oe)
  74. # updating weights
  75. #who+=lr*np.dot(oe*Oo*(1-Oo), Oh.T)
  76. #wih+=lr*np.dot(hiddenE*Oh*(1-Oh), inputs.T)
  77. #print('New wih: ', wih)
  78. #print('New who: ', who)
  79. NewW=who-lr*np.dot(-oe*Oo*(1-Oo),Oh.T)
  80. NewW
  81. array([[-0.43163327, -0.08631366, -0.10845266, 0.07443995],
  82. [ 0.38337319, -0.34586342, -0.10636942, 0.14029244],
  83. [ 0.06287227, 0.06195702, 0.01966668, 0.10341479]])
  84. newWho=who-lr*np.dot(-oe*Oo*(1-Oo), Oh.T)
  85. newWho
  86. array([[-0.43163327, -0.08631366, -0.10845266, 0.07443995],
  87. [ 0.38337319, -0.34586342, -0.10636942, 0.14029244],
  88. [ 0.06287227, 0.06195702, 0.01966668, 0.10341479]])
  89. ```
  90. ## Using class
  91. ```
  92. import numpy as np
  93. class NeuralNetwork:
  94. # init method
  95. def __init__(self, inputN,hiddenN, outputN, lr):
  96. # creates a NN with three layers (input, hidden, output)
  97. # inputN - Number of input nodes
  98. # hiddenN - Number of hidden nodes
  99. self.inN=inputN
  100. self.hiN=hiddenN
  101. self.outN=outputN
  102. self.lr=lr
  103. #weight W11 W21 W31
  104. # W12 W22 W32
  105. # .....
  106. np.random.seed(53)
  107. self.wih=np.random.rand(self.hiN, self.inN)-0.5
  108. self.who=np.random.rand(self.outN,self.hiN)-0.5
  109. print("Wih: ", self.wih)
  110. print("Who: ", self.who)
  111. pass
  112. # NN computing method
  113. def feedforward(self, inputList):
  114. # computing hidden output
  115. inputs = np.array(inputList, ndmin=2).T
  116. self.Xh = np.dot(self.wih, inputs)
  117. print('Xh: ', self.Xh)
  118. self.af = lambda x:1/(1+np.exp(-x))
  119. self.Oh = self.af(self.Xh)
  120. print('Oh:', self.Oh)
  121. # computing output
  122. self.Xo = np.dot(self.who, self.Oh)
  123. print('Xo: ', self.Xo)
  124. self.Oo = self.af(self.Xo)
  125. print('Oo: ', self.Oo)
  126. pass
  127. # NN trainning method
  128. def backpropagation(self, inputList, targetList):
  129. # data
  130. lr = self.lr
  131. inputs = np.array(inputList, ndmin=2).T
  132. target = np.array(targetList, ndmin=2).T
  133. #computting hidden layer
  134. Xh = np.dot(self.wih, inputs)
  135. af = lambda x:1/(1+np.exp(-x))
  136. Oh = af(Xh)
  137. # computing output
  138. Xo = np.dot(self.who, Oh)
  139. Oo = af(Xo)
  140. # Output error
  141. oe = target-Oo
  142. # E propagation
  143. hiddenE = np.dot(self.who.T, oe)
  144. # updating weights
  145. self.who+=lr*np.dot(oe*Oo*(1-Oo), Oh.T)
  146. self.wih+=lr*np.dot(hiddenE*Oh*(1-Oh), inputs.T)
  147. return self.wih, self.who
  148. ```
  149. ```
  150. NN = NeuralNetwork(3,4,3,0.4)
  151. Wih: [[ 0.34666241 0.06116554 -0.0451246 ]
  152. [-0.14782509 0.08585138 0.03574974]
  153. [ 0.32745628 -0.2354578 -0.02162094]
  154. [-0.15221498 -0.36552168 -0.24002265]]
  155. Who: [[-0.45236532 -0.1057067 -0.12838381 0.05673292]
  156. [ 0.39749455 -0.33265411 -0.09279358 0.15235334]
  157. [ 0.06774908 0.06651886 0.0243551 0.10758002]]
  158. ```
  159. ```
  160. NN.feedforward([0.32, 0.27, 0.18])
  161. Xh: [[ 0.11932424]
  162. [-0.0176892 ]
  163. [ 0.03732063]
  164. [-0.19060372]]
  165. Oh: [[0.52979571]
  166. [0.49557782]
  167. [0.50932908]
  168. [0.45249281]]
  169. Xo: [[-0.33176547]
  170. [ 0.06741123]
  171. [ 0.12994239]]
  172. Oo: [[0.41781112]
  173. [0.51684643]
  174. [0.53243996]]
  175. ```
  176. ```
  177. NN.backpropagation([0.32, 0.27, 0.18], [0.82, 0.25, 0.44])
  178. (array([[ 0.33727924, 0.05324849, -0.05040263],
  179. [-0.14654184, 0.08693412, 0.03647157],
  180. [ 0.32652462, -0.23624388, -0.02214499],
  181. [-0.15309598, -0.36626503, -0.24051822]]),
  182. array([[-0.43163327, -0.08631366, -0.10845266, 0.07443995],
  183. [ 0.38337319, -0.34586342, -0.10636942, 0.14029244],
  184. [ 0.06287227, 0.06195702, 0.01966668, 0.10341479]]))
  185. ```