做網(wǎng)站賺錢嗎免費(fèi)數(shù)據(jù)查詢網(wǎng)站
?第一步:準(zhǔn)備數(shù)據(jù)
17種猴子動(dòng)物數(shù)據(jù):
self.class_indict = ["白頭卷尾猴", "彌猴", "山魈", "松鼠猴", "葉猴", "銀色絨猴", "印度烏葉猴", "疣猴", "侏絨","白禿猴", "赤猴", "滇金絲猴", "狒狒", "黑色吼猴", "黑葉猴", "金絲猴", "懶猴"],總共有1800張圖片,每個(gè)文件夾單獨(dú)放一種數(shù)據(jù)
第二步:搭建模型
本文選擇一個(gè)ShufflenetV2網(wǎng)絡(luò),其原理介紹如下:
shufflenet v2是曠視提出的shufflenet的升級(jí)版本,并被ECCV2018收錄。論文說在同等復(fù)雜度下,shufflenet v2比shufflenet和mobilenetv2更準(zhǔn)確。shufflenet v2是基于四條準(zhǔn)則對(duì)shufflenet v1進(jìn)行改進(jìn)而得到的,這四條準(zhǔn)則如下:
(G1)同等通道大小最小化內(nèi)存訪問量 對(duì)于輕量級(jí)CNN網(wǎng)絡(luò),常采用深度可分割卷積(depthwise separable convolutions),其中點(diǎn)卷積( pointwise convolution)即1x1卷積復(fù)雜度最大。這里假定輸入和輸出特征的通道數(shù)分別為C1和C2,經(jīng)證明僅當(dāng)C1=C2時(shí),內(nèi)存使用量(MAC)取最小值,這個(gè)理論分析也通過實(shí)驗(yàn)得到證實(shí)。更詳細(xì)的證明見參考【1】
(G2)過量使用組卷積會(huì)增加MAC 組卷積(group convolution)是常用的設(shè)計(jì)組件,因?yàn)樗梢詼p少?gòu)?fù)雜度卻不損失模型容量。但是這里發(fā)現(xiàn),分組過多會(huì)增加MAC。更詳細(xì)的證明見參考【1】
(G3)網(wǎng)絡(luò)碎片化會(huì)降低并行度 一些網(wǎng)絡(luò)如Inception,以及Auto ML自動(dòng)產(chǎn)生的網(wǎng)絡(luò)NASNET-A,它們傾向于采用“多路”結(jié)構(gòu),即存在一個(gè)lock中很多不同的小卷積或者pooling,這很容易造成網(wǎng)絡(luò)碎片化,減低模型的并行度,相應(yīng)速度會(huì)慢,這也可以通過實(shí)驗(yàn)得到證明。
(G4)不能忽略元素級(jí)操作 對(duì)于元素級(jí)(element-wise operators)比如ReLU和Add,雖然它們的FLOPs較小,但是卻需要較大的MAC。這里實(shí)驗(yàn)發(fā)現(xiàn)如果將ResNet中殘差單元中的ReLU和shortcut移除的話,速度有20%的提升。
根據(jù)前面的4條準(zhǔn)則,作者分析了ShuffleNet v1設(shè)計(jì)的不足,并在此基礎(chǔ)上改進(jìn)得到了ShuffleNetv2,兩者模塊上的對(duì)比下圖所示
第三步:訓(xùn)練代碼
1)損失函數(shù)為:交叉熵?fù)p失函數(shù)
2)訓(xùn)練代碼:
import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import shufflenet_v2_x1_0
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 實(shí)例化訓(xùn)練數(shù)據(jù)集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 實(shí)例化驗(yàn)證數(shù)據(jù)集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在預(yù)訓(xùn)練權(quán)重則載入model = shufflenet_v2_x1_0(num_classes=args.num_classes).to(device)if args.weights != "":if os.path.exists(args.weights):weights_dict = torch.load(args.weights, map_location=device)load_weights_dict = {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() == v.numel()}print(model.load_state_dict(load_weights_dict, strict=False))else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否凍結(jié)權(quán)重if args.freeze_layers:for name, para in model.named_parameters():# 除最后的全連接層外,其他權(quán)重全部?jī)鼋Y(jié)if "fc" not in name:para.requires_grad_(False)pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=4E-5)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# trainmean_loss = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateacc = evaluate(model=model,data_loader=val_loader,device=device)print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))tags = ["loss", "accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], acc, epoch)tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=17)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.01)parser.add_argument('--lrf', type=float, default=0.1)# 數(shù)據(jù)集所在根目錄# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\monkeys\training")# shufflenetv2_x1.0 官方權(quán)重下載地址# https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pthparser.add_argument('--weights', type=str, default='./shufflenetv2_x1-5666bf0f80.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)
第四步:統(tǒng)計(jì)正確率
第五步:搭建GUI界面
第六步:整個(gè)工程的內(nèi)容
有訓(xùn)練代碼和訓(xùn)練好的模型以及訓(xùn)練過程,提供數(shù)據(jù),提供GUI界面代碼
代碼的下載路徑(新窗口打開鏈接):基于Pytorch框架的深度學(xué)習(xí)ShufflenetV2神經(jīng)網(wǎng)絡(luò)十七種猴子動(dòng)物識(shí)別分類系統(tǒng)源碼
有問題可以私信或者留言,有問必答