程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2022-06(11)

2022-07(1)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices

发布于2023-05-20 15:56     阅读(377)     评论(0)     点赞(14)     收藏(4)


运行以下代码时报错:RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_cat)

这行代码在函数定义里,且当时loss和model都有.to(device)的操作

  1. encoder_Z_distr = self.encoder_result(batch_x) #从batch-x中生成Z
  2. #第一个生成的Z,包含Z的过去和现在
  3. to_decoder = self.sample_encoder_Z(batch=batch_x)
  4. padding_num = torch.zeros((1, 2*self.dimZ), dtype=float)
  5. to_decoder1 = to_decoder.view(batch_size, -1)
  6. to_decoder_I1 = torch.cat((to_decoder1[:batch_size-1,], to_decoder1[1:batch_size,]), axis=1)
  7. to_decoder_I1 = torch.cat((to_decoder_I1, padding_num), axis = 0)
  8. z2_row2 = torch.randn(batch_size-1, 2*self.dimZ)
  9. to_decoder_I2 = torch.cat((to_decoder1[:batch_size-1,], z2_row2), axis=1)
  10. to_decoder_I2 = torch.cat((to_decoder_I2, padding_num), axis = 0)
  11. decoder_logits_mean1 = torch.mean(self.decoder_logits(to_decoder_I1), dim=0)#lower bound L1的计算
  12. decoder_logits_mean2 = torch.mean(self.decoder_logits(to_decoder_I2), dim=0)#lower bound L2的计算

在网上找了很久都没有找到原因,报错的地方在cat:

即在在数据拼接的时候,即一个数据在GPU0上,一个数据在GPU1上,这就会出现错误

解决办法:

一定要注意,产生随机数的地方也要用.to(device)

  1. padding_num = torch.zeros((1, 2*self.dimZ), dtype=float).to(device)
  2. z2_row2 = torch.randn(batch_size-1, 2*self.dimZ).to(device)

哪些数据需要放到gpu上训练:

可以参考这篇文章https://blog.csdn.net/qimo601/article/details/123822178

  1. import torch
  2. import time
  3. #1.通常用法
  4. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  5. data = data.to(device)
  6. model = model.to(device)
  7. '''
  8. 1.先创建device对象
  9. 2.to函数指定数据或者模型放到哪里
  10. '''
  11. #2.将构建的tensor或者模型放到指定设备上(GPU)
  12. torch.device('cuda',0) #这里的0指的是设备的序号
  13. torch.device('cuda:0')
  14. #3.例子 cpu转到GPU
  15. s = time.time()
  16. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  17. a = torch.rand([1024, 1024, 10]).to(device)
  18. print('time:{:6.3f}'.format(time.time()-s)) # 输出: time: 0.087
  19. #4.例子2 Tensor GPU 转到cpu运行
  20. predict_boxes = predictions["boxes"].to("cpu").numpy()
  21. predict_classes = predictions["labels"].to("cpu").numpy()
  22. predict_scores = predictions["scores"].to("cpu").numpy()

to(device)需要注意的是:

使用GPU训练的时候,需要将Module对象和Tensor类型的数据送入到device。通常会使用 to.(device)。但是需要注意的是:

对于Tensor类型的数据,使用to.(device) 之后,需要接收返回值,返回值才是正确设置了device的Tensor。

对于Module对象,只用调用to.(device) 就可以将模型设置为指定的device。不必接收返回值。

  1. # Module对象设置device的写法
  2. model.to(device)
  3. # Tensor类型的数据设置 device 的写法。
  4. samples = samples.to(device)

其他cat可能会遇到的错误:

在多GPU训练时,遇到了下述的错误:
1. Expected tensor for argument 1 'input' to have the same device as tensor for argument 2 'weight'; but device 0 does not equal 1
2. RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

造成这个错误的可能性有挺多,总起来是模型、输入、模型内参数不在一个GPU上。本人是在调试RandLA-Net pytorch源码,希望使用双GPU训练,经过尝试解决这个问题,此处做一个记录,希望给后来人一个提醒。

经过调试,发现报错的地方主要是在数据拼接的时候,即一个数据在GPU0上,一个数据在GPU1上,这就会出现错误,相关代码如下:
return torch.cat((
self.mlp(concat),
features.expand(B, -1, N, K)
), dim=-3)

上述代码中,必须保证self.mlp(concat)与features.expand(B, -1, N, K)在同一个GPU中。在多GPU运算时,features(此时是输入变量)有可能放在任何一个GPU中,因此此处在拼接前,获取一下features的GPU,然后将concat放入相应的GPU中,再进行数据拼接就可以了,代码如下:
device = features.device
concat = concat.to(device)
return torch.cat((
self.mlp(concat),
features.expand(B, -1, N, K)
), dim=-3)

该源码中默认状态下device是一个固定的值,在多GPU训练状态下就会报错,代码中还有几处数据融合,大家可以依据上述思路做修改。此外该源码中由于把device的值写死了,训练好的模型也必须在相应的GPU中做推理,如在cuda0中训练的模型如果在cuda1中推理就会报错,各位可以依据此思路对源码做相应的修改。

原文链接: https://blog.csdn.net/weixin_41496173/article/details/119789280

原文链接:https://blog.csdn.net/Viviane_2022/article/details/128638452



所属网站分类: 技术文章 > 博客

作者:fhue34873

链接:https://www.pythonheidong.com/blog/article/1979397/37c5055143968aa556fc/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

14 0
收藏该文
已收藏

评论内容:(最多支持255个字符)