程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2023-10(1)

解决错误 RuntimeError: cuda runtime error (710) : device-side assert triggered a

发布于2020-03-14 18:57     阅读(876)     评论(0)     点赞(23)     收藏(1)


在github上看别人的代码,用别人的数据集跑通了,满心欢喜的换自己的数据集,修改了一番后,发现遇到了莫名其妙的错误,如下

Traceback (most recent call last):
  File "train_discriminator.py", line 167, in <module>
    main()
  File "train_discriminator.py", line 105, in main
    loss_D.backward()
  File "//anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: cuda runtime error (710) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1573049306803/work/aten/src/ATen/native/cuda/SoftMax.cu:647

这里的log里似乎没有什么关键信息,搜了一下,说切换到cpu运行就可以看到一些信息了

切换到cpu,遇到如下错误

File "train_discriminator.py", line 167, in <module>
    main()
  File "train_discriminator.py", line 100, in main
    loss_s = F.cross_entropy(y_disc_real, labels)
  File "/python3.7/site-packages/torch/nn/functional.py", line 2009, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1838, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at /opt/conda/conda-bld/pytorch_1573049306803/work/aten/src/THNN/generic/ClassNLLCriterion.c:97

这样我就想到原因了,是因为使用pytorch的torchtext加载数据集,其中测试时遇到了在训练集中没有的label,这样就造成了错误,提醒我们要注意分层抽样,特别是在某一类别的数量特别少时。

 

 


 



所属网站分类: 技术文章 > 博客

作者:坚持就是胜利

链接:https://www.pythonheidong.com/blog/article/259449/4ef5b24c06875c815c0f/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

23 0
收藏该文
已收藏

评论内容:(最多支持255个字符)