问题描述:
最近一直在使用pytorch, 由于深度学习的网络往往需要设置验证集来验证模型是否稳定.
我一直再做一个关于医学影像分割的课题,为了查看自己的模型是否稳定,于是设置了验证集.
但是在运行的过程中,当程序执行到 validatioon时,显存立即上升,我可怜的显卡只有8GB显存,瞬间爆炸.
怎么办呢?实验得做呀.于是找了不少方法,比如设置各个网络变量requires_grad=False,但是并不管用,显存依然爆炸.
后来百度了一番,终于解决了显存爆炸的问题.
解决方案:
假设训练程序是这样的:
for train_data, train_label in train_dataloader:
do
trainning
then
for valid_data,valid_label in valid_dataloader:
do
validtion
当程序执行到validation时,显存忽然上升,几乎是之前的两倍.
只需要这样改:
for train_data, train_label in train_dataloader:
do
trainning
then
with torch.no_grad():
for valid_data,valid_label in valid_dataloader:
do
validtion
当程序执行到validation时,显存将不再上升.问题得到解决.真的是非常简单.