其他
PyTorch 教程-图像识别神经网络的验证
在训练部分,我们使用MNIST数据集(无穷数据集)对模型进行了训练,并且似乎达到了合理的损失和准确性。如果模型能够将其学到的知识应用于新数据并进行泛化,那么这将是其性能的真正体现。以下是验证模型的步骤:
步骤1:
validation_dataset=datasets.MNIST(root='./data',train=False,download=True,transform=transform1)
步骤2:
validation_loader=torch.utils.data.DataLoader(dataset=validation_dataset,batch_size=100,shuffle=False)
步骤3:
val_loss_history=[]
val_correct_history=[]
步骤4:
在接下来的步骤中,我们将验证模型。模型将在同一epoch中验证。在完成迭代整个训练集以训练数据之后,我们现在将在整个验证集上进行迭代,以测试我们的数据。
val_loss=0.0
val_correct=0.0
步骤5:
for val_input,val_labels in validation_loader:
步骤6:
当我们迭代图像的批次时,我们必须将它们展平,并使用view方法进行形状变换。
注意:每个图像张量的形状是(1,28,28),这意味着总共有784个像素。
val_inputs=val_input.view(val_input.shape[0],-1)
val_outputs=model(val_inputs)
步骤7:
val_loss1=criteron(val_outputs,val_labels)
with torch.no_grad():
这将临时将所有requires_grad标志设置为False。
步骤8:
_,val_preds=torch.max(val_outputs,1)
val_loss+=val_loss1.item()
val_correct+=torch.sum(val_preds==val_labels.data)
步骤9:
val_epoch_loss=val_loss/len(validation_loader)
val_epoch_acc=correct.float()/len(validation_loader)
val_loss_history.append(val_epoch_loss)
val_correct_history.append(epoch_acc)
步骤10:
print('validation_loss:{:.4f},{:.4f}'.format(val_epoch_loss,val_epoch_acc.item()))
这将产生预期的结果,如下所示:
步骤11:
plt.plot(loss_history,label='Training Loss')
plt.plot(val_loss_history,label='Validation Loss')
plt.legend()
plt.show()
plt.plot(correct_history,label='Training accuracy')
plt.plot(val_correct_history,label='Validation accuracy')
plt.legend()
plt.show()
完整代码
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as func
from torch import nn
from torchvision import datasets,transforms
transform1=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
training_dataset=datasets.MNIST(root='./data',train=True,download=True,transform=transform1)
validation_dataset=datasets.MNIST(root='./data',train=False,download=True,transform=transform1)
training_loader=torch.utils.data.DataLoader(dataset=training_dataset,batch_size=100,shuffle=True)
validation_loader=torch.utils.data.DataLoader(dataset=validation_dataset,batch_size=100,shuffle=False)
def im_convert(tensor):
image=tensor.clone().detach().numpy()
image=image.transpose(1,2,0)
print(image.shape)
image=image*(np.array((0.5,0.5,0.5))+np.array((0.5,0.5,0.5)))
image=image.clip(0,1)
return image
class classification1(nn.Module):
def __init__(self,input_layer,hidden_layer1,hidden_layer2,output_layer):
super().__init__()
self.linear1=nn.Linear(input_layer,hidden_layer1)
self.linear2=nn.Linear(hidden_layer1,hidden_layer2)
self.linear3=nn.Linear(hidden_layer2,output_layer)
def forward(self,x):
x=func.relu(self.linear1(x))
x=func.relu(self.linear2(x))
x=self.linear3(x)
return x
model=classification1(784,125,65,10)
criteron=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.0001)
epochs=12
loss_history=[]
correct_history=[]
val_loss_history=[]
val_correct_history=[]
for e in range(epochs):
loss=0.0
correct=0.0
val_loss=0.0
val_correct=0.0
for input,labels in training_loader:
inputs=input.view(input.shape[0],-1)
outputs=model(inputs)
loss1=criteron(outputs,labels)
optimizer.zero_grad()
loss1.backward()
optimizer.step()
_,preds=torch.max(outputs,1)
loss+=loss1.item()
correct+=torch.sum(preds==labels.data)
else:
with torch.no_grad():
for val_input,val_labels in validation_loader:
val_inputs=val_input.view(val_input.shape[0],-1)
val_outputs=model(val_inputs)
val_loss1=criteron(val_outputs,val_labels)
_,val_preds=torch.max(val_outputs,1)
val_loss+=val_loss1.item()
val_correct+=torch.sum(val_preds==val_labels.data)
epoch_loss=loss/len(training_loader)
epoch_acc=correct.float()/len(training_loader)
loss_history.append(epoch_loss)
correct_history.append(epoch_acc)
val_epoch_loss=val_loss/len(validation_loader)
val_epoch_acc=correct.float()/len(validation_loader)
val_loss_history.append(val_epoch_loss)
val_correct_history.append(epoch_acc)
print('training_loss:{:.4f},{:.4f}'.format(epoch_loss,epoch_acc.item()))
print('validation_loss:{:.4f},{:.4f}'.format(val_epoch_loss,val_epoch_acc.item()))
plt.plot(correct_history,label='Correct history ')
plt.plot(val_correct_history,label='Validation correct history')
plt.legend()
plt.show()
plt.plot(correct_history,label='Training accuracy')
plt.plot(val_correct_history,label='Validation accuracy')
plt.legend()
plt.show()