logistic回归的参数梯度更新方法的个人理解
创始人
2024-06-01 05:13:03
0

logistic回归参数更新看了几篇博文,感觉理解不透彻,所以自己写一下,希望能有更深的理解。logistic回归输入是一个线性函数Wx+b\boldsymbol{W}\boldsymbol{x}+\boldsymbol{b}Wx+b,为了简单理解,考虑batchsize为1的情况。这时输入x\boldsymbol{x}x为一个n×1n\times1n×1的向量,标签y\boldsymbol{y}y我们采用oneHot编码为一个m×1m\times1m×1的向量,显然\boldsymbol{b}也是一个m×1m\times1m×1的向量,参数W\boldsymbol{W}W为一个m×nm\times nm×n的矩阵。若n=4n=4n=4、m=3m=3m=3,我们用图形表示logistic回归如下:
在这里插入图片描述
这里的标签y\boldsymbol{y}y采用onehot编码,长度为3,如果类别编号为1,则其编码为{1,0,0}T\{1,0,0\}^T{1,0,0}T,对应上图的话,就是y∗1=1y_*^1=1y∗1​=1,y∗2=0y_*^2=0y∗2​=0,y∗3=0y_*^3=0y∗3​=0。损失函数LLL就是y1y^1y1和y∗1y_*^1y∗1​的交叉熵损失+y2y^2y2和y∗2y_*^2y∗2​的交叉熵损失+y3y^3y3和y∗3y_*^3y∗3​的交叉熵损失。
L=∑i=13y∗ilog⁡yi=y∗1log⁡y1+y∗2log⁡y2+y∗3log⁡y3\begin{aligned} L&=\sum_{i=1}^3y^i_*\log{y^i}\\ &=y^1_*\log{y^1}+y^2_*\log{y^2}+y^3_*\log{y^3} \end{aligned} L​=i=1∑3​y∗i​logyi=y∗1​logy1+y∗2​logy2+y∗3​logy3​
上式中:
y1=ez1ez1+ez2+ez3y2=ez2ez1+ez2+ez3y3=ez3ez1+ez2+ez3\begin{aligned} y^1&=\frac{e^{z^1}}{e^{z^1}+e^{z^2}+e^{z^3}}\\ y^2&=\frac{e^{z^2}}{e^{z^1}+e^{z^2}+e^{z^3}}\\ y^3&=\frac{e^{z^3}}{e^{z^1}+e^{z^2}+e^{z^3}}\\ \end{aligned} y1y2y3​=ez1+ez2+ez3ez1​=ez1+ez2+ez3ez2​=ez1+ez2+ez3ez3​​

z1=w1Tx+b1z2=w2Tx+b2z3=w3Tx+b3\begin{aligned} z^1=\boldsymbol{w_1}^T \boldsymbol{x}+b_1\\ z^2=\boldsymbol{w_2}^T \boldsymbol{x}+b_2\\ z^3=\boldsymbol{w_3}^T \boldsymbol{x}+b_3 \end{aligned} z1=w1​Tx+b1​z2=w2​Tx+b2​z3=w3​Tx+b3​​
其中,w1={w11,w12,w13,w14}T\boldsymbol{w_1}=\{w_{11},w_{12},w_{13},w_{14}\}^Tw1​={w11​,w12​,w13​,w14​}T,x={x1,x2,x3,x4}T\boldsymbol{x}=\{x_{1},x_{2},x_{3},x_{4}\}^Tx={x1​,x2​,x3​,x4​}T因此:

损失函数LLL对w1\boldsymbol{w_1}w1​求导:
∂L∂w1=∂L∂y1∂y1∂z1∂z1∂w1+∂L∂y2∂y2∂z1∂z1∂w1+∂L∂y3∂y3∂z1∂z1∂w1=y1∗y1×y1(1−y1)×x−y2∗y2×y1y2×x−y3∗y3×y1y3×x=(y1∗(1−y1)−y2∗y1−y3∗y1)x=(y1∗−y1(y1∗+y2∗+y3∗))x=(y1∗−y1)x\begin{aligned} \frac{\partial L}{\partial \boldsymbol{w_1}}&=\frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial z^1}\frac{\partial z^1}{\partial \boldsymbol{w_1}}+\frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial z^1}\frac{\partial z^1}{\partial \boldsymbol{w_1}}+\frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial z^1}\frac{\partial z^1}{\partial \boldsymbol{w_1}}\\ &=\frac{y_1^*}{y_1}\times y_1(1-y_1)\times \boldsymbol{x}-\frac{y_2^*}{y_2}\times y_1y_2\times \boldsymbol{x}-\frac{y_3^*}{y_3}\times y_1y_3\times \boldsymbol{x}\\ &=(y_1^*(1-y_1)-y_2^*y_1-y_3^*y_1)\boldsymbol{x}\\ &=(y_1^*-y_1(y_1^*+y_2^*+y_3^*))\boldsymbol{x}\\ &=(y_1^*-y_1)\boldsymbol{x}\\ \end{aligned} ∂w1​∂L​​=∂y1​∂L​∂z1∂y1​​∂w1​∂z1​+∂y2​∂L​∂z1∂y2​​∂w1​∂z1​+∂y3​∂L​∂z1∂y3​​∂w1​∂z1​=y1​y1∗​​×y1​(1−y1​)×x−y2​y2∗​​×y1​y2​×x−y3​y3∗​​×y1​y3​×x=(y1∗​(1−y1​)−y2∗​y1​−y3∗​y1​)x=(y1∗​−y1​(y1∗​+y2∗​+y3∗​))x=(y1∗​−y1​)x​
注意(y1∗+y2∗+y3∗)(y_1^*+y_2^*+y_3^*)(y1∗​+y2∗​+y3∗​)是标签onehot编码的三个值,和正好为1。同理可得到剩下的两个导数:
∂L∂w2=(y2∗−y2)x∂L∂w3=(y3∗−y3)x\frac{\partial L}{\partial \boldsymbol{w_2}} = (y_2^*-y_2)\boldsymbol{x}\\ \frac{\partial L}{\partial \boldsymbol{w_3}} = (y_3^*-y_3)\boldsymbol{x} ∂w2​∂L​=(y2∗​−y2​)x∂w3​∂L​=(y3∗​−y3​)x
交叉熵损失函数LLL关于w\boldsymbol{w}w的梯度为:
[(y1∗−y1)x1(y2∗−y2)x1(y3∗−y3)x1(y1∗−y1)x2(y2∗−y2)x2(y3∗−y3)x2(y1∗−y1)x3(y2∗−y2)x3(y3∗−y3)x3(y1∗−y1)x4(y2∗−y2)x4(y3∗−y3)x4(y1∗−y1)x5(y2∗−y2)x5(y3∗−y3)x5]T\left[ \begin{aligned} &(y_1^*-y_1)x1&(y_2^*-y_2)x1\space\space\space\space&(y_3^*-y_3)x1\\ &(y_1^*-y_1)x2&(y_2^*-y_2)x2\space\space\space\space&(y_3^*-y_3)x2\\ &(y_1^*-y_1)x3&(y_2^*-y_2)x3\space\space\space\space&(y_3^*-y_3)x3\\ &(y_1^*-y_1)x4&(y_2^*-y_2)x4\space\space\space\space&(y_3^*-y_3)x4\\ &(y_1^*-y_1)x5&(y_2^*-y_2)x5\space\space\space\space&(y_3^*-y_3)x5\\ \end{aligned} \right]^T ​​(y1∗​−y1​)x1(y1∗​−y1​)x2(y1∗​−y1​)x3(y1∗​−y1​)x4(y1∗​−y1​)x5​(y2∗​−y2​)x1    (y2∗​−y2​)x2    (y2∗​−y2​)x3    (y2∗​−y2​)x4    (y2∗​−y2​)x5    ​(y3∗​−y3​)x1(y3∗​−y3​)x2(y3∗​−y3​)x3(y3∗​−y3​)x4(y3∗​−y3​)x5​​T
这样交叉熵损失函数LLL关于w\boldsymbol{w}w的梯度用numpy的外积计算表示为:
∂L∂w=numpy.outer(x,y∗−y)\frac{\partial L}{\partial \boldsymbol{w}}=numpy.outer(\boldsymbol{x},\boldsymbol{y^*}-\boldsymbol{y}) ∂w∂L​=numpy.outer(x,y∗−y)
用同样的方法可以推导出:
∂L∂b=y∗−y\frac{\partial L}{\partial \boldsymbol{b}}=\boldsymbol{y^*}-\boldsymbol{y} ∂b∂L​=y∗−y

相关内容

热门资讯

监控摄像头接入GB28181平... 流程简介将监控摄像头的视频在网站和APP中直播,要解决的几个问题是:1&...
Windows10添加群晖磁盘... 在使用群晖NAS时,我们需要通过本地映射的方式把NAS映射成本地的一块磁盘使用。 通过...
protocol buffer... 目录 目录 什么是protocol buffer 1.protobuf 1.1安装  1.2使用...
在Word、WPS中插入AxM... 引言 我最近需要写一些文章,在排版时发现AxMath插入的公式竟然会导致行间距异常&#...
【PdgCntEditor】解... 一、问题背景 大部分的图书对应的PDF,目录中的页码并非PDF中直接索引的页码...
修复 爱普生 EPSON L4... L4151 L4153 L4156 L4158 L4163 L4165 L4166 L4168 L4...
Fluent中创建监测点 1 概述某些仿真问题,需要创建监测点,用于获取空间定点的数据࿰...
educoder数据结构与算法...                                                   ...
MySQL下载和安装(Wind... 前言:刚换了一台电脑,里面所有东西都需要重新配置,习惯了所...
MFC文件操作  MFC提供了一个文件操作的基类CFile,这个类提供了一个没有缓存的二进制格式的磁盘...