pytorch中tensorboard的使用

1.首先介绍一下tensorboard

  TensorBoard是一个可视化工具,它可以用来展示网络图、张量的指标变化、张量的分布情况等。特别是在训练网络的时候,我们可以设置不同的参数(比如:权重W、偏置B、卷积层数、全连接层数等),使用TensorBoader可以很直观的帮我们进行参数的选择。它通过运行一个本地服务器,来监听6006端口。在浏览器发出请求时,分析训练时记录的数据,绘制训练过程中的图像。

2.如何用tensorboard实现Pytorch的模型结构可视化

1.1安装tensorboard(pytorch)

conda install tensorboardX
conda install tensorboard
conda install tensorflow

至于为什么要安装这三个我是参考其他大神,我也不懂。

1.2代码实现

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('save') #建立一个保存数据用的东西,save是输出的文件名
dummy_input = torch.rand(512, 1, 28, 28)  # 网络中输入的数据维度
with SummaryWriter(comment='LeNet') as w:
    w.add_graph(net, (dummy_input,))  # net是你的网络名

添加完上述代码后,运行程序后程序里会出现下列文件夹,如果是在服务器上运行,到服务器上的代码里查看是否存在下列文件夹。

如果是在本地跑的,直接按照下面步骤继续。

在pycharm里面的终端窗口输入


tensorboard --logdir = C:\Users\huangxin1\PycharmProjects\untitled\runs

 注意你要进入你运行代码的虚拟环境里面输入上述代码,runs的路径是写自己的runs路径。

然后会出来下面这个:

TensorBoard 2.6.0 at http://localhost:6006/ (Press CTRL+C to quit)

点击网址,就可以显示自己的网络模型的结构了。

 如果是在服务器上跑的代码则需要本地访问远程服务器上的tensorboard,方法如下:

适用情况框架是pytorch,需要Xshell工具(其他工具也可以,我用的Xshell)。

打开Xhell,新建会话属性。

 主机号和端口号为你所使用的服务器的主机号以及端口号。随后点击隧道,按照我的设置,照搬就行。

 随后点击链接即可

 然后进入到你的虚拟环境,运行以下代码

(pytorch) WeiWB@shuguang:~$ tensorboard --logdir="/home/WeiWB/code/Conv-TasNet-master/src/runs" --port=6006

 然后复制给的网址http://localhost:6006 d到浏览器运行就可以了。

 。至此本地访问远程服务器上的tensorboard教程已经完毕。

3.用tensorboard查看loss以及其他参数的变换

直接放代码:

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("runs/logs_fina")  # 存放log文件的目录
writer.add_scalar('train/loss', ave_train_loss, epoch)  # 画loss,横坐标为epoch
writer.add_scalar('train/lr', ave_lr, epoch)
writer.close()

该有小伙伴疑惑,这个代码的位置应该放在哪儿呢?放在产生每个训练loss值的后面就可以了

来源:在学习的魏同学

物联沃分享整理
物联沃-IOTWORD物联网 » pytorch中tensorboard的使用

发表评论