引文 Gradio是一个是用友好的web界面演示机器学习模型的最快方法,它的操作非常简便,很方便上手。
MNIST是几乎每个接触机器学习的同学都使用过的数据集,内含0~9共10个数字的上千张手写图片,其每张图片的大小均为28px*28px
Pytorch是一款非常方便的深度学习库,可以轻松搭建深度神经网络。
本文将使用Pytorch训练卷积神经网络(CNN)来进行手写数字识别,然后使用Gradio中的手写板功能输入手写数字,进行识别测试。在文章末尾可以下载本文中的全部代码(ipynb格式)
现状 目前网络上的中文教程大多局限于搭建MNIST识别模型,并使用数据集的内建测试集进行测试,并未涉及自行输入图片进行测试,而使用Gradio进行可视化展示的更是少之又少。
Gradio官方文档中的使用方法利用了外部模型,并未涉及自行训练。而外网大部分教程都使用tensorflow。
流程 导入必备包 1 2 3 4 5 6 import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transformsimport gradio
设定一些超参数 1 2 3 BATCH_SIZE=512 #大概需要2G的显存 EPOCHS=10 # 总共训练批次 DEVICE = torch.device("mps") # mps, cuda or cpu
超参数可以理解为决定模型如何进行训练的设置参数。BATCH_SIZE代表一次进入训练的图片数量;EPOCHS代表训练多少个周期;DEVICE代表使用什么硬件进行训练——本文使用了MacBook的M1芯片,因此选择mps。N卡用户请选择cuda,其余可使用cpu进行训练(或使用其他核心,在此不详述)。
加载内建的训练数据和测试数据 1 2 3 4 5 6 7 8 9 10 11 12 13 14 train_loader = torch.utils.data.DataLoader( datasets.MNIST('data' , train=True , download=True , transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307 ,), (0.3081 ,)) ])), batch_size=BATCH_SIZE, shuffle=True ) test_loader = torch.utils.data.DataLoader( datasets.MNIST('data' , train=False , transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307 ,), (0.3081 ,)) ])), batch_size=BATCH_SIZE, shuffle=True )
定义卷积神经网络 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class ConvNet (nn.Module): def __init__ (self ): super ().__init__() self.conv1=nn.Conv2d(1 ,10 ,5 ) self.conv2=nn.Conv2d(10 ,20 ,5 ) self.fc1=nn.Linear(20 *8 *8 ,640 ) self.fc2=nn.Linear(640 ,10 ) def forward (self,x): in_size = x.size(0 ) out=self.conv1(x) out=F.relu(out) out=F.max_pool2d(out,2 ,2 ) out = self.conv2(out) out = F.relu(out) out = out.view(in_size, -1 ) out=self.fc1(out) out=self.fc2(out) out = F.log_softmax(out, dim=1 ) return out
神经网络使用了卷积层--->relu激活层--->2*2最大池化层--->卷积层--->relu激活层--->两个全连接层--->softmax
激活层的结构。
定义损失函数和优化器 1 2 model=ConvNet().to(DEVICE) optimizer=optim.Adam(model.parameters())
封装训练和测试函数 1 2 3 4 5 6 7 8 9 10 11 12 13 def train (model, device, train_loader, optimizer, epoch ): model.train() for batch_idx, (data, target) in enumerate (train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output=model(data) loss=F.nll_loss(output,target) loss.backward() optimizer.step() if (batch_idx + 1 ) % 30 == 0 : print ('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}' .format ( epoch, batch_idx * len (data), len (train_loader.dataset), 100. * batch_idx / len (train_loader), loss.item()))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 def test (model,device,test_loader ): model.eval () test_loss=0 correct=0 with torch.no_grad(): for data,target in test_loader: data,target=data.to(device),target.to(device) output=model(data) test_loss += F.nll_loss(output, target, reduction='sum' ).item() pred = output.max (1 , keepdim=True )[1 ] correct += pred.eq(target.view_as(pred)).sum ().item() test_loss /= len (test_loader.dataset) print ('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n' .format ( test_loss, correct, len (test_loader.dataset), 100. * correct / len (test_loader.dataset)))
开始训练 1 2 3 4 for epoch in range (EPOCHS): train(model,DEVICE,train_loader,optimizer,epoch) test(model,DEVICE,test_loader)
输出样例:
Train Epoch: 0 [14848/60000 (25%)] Loss: 0.327558 Train Epoch: 0 [30208/60000 (50%)] Loss: 0.129759 Train Epoch: 0 [45568/60000 (75%)] Loss: 0.135278
Test set: Average loss: 0.0902, Accuracy: 9718/10000 (97%)
加载Gradio 定义一个预测函数 不同于之前版本可以加载pytorch模型,当前版本的Gradio必须自行书写函数传入。在Gradio启动后,手写板中书写的数字将会以单通道(B/W)的形式传入预测函数中。通过type( )
函数查看,该图片为numpy类,因此需要先将其通过transforms.ToTensor()
将其转化为pytorch的张量tensor形式。随后,拓展图片维度以适应神经网络要求,随后将其送入DEVICE中进行推理。代码如下:
1 2 3 4 5 6 7 8 9 10 def predict (inp ): img = transforms.ToTensor()(inp) img_ = img.unsqueeze(0 ) img_ = img_.to(DEVICE) output = model(img_) pred_index = int (torch.argmax(output, dim=1 )) return pred_index
启动Gradio 1 2 3 inp = gradio.Sketchpad() io = gradio.Interface(fn=predict,inputs=inp, outputs="text" ,live=True ) io.launch()
gradio.Sketchpad()
是指加载Gradio的手写板,将其接收到的笔划赋值给inp,gradio.Interface()
是指启动Gradio界面,fn代表使用的函数,这里用到了上面的predict()
函数,即输入inp,输出pred_index,输出方式为“text”文本格式。
代码下载 使用Colab 在线运行:(需要连接Google)
https://colab.research.google.com/drive/1CgcdxfgQkHth98IqHo3uQlMFMeN0bM5C#scrollTo=0zrX_u-0KN6e
国内云盘下载:
gradio.ipynb: https://url80.ctfile.com/f/35431880-763654876-9c0f8c?p=9119 (密码:9119)
参考 [1] 加载外部图片进行测试
[2] 旧版Gradio的使用方式 (现在已经不再适用)
[3] Gradio官网
[4] pytorch 中文 手册