pytorch之求梯度和nn.Linear的理解
创始人
2024-03-31 01:10:22
0

文章目录

  • 前言
  • pytorch求梯度
    • 例1
    • 例2
  • nn.Linear()
  • 总结

前言

在这里插入图片描述
假设有这么个简单的神经网络结构。
这篇文章可以说是机器学习之神经网络的公式推导与python代码(手写+pytorch)实现

补充。

pytorch求梯度

例1

pytorch里的数据结构称之为张量(tensor),因为能够自动求导,方便我们模型的计算。
因为计算梯度会自动保留梯度,隐藏需要及时清空梯度才能得到正确的结果。

import torch
from torch.autograd import Variable
def f(x):y = x ** 3return y
# y = x ^ 3
# y' = 3 * x ^ 2 def main1():x = Variable(torch.tensor([5.0]), requires_grad=True)y = f(x)y.backward(retain_graph=True)print(x.grad) # 75y.backward(retain_graph=True)print(x.grad) # 75 + 75x.grad.data.zero_() # 清空梯度y.backward()print(x.grad) # 75
main1()

输出
在这里插入图片描述

例2

import torch
from torch.autograd import Variable
def f(x):y =  1 / xreturn y# y = 1 / x
# y' = - 1 / x^2
def main1():x = Variable(torch.tensor([2.0]), requires_grad=True)y = f(x)y.backward(retain_graph=True)print(x.grad) # 在x=2这点的y的导数= -1/2^2 = -0.25y.backward(retain_graph=True)print(x.grad) # 因为上一次没清空,所以会累加x.grad.data.zero_() # 清空梯度y.backward()print(x.grad) # -0.25
main1()

输出:
在这里插入图片描述

nn.Linear()

这个就是我们经常使用的全连接层了。
其实就是一组权重。
我们可以通过示例理解一下。

import torch.nn as nn
import torch as th
demo = nn.Linear(2, 3, bias=False)
print(demo.weight)

首先定义一个全连接层实例,输入为2输出为3
打印一下权重(因为随机初始化,每次可能不一样)
在这里插入图片描述
可以发现这就是一个2x3的权重矩阵。能够将2维全连接转成3维(5x2 输入得出 5x3(5x2 2x3 ))

然后我们输入数据看看:

data = th.tensor([1.0, 1.0])
res = demo(data)
print(res)

将数据输入全连接层,可以得到如下输出:
在这里插入图片描述

其实就是经过了个矩阵乘法得到的,我们可以用numpy验证一下。

首先我们将全连接层的权重取出来转换成numpy数组,然后我们都知道np.dot就是矩阵乘法,这样就能够进行验证了

import numpy as np
demo2 = demo.weight.detach().numpy()
print(demo2)
print(demo2.shape)
data2 = data.detach().numpy()
print(np.dot(demo2, data2))

输出结果:
在这里插入图片描述
可以看到,输出结果和我们刚开始的nn.Linear的结果是一样,因此这个类其实就是相当于一个权重矩阵,传入的数据通过矩阵乘法得到输出的结果。

总结

在这里插入图片描述
那么这个网络用pytorch如何实现呢?
那就非常简单了,两个nn.Linear就行了:


class network(nn.Module):def __init__(self, input_size, hidden_size, output_size):super().__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_sizeself.w1 = nn.Linear(input_size, hidden_size, bias=False)self.w2 = nn.Linear(hidden_size, output_size, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):i2h = self.w1(x)i2h = self.sigmoid(i2h)h2o = self.w2(i2h)h2o = self.sigmoid(h2o)return h2o

其中w1就相当于vij的系数矩阵
w2就相当于wjk的系数矩阵
然后forward就是进行正向传播计算。
反向传播的话就是后面计算出损失后,调用backward一句话就好了。

应该比手写简介且好理解吧。

相关内容

热门资讯

监控摄像头接入GB28181平... 流程简介将监控摄像头的视频在网站和APP中直播,要解决的几个问题是:1&...
Windows10添加群晖磁盘... 在使用群晖NAS时,我们需要通过本地映射的方式把NAS映射成本地的一块磁盘使用。 通过...
protocol buffer... 目录 目录 什么是protocol buffer 1.protobuf 1.1安装  1.2使用...
在Word、WPS中插入AxM... 引言 我最近需要写一些文章,在排版时发现AxMath插入的公式竟然会导致行间距异常&#...
Fluent中创建监测点 1 概述某些仿真问题,需要创建监测点,用于获取空间定点的数据࿰...
educoder数据结构与算法...                                                   ...
MySQL下载和安装(Wind... 前言:刚换了一台电脑,里面所有东西都需要重新配置,习惯了所...
MFC文件操作  MFC提供了一个文件操作的基类CFile,这个类提供了一个没有缓存的二进制格式的磁盘...
有效的括号 一、题目 给定一个只包括 '(',')','{','}'...
【PdgCntEditor】解... 一、问题背景 大部分的图书对应的PDF,目录中的页码并非PDF中直接索引的页码...