博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现
阅读量:6268 次
发布时间:2019-06-22

本文共 4099 字,大约阅读时间需要 13 分钟。

GRU(Gated Recurrent Unit) 更新过程推导及简单代码实现

RNN网络考虑到了具有时间数列的样本数据,但是RNN仍存在着一些问题,比如随着时间的推移,RNN单元就失去了对很久之前信息的保存和处理的能力,而且存在着gradient vanishing问题。

所以有些特殊类型的RNN网络相继被提出,比如LSTM(long short term memory)和GRU(gated recurrent unit)().这里我主要推导一下GRU参数的迭代过程

GRU单元结构如下图所示

enter description here
1479126283494.jpg

数据流过程如下

1027162-20161117222055092-1837147551.png

其中1027162-20161117222055373-1685413489.png表示Hadamard积,即对应元素乘积;下标表示节点的index,上标表示时刻;1027162-20161117222055623-1280427293.png表示隐层到输出层的参数矩阵,1027162-20161117222055857-33496224.png分别是隐层和输出层的节点个数;1027162-20161117222056107-299089892.png分别表示输入和上一时刻隐层到更新门z的连接矩阵,1027162-20161117222056404-903553224.png表示输入数据的维度;1027162-20161117222056763-2055975632.png分别表示输入和上一时刻隐层到重置门r的连接矩阵;1027162-20161117222057076-1889152704.png分别表示输入和上一时刻的隐层到待选状态1027162-20161117222057451-400138159.png的连接矩阵。

针对于时刻t,使用链式求导法则,计算参数矩阵的梯度,其中E是代价函数,首先计算对隐层输出的梯度,因为隐层输出牵涉到多个时刻

1027162-20161117222057795-1209891899.png

所以

1027162-20161117222058076-2071167850.png

其中1027162-20161117222058357-1940450245.png分别是对应激活函数的线性和部分

现在对参数计算梯度

1027162-20161117222058748-1060872356.png

1027162-20161117222059123-1167300414.png

1027162-20161117222059451-1655288241.png

将上面的式子矢量化(行向量)表示:

1027162-20161117222059701-26260650.png
1027162-20161117222059967-2048645455.png

那接下来使用matlab来实现一个小例子,看看GRU的效果,同样是二进制相加的问题

  1. function error= GRUtest( ) 

  2. % 初始化训练数据 

  3. uNum=16;%单元个数 

  4. maxInt=2^uNum; 

  5. % 初始化网络结构 

  6. xdim=2

  7. ydim=1

  8. hdim=16

  9. eta=0.1

  10. %初始化网络参数 

  11. Wy=rand(hdim,ydim)*2-1

  12. Wr=rand(xdim,hdim)*2-1

  13. Ur=rand(hdim,hdim)*2-1

  14. W =rand(xdim,hdim)*2-1

  15. U =rand(hdim,hdim)*2-1

  16. Wz=rand(xdim,hdim)*2-1

  17. Uz=rand(hdim,hdim)*2-1

  18.  

  19. rvalues=zeros(uNum+1,hdim); 

  20. zvalues=zeros(uNum+1,hdim); 

  21. hbarvalues=zeros(uNum,hdim); 

  22. hvalues = zeros(uNum,hdim); 

  23. yvalues=zeros(uNum,ydim); 

  24.  

  25. for p=1:10000 

  26. aInt=randi(maxInt/2); 

  27. bInt=randi(maxInt/2); 

  28. cInt=aInt+bInt; 

  29. at=dec2bin(aInt)-'0'

  30. bt=dec2bin(bInt)-'0'

  31. ct=dec2bin(cInt)-'0'

  32. a=zeros(1,uNum); 

  33. b=zeros(1,uNum); 

  34. c=zeros(1,uNum); 

  35. a(1:size(at,2))=at(end:-1:1); 

  36. b(1:size(bt,2))=bt(end:-1:1); 

  37. c(1:size(ct,2))=ct(end:-1:1); 

  38. xvalues=[a;b]'

  39. d=c'

  40.  

  41. % 前向计算 

  42. rvalues(1,:)=sigmoid(xvalues(1,:)*Wr); 

  43. hbarvalues(1,:)=outTanh(xvalues(1,:)*W); 

  44. zvalues(1,:)=sigmoid(xvalues(1,:)*Wz); 

  45. hvalues(1,:)=zvalues(1,:).*hbarvalues(1,:); 

  46. yvalues(1,:)=sigmoid(hvalues(1,:)*Wy); 

  47. for t=2:uNum 

  48. rvalues(t,:)=sigmoid(xvalues(t,:)*Wr+hvalues(t-1,:)*Ur); 

  49. hbarvalues(t,:)=outTanh(xvalues(t,:)*W+(rvalues(t,:).*hvalues(t-1,:))*U); 

  50. zvalues(t,:)=sigmoid(xvalues(t,:)*Wz+hvalues(t-1,:)*Uz); 

  51. hvalues(t,:)=(1-zvalues(t,:)).*hvalues(t-1,:)+zvalues(t,:).*hbarvalues(t,:); 

  52. yvalues(t,:)=sigmoid(hvalues(t,:)*Wy);  

  53. end 

  54.  

  55. % 误差反向传播 

  56. delta_r_next=zeros(1,hdim); 

  57. delta_z_next=zeros(1,hdim); 

  58. delta_h_next=zeros(1,hdim); 

  59. delta_next=zeros(1,hdim); 

  60.  

  61. dWy=zeros(hdim,ydim); 

  62. dWr=zeros(xdim,hdim); 

  63. dUr=zeros(hdim,hdim); 

  64. dW=zeros(xdim,hdim); 

  65. dU=zeros(hdim,hdim); 

  66. dWz=zeros(xdim,hdim); 

  67. dUz=zeros(hdim,hdim); 

  68.  

  69. for t=uNum:-1:2 

  70. delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:)); 

  71. delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:)); 

  72. delta_z=delta_h.*(hbarvalues(t,:)-hvalues(t-1,:)).*diffsigmoid(zvalues(t,:)); 

  73. delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)); 

  74. delta_r=hvalues(t-1,:).*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:)); 

  75.  

  76. dWy=dWy+hvalues(t,:)'*delta_y; 

  77. dWz=dWz+xvalues(t,:)'*delta_z; 

  78. dUz=dUz+hvalues(t-1,:)'*delta_z; 

  79. dW =dW+xvalues(t,:)'*delta; 

  80. dU =dU+(rvalues(t,:).*hvalues(t-1,:))'*delta ; 

  81. dWr=dWr+xvalues(t,:)'*delta_r; 

  82. dUr=dUr+hvalues(t-1,:)'*delta_r; 

  83.  

  84. delta_r_next=delta_r; 

  85. delta_z_next=delta_z; 

  86. delta_h_next=delta_h; 

  87. delta_next =delta; 

  88.  

  89. end 

  90.  

  91. t=1

  92. delta_y=(yvalues(t,:)-d(t,:)).*diffsigmoid(yvalues(t,:)); 

  93. delta_h=delta_y*Wy'+delta_z_next*Uz'+delta_next*U'.*rvalues(t+1,:)+delta_r_next*Ur'+delta_h_next.*(1-zvalues(t+1,:)); 

  94. delta_z=delta_h.*(hbarvalues(t,:)-0).*diffsigmoid(zvalues(t,:)); 

  95. delta =delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)); 

  96. delta_r=0.*((delta_h.*zvalues(t,:).*diffoutTanh(hbarvalues(t,:)))*U').*diffsigmoid(rvalues(t,:)); 

  97.  

  98. dWy=dWy+hvalues(t,:)'*delta_y; 

  99. dWz=dWz+xvalues(t,:)'*delta_z; 

  100. dW =dW+xvalues(t,:)'*delta; 

  101. dWr=dWr+xvalues(t,:)'*delta_r; 

  102.  

  103. Wy = Wy-eta*dWy; 

  104. Wr = Wr-eta*dWr; 

  105. Ur = Ur-eta*dUr; 

  106. W = W -eta*dW; 

  107. U = U-eta*dU; 

  108. Wz = Wz-eta*dWz; 

  109. Uz = Uz-eta*dUz; 

  110. error = (norm(yvalues-d,2))/2.0

  111. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 

  112. if mod(p,500)==0 

  113. fprintf('******************第%s次迭代****************\n',int2str(p)); 

  114. yvalues=round(yvalues(end:-1:1)); 

  115. y=bin2dec(int2str(yvalues')); 

  116. fprintf('y=%d\n',y); 

  117. fprintf('c=%d\n',cInt); 

  118. fprintf('样本误差:e=%f\n',error); 

  119. end 

  120. end 

  121. end 

  122.  

  123. function f=sigmoid(x) 

  124. f=1./(1+exp(-x)); 

  125. end 

  126.  

  127. function fd = diffsigmoid(f) 

  128. fd=f.*(1-f); 

  129. end 

  130.  

  131. function g=outTanh(x) 

  132. g=1-2./(1+exp(2*x)); 

  133. end 

  134.  

  135. function gd=diffoutTanh(g) 

  136. gd=1-g.^2

  137. end 

部分实验结果

enter description here
1479392393541.jpg

转载于:https://www.cnblogs.com/YiXiaoZhou/p/6075777.html

你可能感兴趣的文章
添加浏览器的用户样式表
查看>>
LigerUI学习笔记之布局篇 layout
查看>>
LeetCode题解(二)
查看>>
Mybatis通用Mapper
查看>>
文件磁盘命令(就该这么学6章内容)
查看>>
2016-207-19 随笔
查看>>
java的double类型如何精确到一位小数?
查看>>
看看国外的javascript题目,你能全部做对吗?
查看>>
ffmpeg 如何选择具有相同AVCodecID的编解码器 (AVCodec)
查看>>
真正解决 Windows 中 Chromium “缺少 Google API 密钥” 的问题
查看>>
Spring 之 AOP
查看>>
软件项目管理|期末复习(二)
查看>>
直接调用VS.net2005中的配置界面
查看>>
程序员的自我修养五Windows PE/COFF
查看>>
关于字符集,编码格式,大小端的简单总结
查看>>
js string 转 int Number()
查看>>
课堂练习:ex 4-20
查看>>
20155328 2016-2017-2 《Java程序设计》 第8周学习总结
查看>>
python操作redis--string
查看>>
echarts图表初始大小问题及echarts随窗口变化自适应
查看>>