+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

暂无数据

输入大小为200的(-1,784)形状之间的冲突

发布于2021-04-11 15:54     阅读(2511)     评论(0)     点赞(20)     收藏(1)


0

1

2

3

4

5

张量的形状及其输入大小存在误差,因为这是相互矛盾的。我完全不知道该怎么办。因为我仍然是这个主题的新手,所以这就是为什么您可能会或可能不会发现根本不需要的代码行的原因。只是抬起头。如果需要,请亲自与我联系,我将把.ipynb文件发送给您

如果重要的话,我正在使用MNIST handrawn数字数据集进行计算机视觉。

import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import SubsetRandomSampler, DataLoader

dataset = MNIST(root='data/', download=True)

def split_indices(dataset, rate):
    eval = int(dataset*rate/100)
    index = np.random.permutation(dataset)
    return index[eval:], index[:eval]
 
train_index, eval_index = split_indices(len(dataset), rate=20)

dataset = MNIST(root='data/', train=True, transform=transforms.ToTensor())

train_sampler = SubsetRandomSampler(train_index)
train_dl = DataLoader(dataset, batch_size=200, sampler=train_sampler)

val_sampler = SubsetRandomSampler(eval_index)
val_dl = DataLoader(dataset, batch_size=200, sampler=eval_sampler)
 
inputs = 28*28
nums = 10
model = nn.Linear(inputs, nums)
 
class MnistModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(inputs, nums)

        def forward(self, xb):
            xb = xb.reshape(-1, 784)
            outputs = self.linear(xb)
            return outputs

        def accuracy(x, y):
            print(torch.sum(x == y).item()/len(x))

model = MnistModel()

for images, labels in train_dl:
        outputs = model(images)
        break

loss_fn = F.cross_entropy
loss = loss_fn(outputs, labels)
opt = torch.optim.Adam(model.parameters(), lr=7)

def loss_batch(model, loss_fn, xb, yb, opt=None, metric=None):
    preds = model(xb)
    loss = loss_fn(preds, yb)
 
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
 
    metric_result = None
    if metric is not None:
        metric_result = metric(preds, yb)
 
    return loss.item(), len(xb), metric_result
 
def evaluate(model, loss_fn, valid_dl, metric=None):
    with torch.no_grad():
        results = [loss_batch(model, loss_fn, xb, yb, metric=metric) for xb, yb in valid_dl]
 
        losses, nums, metrics = zip(*results)
        total = np.sum(nums)
        avg_loss = np.sum(np.multiply(losses, nums)) / total
 
        avg_metric = None
        if metric is not None:
            avg_metric = np.sum(np.multiply(metrics, nums)) / total
 
        return avg_loss, total, avg_metric
 
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.sum(preds == labels).item() / len(preds)

eval_loss, total, eval_acc = evaluate(model, loss_fn, val_dl, metric=accuracy)
print(f"loss: {eval_loss}, accuracy: {eval_acc*100}")

def fit(epochs, model, loss_fn, opt, train_dl, valid_dl, metric=None):
    for epoch in range(epochs):
        for xb, yb in train_dl:
            loss,_,_ = loss_batch(model, xb, yb, opt)

        result = evaluate(model, loss_fn, valid_dl, metric)
        eval_loss, total, eval_metric = result

        if metric is None:
            print(f"Epoch: {epoch+1}, loss: {loss.item()}")

        else:
            print(f"Epoch: {epoch+1}, loss: {loss.item()}, metric: {metric.__name__} {eval_metric}")

model = MnistModel()

opt = torch.optim.Adam(model.parameters(), lr=7)
fit(5, model, loss_fn, opt, train_dl, eval_dl, accuracy) #Error line

错误输出:

RuntimeError                              Traceback (most recent call last)

<ipython-input-55-90c5585d3b40> in <module>()
      1 opt = torch.optim.Adam(model.parameters(), lr=7)
----> 2 fit(5, model, loss_fn, opt, train_dl, eval_dl, accuracy)

3 frames

<ipython-input-49-afd130f584e4> in forward(self, xb)
     18 
     19     def forward(self, xb):
---> 20         xb = xb.reshape(-1, 784)
     21         outputs = self.linear(xb)
     22         return outputs

RuntimeError: shape '[-1, 784]' is invalid for input of size 200

解决方案


不要对炬管使用重塑形状。

使用torch.nn.Flatten()展平您的图像。这在您的程序中看起来是一致的。

0

1

2

3

4

5

6

7

8



所属网站分类: 技术文章 > 问答

作者:黑洞官方问答小能手

链接: https://www.pythonheidong.com/blog/article/933326/fb8c7287a94c417ade08/

来源: python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

20 0
收藏该文
已收藏

评论内容:(最多支持255个字符)