中文亚洲精品无码_熟女乱子伦免费_人人超碰人人爱国产_亚洲熟妇女综合网

當前位置: 首頁 > news >正文

淄博政府網(wǎng)站建設(shè)公司百度商業(yè)賬號登錄

淄博政府網(wǎng)站建設(shè)公司,百度商業(yè)賬號登錄,網(wǎng)站 app開發(fā) 財務做帳,怎么給網(wǎng)站做跳轉(zhuǎn)?博客主頁:王樂予🎈 ?年輕人要:Living for the moment(活在當下)!💪 🏆推薦專欄:【圖像處理】【千錘百煉Python】【深度學習】【排序算法】 目錄 😺〇、倉庫…

?博客主頁:王樂予🎈
?年輕人要:Living for the moment(活在當下)!💪
🏆推薦專欄:【圖像處理】【千錘百煉Python】【深度學習】【排序算法】

目錄

  • 😺〇、倉庫源碼
  • 😺一、數(shù)據(jù)集介紹
    • 🐶1.1 GitHub原始數(shù)據(jù)集
    • 🐶1.2 GitHub預處理后的數(shù)據(jù)集
      • 🦄1.2.1 簡化的繪圖文件(.ndjson)
      • 🦄1.2.2 二進制文件(.bin)
      • 🦄1.2.3 Numpy位圖(.npy)
    • 🐶1.3 Kaggle數(shù)據(jù)集
  • 😺二、數(shù)據(jù)集準備
  • 😺三、獲取png格式圖片
  • 😺四、訓練過程
    • 🐶4.1 split_datasets.py
    • 🐶4.2 option.py
    • 🐶4.3 getdata.py
    • 🐶4.4 model.py
    • 🐶4.5 train-DDP.py
    • 🐶4.6 model_transfer.py
    • 🐶4.7 evaluate.py

😺〇、倉庫源碼

本文所有代碼存放在GitHub倉庫中QuickDraw-DDP:歡迎forkstar

😺一、數(shù)據(jù)集介紹

在這里插入圖片描述
Quick Draw 數(shù)據(jù)集是 345 個類別的 5000 萬張圖紙的集合,由游戲 Quick, Draw!的玩家貢獻。這些圖畫被捕獲為帶時間戳的矢量,并標記有元數(shù)據(jù),包括要求玩家繪制的內(nèi)容以及玩家所在的國家/地區(qū)。

GitHub數(shù)據(jù)集地址: 📎The Quick, Draw! Dataset

Kaggle數(shù)據(jù)集地址:📎Quick, Draw! Doodle Recognition Challenge

Github中提供了兩種類型的數(shù)據(jù)集,分別是 原始數(shù)據(jù)集預處理后的數(shù)據(jù)集 。
Google Cloud提供了數(shù)據(jù)集下載鏈接:quickdraw_dataset
在這里插入圖片描述

🐶1.1 GitHub原始數(shù)據(jù)集

原始數(shù)據(jù)以按類別分隔的 ndjson 文件的形式提供,格式如下:

類型說明
key_id64位無符號整型所有圖形的唯一標識符
word字符串類別
recognized布爾值該類別是否被游戲識別
timestamp日期時間繪制時間
countrycode字符串玩家所在位置的雙字母國家/地區(qū)代碼 (ISO 3166-1 alpha-2)
drawing字符串一個矢量繪制的 JSON 數(shù)組

每行包含一個繪圖數(shù)據(jù),下面是單個繪圖的示例:

  { "key_id":"5891796615823360","word":"nose","countrycode":"AE","timestamp":"2017-03-01 20:41:36.70725 UTC","recognized":true,"drawing":[[[129,128,129,129,130,130,131,132,132,133,133,133,133,...]]]}

drawing字段格式如下:

[ [  // First stroke [x0, x1, x2, x3, ...],[y0, y1, y2, y3, ...],[t0, t1, t2, t3, ...]],[  // Second stroke[x0, x1, x2, x3, ...],[y0, y1, y2, y3, ...],[t0, t1, t2, t3, ...]],... // Additional strokes
]

其中xy是像素坐標,t是自第一個點以來的時間(以毫秒為單位)。由于用于顯示和輸入的設(shè)備不同,原始繪圖可能具有截然不同的邊界框和點數(shù)。

🐶1.2 GitHub預處理后的數(shù)據(jù)集

🦄1.2.1 簡化的繪圖文件(.ndjson)

簡化了向量,刪除了時序信息,并將數(shù)據(jù)定位和縮放為256x256區(qū)域。數(shù)據(jù)以ndjson格式導出,其元數(shù)據(jù)與raw格式相同。簡化過程是:

  1. 將繪圖與左上角對齊,最小值為 0。
  2. 統(tǒng)一縮放繪圖,最大值為 255。
  3. 以 1 像素的間距對所有描邊重新取樣。
  4. 使用 epsilon 值為 2.0 的Ramer-Douglas-Peucker 算法簡化所有筆畫。

讀取ndjson文件的代碼如下:

# read_ndjson.py
import jsonwith open('aircraft carrier.ndjson', 'r') as file:for line in file:data = json.loads(line)key_id = data['key_id']drawing = data['drawing']# ……

讀取aircraft carrier.ndjsondebug之后的輸出結(jié)果如下圖所示??梢钥吹降谝恍袛?shù)據(jù)包含8個筆觸。
在這里插入圖片描述

🦄1.2.2 二進制文件(.bin)

簡化的圖紙和元數(shù)據(jù)也以自定義二進制格式提供,以實現(xiàn)高效的壓縮和加載。

讀取bin文件的代碼如下:

# read_bin.py
import struct
from struct import unpackdef unpack_drawing(file_handle):key_id, = unpack('Q', file_handle.read(8))country_code, = unpack('2s', file_handle.read(2))recognized, = unpack('b', file_handle.read(1))timestamp, = unpack('I', file_handle.read(4))n_strokes, = unpack('H', file_handle.read(2))image = []for i in range(n_strokes):n_points, = unpack('H', file_handle.read(2))fmt = str(n_points) + 'B'x = unpack(fmt, file_handle.read(n_points))y = unpack(fmt, file_handle.read(n_points))image.append((x, y))return {'key_id': key_id,'country_code': country_code,'recognized': recognized,'timestamp': timestamp,'image': image}def unpack_drawings(filename):with open(filename, 'rb') as f:while True:try:yield unpack_drawing(f)except struct.error:breakfor drawing in unpack_drawings('nose.bin'):# do something with the drawingprint(drawing['country_code'])

🦄1.2.3 Numpy位圖(.npy)

所有簡化的繪圖都已渲染為numpy格式的28x28灰度位圖。這些圖像是根據(jù)簡化的數(shù)據(jù)生成的,但與繪圖邊界框的中心對齊,而不是與左上角對齊。

讀取npy文件的代碼如下:

# read_npy.py
import numpy as npdata_path = 'aircraft_carrier.npy'data = np.load(data_path)
print(data)

🐶1.3 Kaggle數(shù)據(jù)集

在Kaggle競賽中,使用的數(shù)據(jù)集為340個類別。數(shù)據(jù)格式統(tǒng)一為csv表格數(shù)據(jù)。數(shù)據(jù)集中有5個文件:

  • sample_submission.csv - 正確格式的樣本提交文件
  • test_raw.csv - 矢量格式的測試數(shù)據(jù)raw
  • test_simplified.csv - 矢量格式的測試數(shù)據(jù)simplified
  • train_raw.zip - 向量格式的訓練數(shù)據(jù);每個單詞一個 CSV 文件raw
  • train_simplified.zip - 向量格式的訓練數(shù)據(jù);每個單詞一個 CSV 文件simplified

注:csv文件的列titlendjson文件的鍵名一致。

😺二、數(shù)據(jù)集準備

本文將使用kaggle提供的train_simplified數(shù)據(jù)集。案例流程包含:

  1. 將所有類的csv格式文件保存為png圖片格式;
  2. 對340個類別的png格式圖片各抽取10000張用作后續(xù)實踐;
  3. 對每個類別的10000張數(shù)據(jù)進行8:1:1的訓練集、驗證集、測試集的劃分;
  4. 訓練模型;
  5. 模型評估。

😺三、獲取png格式圖片

使用下面腳本可以將csv數(shù)據(jù)轉(zhuǎn)為png圖片格式保存。

# csv2png.py
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy import interpolate, misc
import matplotlib
matplotlib.use('Agg')input_dir = 'kaggle/train_simplified'
output_base_dir = 'datasets256'os.makedirs(output_base_dir, exist_ok=True)csv_files = [f for f in os.listdir(input_dir) if f.endswith('.csv')]    # Retrieve all CSV files from the folderskipped_files = []  # Record skipped filesfor csv_file in csv_files:csv_file_path = os.path.join(input_dir, csv_file)   # Build a complete file pathoutput_dir = os.path.join(output_base_dir, os.path.splitext(csv_file)[0])   # Build output directoryif os.path.exists(output_dir):      # Check if the output directory existsskipped_files.append(csv_file)print(f'The directory already exists, skip file: {csv_file}')continueos.makedirs(output_dir, exist_ok=True)data = pd.read_csv(csv_file_path)       # Read CSV filefor index, row in data.iterrows():  # Traverse each row of datadrawing = eval(row['drawing'])key_id = row['key_id']word = row['word']img = np.zeros((256, 256))      # Initialize imagefig = plt.figure(figsize=(256/96, 256/96), dpi=96)for stroke in drawing:      # Draw each strokestroke_x = stroke[0]stroke_y = stroke[1]x = np.array(stroke_x)y = np.array(stroke_y)np.interp((x + y) / 2, x, y)plt.plot(x, y, 'k')ax = plt.gca()ax.xaxis.set_ticks_position('top')ax.invert_yaxis()plt.axis('off')plt.savefig(os.path.join(output_dir, f'{word}-{key_id}.png'))plt.close(fig)print(f'Conversion completed: {csv_file} the {index:06d}image')print("The skipped files are:")
for file in skipped_files:print(file)

需要注意的是:繪圖數(shù)據(jù)有5000萬左右,處理時間非常久,建議多開幾個腳本運行(PS:代碼中添加了文件夾是否存在的判斷語句,不用擔心會重復寫入)。也可以使用joblib庫多線程加速(玩不好容易宕機,不建議)。

相關(guān)文件存儲空間大小如下:

  • GitHub 預處理后的ndjson文件有23G
  • Kaggletrain_raw.zip文件有206G
  • Kaggletrain_simplified.zip文件有23G
  • Kaggletrain_simplified轉(zhuǎn)為256*256大小的圖片有470G

如果磁盤空間不足,進行png轉(zhuǎn)化時可以選擇128128大小或者6464大小。也可以保存單通道圖像。

建議處理完畢之后使用下面的腳本檢查一下有沒有沒處理的類別:

# check_class_num.py
import osfolder = 'datasets256'subfolders = [f.path for f in os.scandir(folder) if f.is_dir()]for subfolder in subfolders:    # Traverse each subfoldersfolder_name = os.path.basename(subfolder)   # Get the name of the subfoldersfiles = [f for f in os.scandir(subfolder) if f.is_file()]   # Retrieve all files in the subfoldersimage_count = sum(1 for f in files if f.name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')))   # Calculate the number of imagesif image_count == 0:        # If the number of images is 0, print out the names of the subfolders and delete themprint(f"There are no images in the subfolders '{folder_name}', deleting them...")os.rmdir(subfolder)print(f"subfolders '{folder_name}' deleted")else:print(f"Number of images in subfolders: '{folder_name}' : {image_count}")

如果檢查到有空文件夾,需要再運行csv2png.py的代碼。

😺四、訓練過程

🐶4.1 split_datasets.py

首先要劃分數(shù)據(jù)集,原始數(shù)據(jù)為png圖片格式數(shù)據(jù)集。

import os
import shutil
import randomoriginal_dataset_path = 'datasets256'     # Original dataset path
new_dataset_path = 'datasets'                       # Divide the dataset pathtrain_path = os.path.join(new_dataset_path, 'train')
val_path = os.path.join(new_dataset_path, 'val')
test_path = os.path.join(new_dataset_path, 'test')if not os.path.exists(train_path):os.makedirs(train_path)if not os.path.exists(val_path):os.makedirs(val_path)if not os.path.exists(test_path):os.makedirs(test_path)classes = os.listdir(original_dataset_path)     # Get all categoriesrandom.seed(42)for class_name in classes:      # Traverse each categorysrc_folder = os.path.join(original_dataset_path, class_name)    # Source folder path# Check if the folder for this category already exists under train, val, and testtrain_folder = os.path.join(train_path, class_name)val_folder = os.path.join(val_path, class_name)test_folder = os.path.join(test_path, class_name)# If the train, val, and test folders already exist, skip the folder creation sectionif os.path.exists(train_folder) and os.path.exists(val_folder) and os.path.exists(test_folder):# Check if the folder is emptyif os.listdir(train_folder) and os.listdir(val_folder) and os.listdir(test_folder):print(f"Category {class_name} already exists and is not empty, skip processing.")continue# create folderif not os.path.exists(train_folder):os.makedirs(train_folder)if not os.path.exists(val_folder):os.makedirs(val_folder)if not os.path.exists(test_folder):os.makedirs(test_folder)files = os.listdir(src_folder)      # Retrieve all file names under this categoryfiles = files[:10000]       # Only retrieve the first 10000 filesrandom.shuffle(files)       # Shuffle file listtotal_files = len(files)train_split_index = int(total_files * 0.8)val_split_index = int(total_files * 0.9)train_files = files[:train_split_index]val_files = files[train_split_index:val_split_index]test_files = files[val_split_index:]for file in train_files:src_file = os.path.join(src_folder, file)dst_file = os.path.join(train_folder, file)shutil.copy(src_file, dst_file)for file in val_files:src_file = os.path.join(src_folder, file)dst_file = os.path.join(val_folder, file)shutil.copy(src_file, dst_file)for file in test_files:src_file = os.path.join(src_folder, file)dst_file = os.path.join(test_folder, file)shutil.copy(src_file, dst_file)print("Dataset partitioning completed!")

代碼運行完畢之后,datasets目錄下面會出現(xiàn)三個文件夾,分別是train、valtest

🐶4.2 option.py

定義后續(xù)我們需要的一些參數(shù)。

import argparsedef get_args():parser = argparse.ArgumentParser(description='all argument')parser.add_argument('--num_classes', type=int, default=340, help='image num classes')parser.add_argument('--loadsize', type=int, default=64, help='image size')parser.add_argument('--epochs', type=int, default=100, help='all epochs')parser.add_argument('--batch_size', type=int, default=1024, help='batch size')parser.add_argument('--lr', type=float, default=0.001, help='init lr')parser.add_argument('--use_lr_scheduler', type=bool, default=True, help='use lr scheduler')parser.add_argument('--dataset_train', type=str, default='./datasets/train', help='train path')parser.add_argument('--dataset_val', type=str, default="./datasets/val", help='val path')parser.add_argument('--dataset_test', type=str, default="./datasets/test", help='test path')parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='ckpt path')parser.add_argument('--tensorboard_dir', type=str, default='./tensorboard_dir', help='log path')parser.add_argument('--resume', type=bool, default=False, help='continue training')parser.add_argument('--resume_ckpt', type=str, default='./checkpoints/model_best.pth', help='choose breakpoint ckpt')parser.add_argument('--local-rank', type=int, default=-1, help='local rank')parser.add_argument('--use_mix_precision', type=bool, default=False, help='use mix pretrain')parser.add_argument('--test_img_path', type=str, default='datasets/test/zigzag/zigzag-4508464694951936.png', help='choose test image')parser.add_argument('--test_dir_path', type=str, default='./datasets/test', help='choose test path')return parser.parse_args()

由于后續(xù)將使用DDP單機多卡以及AMP策略進行訓練,因此額外加入了local-rankuse_mix_precision參數(shù)。

🐶4.3 getdata.py

接下來定義數(shù)據(jù)管道。

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from option import get_args
opt = get_args()mean = [0.9367, 0.9404, 0.9405]
std = [0.1971, 0.1970, 0.1972]
def data_augmentation():data_transform = {'train': transforms.Compose([transforms.Resize((opt.loadsize, opt.loadsize)),transforms.ToTensor(),  # HWC -> CHWtransforms.Normalize(mean, std)]),'val': transforms.Compose([transforms.Resize((opt.loadsize, opt.loadsize)),transforms.ToTensor(),transforms.Normalize(mean, std)]),}return data_transformdef MyData():data_transform = data_augmentation()image_datasets = {'train': ImageFolder(opt.dataset_train, data_transform['train']),'val': ImageFolder(opt.dataset_val, data_transform['val']),}data_sampler = {'train': torch.utils.data.distributed.DistributedSampler(image_datasets['train']),'val': torch.utils.data.distributed.DistributedSampler(image_datasets['val']),}dataloaders = {'train': DataLoader(image_datasets['train'], batch_size=opt.batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=data_sampler['train']),'val': DataLoader(image_datasets['val'], batch_size=opt.batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=data_sampler['val'])}return dataloadersclass_names =['The Eiffel Tower', 'The Great Wall of China', 'The Mona Lisa', 'airplane', 'alarm clock', 'ambulance', 'angel', 'animal migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball bat', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling fan', 'cell phone', 'cello', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise ship', 'cup', 'diamond', 'dishwasher', 'diving board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip flops', 'floor lamp', 'flower', 'flying saucer', 'foot', 'fork', 'frog', 'frying pan', 'garden', 'garden hose', 'giraffe', 'goatee', 'golf club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey puck', 'hockey stick', 'horse', 'hospital', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass', 'house', 'house plant', 'hurricane', 'ice cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light bulb', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint can', 'paintbrush', 'palm tree', 'panda', 'pants', 'paper clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup truck', 'picture frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote control', 'rhinoceros', 'river', 'roller coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 'saxophone', 'school bus', 'scissors', 'scorpion', 'screwdriver', 'sea turtle', 'see saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 'smiley face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop sign', 'stove', 'strawberry', 'streetlight', 'string bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing set', 'sword', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis racquet', 'tent', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 'yoga', 'zebra', 'zigzag'
]if __name__ == '__main__':mena_std_transform = transforms.Compose([transforms.ToTensor()])dataset = ImageFolder(opt.dataset_val, transform=mena_std_transform)print(dataset.class_to_idx)		# Index for each category

🐶4.4 model.py

定義模型,這里使用mobilenet的small版本。需要將模型的classifier層的輸出改為類別數(shù)量。
可以使用更多優(yōu)質(zhì)的模型對數(shù)據(jù)集進行訓練,例如shufflenet、squeezenet等。

import torch.nn as nn
from torchvision.models import mobilenet_v3_small
from torchsummary import summary
from option import get_args
opt = get_args()def CustomMobileNetV3():model = mobilenet_v3_small(weights='MobileNet_V3_Small_Weights.IMAGENET1K_V1')model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, opt.num_classes)return modelif __name__ == '__main__':model = CustomMobileNetV3()print(model)print(summary(model.to(opt.device), (3, opt.loadsize, opt.loadsize), opt.batch_size))

模型結(jié)構(gòu)如下:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [1024, 16, 32, 32]             432BatchNorm2d-2         [1024, 16, 32, 32]              32Hardswish-3         [1024, 16, 32, 32]               0Conv2d-4         [1024, 16, 16, 16]             144BatchNorm2d-5         [1024, 16, 16, 16]              32ReLU-6         [1024, 16, 16, 16]               0AdaptiveAvgPool2d-7           [1024, 16, 1, 1]               0Conv2d-8            [1024, 8, 1, 1]             136ReLU-9            [1024, 8, 1, 1]               0Conv2d-10           [1024, 16, 1, 1]             144Hardsigmoid-11           [1024, 16, 1, 1]               0
SqueezeExcitation-12         [1024, 16, 16, 16]               0Conv2d-13         [1024, 16, 16, 16]             256BatchNorm2d-14         [1024, 16, 16, 16]              32InvertedResidual-15         [1024, 16, 16, 16]               0Conv2d-16         [1024, 72, 16, 16]           1,152BatchNorm2d-17         [1024, 72, 16, 16]             144ReLU-18         [1024, 72, 16, 16]               0Conv2d-19           [1024, 72, 8, 8]             648BatchNorm2d-20           [1024, 72, 8, 8]             144ReLU-21           [1024, 72, 8, 8]               0Conv2d-22           [1024, 24, 8, 8]           1,728BatchNorm2d-23           [1024, 24, 8, 8]              48InvertedResidual-24           [1024, 24, 8, 8]               0Conv2d-25           [1024, 88, 8, 8]           2,112BatchNorm2d-26           [1024, 88, 8, 8]             176ReLU-27           [1024, 88, 8, 8]               0Conv2d-28           [1024, 88, 8, 8]             792BatchNorm2d-29           [1024, 88, 8, 8]             176ReLU-30           [1024, 88, 8, 8]               0Conv2d-31           [1024, 24, 8, 8]           2,112BatchNorm2d-32           [1024, 24, 8, 8]              48InvertedResidual-33           [1024, 24, 8, 8]               0Conv2d-34           [1024, 96, 8, 8]           2,304BatchNorm2d-35           [1024, 96, 8, 8]             192Hardswish-36           [1024, 96, 8, 8]               0Conv2d-37           [1024, 96, 4, 4]           2,400BatchNorm2d-38           [1024, 96, 4, 4]             192Hardswish-39           [1024, 96, 4, 4]               0
AdaptiveAvgPool2d-40           [1024, 96, 1, 1]               0Conv2d-41           [1024, 24, 1, 1]           2,328ReLU-42           [1024, 24, 1, 1]               0Conv2d-43           [1024, 96, 1, 1]           2,400Hardsigmoid-44           [1024, 96, 1, 1]               0
SqueezeExcitation-45           [1024, 96, 4, 4]               0Conv2d-46           [1024, 40, 4, 4]           3,840BatchNorm2d-47           [1024, 40, 4, 4]              80InvertedResidual-48           [1024, 40, 4, 4]               0Conv2d-49          [1024, 240, 4, 4]           9,600BatchNorm2d-50          [1024, 240, 4, 4]             480Hardswish-51          [1024, 240, 4, 4]               0Conv2d-52          [1024, 240, 4, 4]           6,000BatchNorm2d-53          [1024, 240, 4, 4]             480Hardswish-54          [1024, 240, 4, 4]               0
AdaptiveAvgPool2d-55          [1024, 240, 1, 1]               0Conv2d-56           [1024, 64, 1, 1]          15,424ReLU-57           [1024, 64, 1, 1]               0Conv2d-58          [1024, 240, 1, 1]          15,600Hardsigmoid-59          [1024, 240, 1, 1]               0
SqueezeExcitation-60          [1024, 240, 4, 4]               0Conv2d-61           [1024, 40, 4, 4]           9,600BatchNorm2d-62           [1024, 40, 4, 4]              80InvertedResidual-63           [1024, 40, 4, 4]               0Conv2d-64          [1024, 240, 4, 4]           9,600BatchNorm2d-65          [1024, 240, 4, 4]             480Hardswish-66          [1024, 240, 4, 4]               0Conv2d-67          [1024, 240, 4, 4]           6,000BatchNorm2d-68          [1024, 240, 4, 4]             480Hardswish-69          [1024, 240, 4, 4]               0
AdaptiveAvgPool2d-70          [1024, 240, 1, 1]               0Conv2d-71           [1024, 64, 1, 1]          15,424ReLU-72           [1024, 64, 1, 1]               0Conv2d-73          [1024, 240, 1, 1]          15,600Hardsigmoid-74          [1024, 240, 1, 1]               0
SqueezeExcitation-75          [1024, 240, 4, 4]               0Conv2d-76           [1024, 40, 4, 4]           9,600BatchNorm2d-77           [1024, 40, 4, 4]              80InvertedResidual-78           [1024, 40, 4, 4]               0Conv2d-79          [1024, 120, 4, 4]           4,800BatchNorm2d-80          [1024, 120, 4, 4]             240Hardswish-81          [1024, 120, 4, 4]               0Conv2d-82          [1024, 120, 4, 4]           3,000BatchNorm2d-83          [1024, 120, 4, 4]             240Hardswish-84          [1024, 120, 4, 4]               0
AdaptiveAvgPool2d-85          [1024, 120, 1, 1]               0Conv2d-86           [1024, 32, 1, 1]           3,872ReLU-87           [1024, 32, 1, 1]               0Conv2d-88          [1024, 120, 1, 1]           3,960Hardsigmoid-89          [1024, 120, 1, 1]               0
SqueezeExcitation-90          [1024, 120, 4, 4]               0Conv2d-91           [1024, 48, 4, 4]           5,760BatchNorm2d-92           [1024, 48, 4, 4]              96InvertedResidual-93           [1024, 48, 4, 4]               0Conv2d-94          [1024, 144, 4, 4]           6,912BatchNorm2d-95          [1024, 144, 4, 4]             288Hardswish-96          [1024, 144, 4, 4]               0Conv2d-97          [1024, 144, 4, 4]           3,600BatchNorm2d-98          [1024, 144, 4, 4]             288Hardswish-99          [1024, 144, 4, 4]               0
AdaptiveAvgPool2d-100          [1024, 144, 1, 1]               0Conv2d-101           [1024, 40, 1, 1]           5,800ReLU-102           [1024, 40, 1, 1]               0Conv2d-103          [1024, 144, 1, 1]           5,904Hardsigmoid-104          [1024, 144, 1, 1]               0
SqueezeExcitation-105          [1024, 144, 4, 4]               0Conv2d-106           [1024, 48, 4, 4]           6,912BatchNorm2d-107           [1024, 48, 4, 4]              96
InvertedResidual-108           [1024, 48, 4, 4]               0Conv2d-109          [1024, 288, 4, 4]          13,824BatchNorm2d-110          [1024, 288, 4, 4]             576Hardswish-111          [1024, 288, 4, 4]               0Conv2d-112          [1024, 288, 2, 2]           7,200BatchNorm2d-113          [1024, 288, 2, 2]             576Hardswish-114          [1024, 288, 2, 2]               0
AdaptiveAvgPool2d-115          [1024, 288, 1, 1]               0Conv2d-116           [1024, 72, 1, 1]          20,808ReLU-117           [1024, 72, 1, 1]               0Conv2d-118          [1024, 288, 1, 1]          21,024Hardsigmoid-119          [1024, 288, 1, 1]               0
SqueezeExcitation-120          [1024, 288, 2, 2]               0Conv2d-121           [1024, 96, 2, 2]          27,648BatchNorm2d-122           [1024, 96, 2, 2]             192
InvertedResidual-123           [1024, 96, 2, 2]               0Conv2d-124          [1024, 576, 2, 2]          55,296BatchNorm2d-125          [1024, 576, 2, 2]           1,152Hardswish-126          [1024, 576, 2, 2]               0Conv2d-127          [1024, 576, 2, 2]          14,400BatchNorm2d-128          [1024, 576, 2, 2]           1,152Hardswish-129          [1024, 576, 2, 2]               0
AdaptiveAvgPool2d-130          [1024, 576, 1, 1]               0Conv2d-131          [1024, 144, 1, 1]          83,088ReLU-132          [1024, 144, 1, 1]               0Conv2d-133          [1024, 576, 1, 1]          83,520Hardsigmoid-134          [1024, 576, 1, 1]               0
SqueezeExcitation-135          [1024, 576, 2, 2]               0Conv2d-136           [1024, 96, 2, 2]          55,296BatchNorm2d-137           [1024, 96, 2, 2]             192
InvertedResidual-138           [1024, 96, 2, 2]               0Conv2d-139          [1024, 576, 2, 2]          55,296BatchNorm2d-140          [1024, 576, 2, 2]           1,152Hardswish-141          [1024, 576, 2, 2]               0Conv2d-142          [1024, 576, 2, 2]          14,400BatchNorm2d-143          [1024, 576, 2, 2]           1,152Hardswish-144          [1024, 576, 2, 2]               0
AdaptiveAvgPool2d-145          [1024, 576, 1, 1]               0Conv2d-146          [1024, 144, 1, 1]          83,088ReLU-147          [1024, 144, 1, 1]               0Conv2d-148          [1024, 576, 1, 1]          83,520Hardsigmoid-149          [1024, 576, 1, 1]               0
SqueezeExcitation-150          [1024, 576, 2, 2]               0Conv2d-151           [1024, 96, 2, 2]          55,296BatchNorm2d-152           [1024, 96, 2, 2]             192
InvertedResidual-153           [1024, 96, 2, 2]               0Conv2d-154          [1024, 576, 2, 2]          55,296BatchNorm2d-155          [1024, 576, 2, 2]           1,152Hardswish-156          [1024, 576, 2, 2]               0
AdaptiveAvgPool2d-157          [1024, 576, 1, 1]               0Linear-158               [1024, 1024]         590,848Hardswish-159               [1024, 1024]               0Dropout-160               [1024, 1024]               0Linear-161                [1024, 340]         348,500
================================================================
Total params: 1,866,356
Trainable params: 1,866,356
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 48.00
Forward/backward pass size (MB): 2979.22
Params size (MB): 7.12
Estimated Total Size (MB): 3034.34
----------------------------------------------------------------

🐶4.5 train-DDP.py

需要注意的是,train-DDP.py中包含許多訓練策略:

  • DDP分布式訓練(單機雙卡);
  • AMP混合精度訓練;
  • 學習率衰減;
  • 早停;
  • 斷點繼續(xù)訓練。
# python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.168.8.89" --master_port=12345 train-DDP.py --use_mix_precision True
# Watch Training Log:tensorboard --logdir=tensorboard_dir
from tqdm import tqdm
import torch
import torch.nn.parallel
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
import time
import os
import torch.optim
import torch.utils.data
import torch.nn as nn
from collections import OrderedDict
from model import CustomMobileNetV3
from getdata import MyData
from torch.cuda.amp import GradScaler
from option import get_args
opt = get_args()
dist.init_process_group(backend='nccl', init_method='env://')os.makedirs(opt.checkpoints, exist_ok=True)def train(gpu):rank = dist.get_rank()model = CustomMobileNetV3()model.cuda(gpu)criterion = nn.CrossEntropyLoss().to(gpu)optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)model = nn.SyncBatchNorm.convert_sync_batchnorm(model)model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])scaler = GradScaler(enabled=opt.use_mix_precision)  dataloaders = MyData()train_loader = dataloaders['train']test_loader = dataloaders['val']if opt.use_lr_scheduler:scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)start_time = time.time()best_val_acc = 0.0no_improve_epochs = 0early_stopping_patience = 6  # Early Stopping Patience"""breakckpt resume"""if opt.resume:checkpoint = torch.load(opt.resume_ckpt)print('Loading checkpoint from:', opt.resume_ckpt)new_state_dict = OrderedDict()      # Create a new ordered dictionary and remove prefixesfor k, v in checkpoint['model'].items():name = k[7:]                    # Remove 'module.' To match the original model definitionnew_state_dict[name] = vmodel.load_state_dict(new_state_dict, strict=False)     # Load a new state dictionaryoptimizer.load_state_dict(checkpoint['optimizer'])start_epoch = checkpoint['epoch']                       # Set the starting epochif opt.use_lr_scheduler:scheduler.load_state_dict(checkpoint['scheduler'])else:start_epoch = 0for epoch in range(start_epoch + 1, opt.epochs):tqdm_trainloader = tqdm(train_loader, desc=f'Epoch {epoch}')running_loss, running_correct_top1, running_correct_top3, running_correct_top5 = 0.0, 0.0, 0.0, 0.0total_samples = 0for i, (images, target) in enumerate(tqdm_trainloader if rank == 0 else train_loader, 0):images = images.to(gpu)target = target.to(gpu)with torch.cuda.amp.autocast(enabled=opt.use_mix_precision):output = model(images)loss = criterion(output, target)optimizer.zero_grad()scaler.scale(loss).backward()scaler.step(optimizer)scaler.update() running_loss += loss.item() * images.size(0)_, predicted = torch.max(output.data, 1)running_correct_top1  += (predicted == target).sum().item()_, predicted_top3 = torch.topk(output.data, 3, dim=1)_, predicted_top5 = torch.topk(output.data, 5, dim=1)running_correct_top3 += (predicted_top3[:, :3] == target.unsqueeze(1).expand_as(predicted_top3)).sum().item()running_correct_top5 += (predicted_top5[:, :5] == target.unsqueeze(1).expand_as(predicted_top5)).sum().item()total_samples += target.size(0)state = {'epoch': epoch,'model': model.module.state_dict(),'optimizer': optimizer.state_dict(),'scheduler': scheduler.state_dict()}if rank == 0:current_lr = scheduler.get_last_lr()[0] if opt.use_lr_scheduler else opt.lrprint(f'[Epoch {epoch}]  'f'[Train Loss: {running_loss / len(train_loader.dataset):.6f}]  'f'[Train Top-1 Acc: {running_correct_top1 / len(train_loader.dataset):.6f}]  'f'[Train Top-3 Acc: {running_correct_top3 / len(train_loader.dataset):.6f}]  'f'[Train Top-5 Acc: {running_correct_top5 / len(train_loader.dataset):.6f}]  'f'[Learning Rate: {current_lr:.6f}]  'f'[Time: {time.time() - start_time:.6f} seconds]')writer.add_scalar('Train/Loss', running_loss / len(train_loader.dataset), epoch)writer.add_scalar('Train/Top-1 Accuracy', running_correct_top1 / len(train_loader.dataset), epoch)writer.add_scalar('Train/Top-3 Accuracy', running_correct_top3 / len(train_loader.dataset), epoch)writer.add_scalar('Train/Top-5 Accuracy', running_correct_top5 / len(train_loader.dataset), epoch)writer.add_scalar('Train/Learning Rate', current_lr, epoch)torch.save(state, f'{opt.checkpoints}model_epoch_{epoch}.pth')# dist.barrier()tqdm_trainloader.close()if opt.use_lr_scheduler:    # Learning-rate Schedulerscheduler.step()acc_top1 = valid(test_loader, model, epoch, gpu, rank)if acc_top1 is not None:if acc_top1 > best_val_acc:best_val_acc = acc_top1no_improve_epochs = 0torch.save(state, f'{opt.checkpoints}/model_best.pth')else:no_improve_epochs += 1if no_improve_epochs >= early_stopping_patience:print(f'Early stopping triggered after {early_stopping_patience} epochs without improvement.')breakelse:print("Warning: acc_top1 is None, skipping this epoch.")dist.destroy_process_group()def valid(val_loader, model, epoch, gpu, rank):model.eval()correct_top1, correct_top3, correct_top5, total = torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu)with torch.no_grad():tqdm_valloader = tqdm(val_loader, desc=f'Epoch {epoch}')for i, (images, target) in enumerate(tqdm_valloader, 0) :images = images.to(gpu)target = target.to(gpu)output = model(images)total += target.size(0)correct_top1  += (output.argmax(1) == target).type(torch.float).sum()_, predicted_top3 = torch.topk(output, 3, dim=1)_, predicted_top5 = torch.topk(output, 5, dim=1)correct_top3 += (predicted_top3[:, :3] == target.unsqueeze(1).expand_as(predicted_top3)).sum().item()correct_top5 += (predicted_top5[:, :5] == target.unsqueeze(1).expand_as(predicted_top5)).sum().item()dist.reduce(total, 0, op=dist.ReduceOp.SUM)     # Group communication reduce operation (change to allreduce if Gloo)dist.reduce(correct_top1, 0, op=dist.ReduceOp.SUM)dist.reduce(correct_top3, 0, op=dist.ReduceOp.SUM)dist.reduce(correct_top5, 0, op=dist.ReduceOp.SUM)if rank == 0:print(f'[Epoch {epoch}]  'f'[Val Top-1 Acc: {correct_top1 / total:.6f}]  'f'[Val Top-3 Acc: {correct_top3 / total:.6f}]  'f'[Val Top-5 Acc: {correct_top5 / total:.6f}]')writer.add_scalar('Validation/Top-1 Accuracy', correct_top1 / total, epoch)writer.add_scalar('Validation/Top-3 Accuracy', correct_top3 / total, epoch)writer.add_scalar('Validation/Top-5 Accuracy', correct_top5 / total, epoch)return float(correct_top1 / total)  # Return top 1 precisiontqdm_valloader.close()def main():train(opt.local_rank)if __name__ == '__main__':writer = SummaryWriter(log_dir=opt.tensorboard_dir)main()writer.close()

在終端使用下面命令可以啟動多卡分布式訓練:

python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.168.8.89" --master_port=12345 train-DDP.py --use_mix_precision True

相關(guān)參數(shù)含義如下:

  • nproc_per_node:顯卡數(shù)量
  • nnodes:機器數(shù)量
  • node_rank:機器編號
  • master_addr:機器ip地址
  • master_port:機器端口

如果使用nohup啟動訓練會存在一個bug

W0914 18:33:15.081479 140031432897728 torch/distributed/elastic/agent/server/api.py:741] Received Signals.SIGHUP death signal, shutting down workers
W0914 18:33:15.085310 140031432897728 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 1685186 closing signal SIGHUP
W0914 18:33:15.085644 140031432897728 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 1685192 closing signal SIGHUP

具體原因可以參考pytorch官方的discuss:DDP Error: torch.distributed.elastic.agent.server.api:Received 1 death signal, shutting down workers

我們可以使用tmux解決這個問題。

  1. 安裝tmuxsudo apt-get install tmux
  2. 新建會話:tmux new -s train-DDP(會話名稱自定義)
  3. 激活虛擬環(huán)境:conda activate pytorch(虛擬環(huán)境以實際需要為準)
  4. 啟動訓練任務:python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.168.8.89" --master_port=12345 train-DDP.py --use_mix_precision True

tmux常用命令如下:

  • 查看當前全部的tmux會話:tmux ls
  • 新建會話:tmux new -s 會話名字
  • 重新進入會話:tmux attach -t 會話名字
  • kill會話:tmux kill-session -t 會話名字

本文訓練過程中的日志如下圖所示:
在這里插入圖片描述
在這里插入圖片描述
模型在第11輪發(fā)生早停。

🐶4.6 model_transfer.py

代碼作用是將pth模型轉(zhuǎn)為移動端的ptl格式和onnx格式,方便模型端側(cè)部署。

from torch.utils.mobile_optimizer import optimize_for_mobile
import torch
from model import CustomMobileNetV3
import onnx
from onnxsim import simplify
from torch.autograd import Variable
from option import get_args
opt = get_args()model = CustomMobileNetV3()
model.load_state_dict(torch.load(f'{opt.checkpoints}model_best.pth', map_location='cpu')['model'])
model.eval()
print("Model loaded successfully.")"""Save .pth format model"""
torch.save(model, f'{opt.checkpoints}/model.pth')"""Save .ptl format model"""
example = torch.rand(1, 3, 64, 64)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(f'{opt.checkpoints}model.ptl')"""Save .onnx format model"""
input_name = ['input']
output_name = ['output']
input = Variable(torch.randn(1, 3, opt.loadsize, opt.loadsize))
torch.onnx.export(model, input, f'{opt.checkpoints}model.onnx', input_names=input_name, output_names=output_name, verbose=True)
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(f'{opt.checkpoints}model.onnx')), f'{opt.checkpoints}model.onnx')   # Perform shape judgment
# simplified model
model_onnx = onnx.load(f'{opt.checkpoints}model.onnx')
model_simplified, check = simplify(model_onnx)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simplified, f'{opt.checkpoints}model_simplified.onnx')

🐶4.7 evaluate.py

代碼定義了三個函數(shù):

  • evaluate_image_single:對單張圖像進行預測
  • evaluate_image_dir:對文件夾圖像進行預測
  • evaluate_onnx_model:onnx模型對圖像進行預測

代碼提供了多個可視化圖像與評估指標。包括 混淆矩陣、F1score 等。

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch.nn.functional as F
import torch.utils.data
import onnxruntime
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix, roc_curve, auc
from tqdm import tqdm
from getdata import mean, std, class_names
from option import get_args
opt = get_args()
device = 'cuda:1'"""Predicting a single image"""
def evaluate_image_single(img_path, transform_test, model, class_names, top_k):image = Image.open(img_path).convert('RGB')img = transform_test(image).to(device)img = img.unsqueeze_(0)out = model(img)pred_softmax = F.softmax(out, dim=1)top_n, top_n_indices = torch.topk(pred_softmax, top_k)confs = top_n[0].cpu().detach().numpy().tolist()class_names_top = [class_names[i] for i in top_n_indices[0]]for i in range(top_k):print(f'Pre: {class_names_top[i]}   Conf: {confs[i]:.3f}')confs_max = confs[0]plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.axis('off')plt.title(f'Pre: {class_names_top[0]}   Conf: {confs_max:.3f}')plt.imshow(image)sorted_pairs = sorted(zip(class_names_top, confs), key=lambda x: x[1], reverse=True)sorted_class_names_top, sorted_confs = zip(*sorted_pairs)plt.subplot(1, 2, 2)bars = plt.bar(sorted_class_names_top, sorted_confs, color='lightcoral')plt.xlabel('Class Names')plt.ylabel('Confidence')plt.title('Top 5 Predictions (Descending Order)')plt.xticks(rotation=45)plt.ylim(0, 1)plt.tight_layout()for bar, conf in zip(bars, sorted_confs):yval = bar.get_height()plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, f'{conf:.3f}', ha='center', va='bottom')plt.savefig('predict_image_with_bars.jpg')"""Predicting folder images"""
def evaluate_image_dir(model, dataloader, class_names):model.eval()all_preds = []all_labels = []correct_top1, correct_top3, correct_top5, total = torch.tensor(0.).to(device), torch.tensor(0.).to(device), torch.tensor(0.).to(device), torch.tensor(0.).to(device)with torch.no_grad():for images, labels in tqdm(dataloader, desc="Evaluating"):images = images.to(device)labels = labels.to(device)outputs = model(images)total += labels.size(0)correct_top1  += (outputs.argmax(1) == labels).type(torch.float).sum()_, predicted_top3 = torch.topk(outputs, 3, dim=1)_, predicted_top5 = torch.topk(outputs, 5, dim=1)correct_top3 += (predicted_top3[:, :3] == labels.unsqueeze(1).expand_as(predicted_top3)).sum().item()correct_top5 += (predicted_top5[:, :5] == labels.unsqueeze(1).expand_as(predicted_top5)).sum().item()_, preds = torch.max(outputs, 1)all_preds.extend(preds)all_labels.extend(labels)all_preds = torch.tensor(all_preds)all_labels = torch.tensor(all_labels)top1 = correct_top1 / totaltop3 = correct_top3 / totaltop5 = correct_top5 / totalprint(f"Top-1 Accuracy: {top1:.4f}")print(f"Top-3 Accuracy: {top3:.4f}")print(f"Top-5 Accuracy: {top5:.4f}")accuracy = accuracy_score(all_labels.cpu().numpy(), all_preds.cpu().numpy())precision = precision_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')recall = recall_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')f1 = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')cm = confusion_matrix(all_labels.cpu().numpy(), all_preds.cpu().numpy())report = classification_report(all_labels.cpu().numpy(), all_preds.cpu().numpy(), target_names=class_names)print(f'Accuracy: {accuracy:.4f}')print(f'Precision: {precision:.4f}')print(f'Recall: {recall:.4f}')print(f'F1 Score: {f1:.4f}')print(report)plt.figure(figsize=(100, 100))sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, annot_kws={"size": 8})plt.xticks(rotation=90) plt.yticks(rotation=0)  plt.xlabel('Predicted Label')plt.ylabel('True Label')plt.title('Confusion Matrix')plt.savefig('confusion_matrix.jpg')"""Using .onnx model to predict images"""
def evaluate_onnx_model(img_path, data_transform, onnx_model_path, class_names, top_k=5):ort_session = onnxruntime.InferenceSession(onnx_model_path)img_pil = Image.open(img_path).convert('RGB')input_img = data_transform(img_pil)input_tensor = input_img.unsqueeze(0).numpy()ort_inputs = {'input': input_tensor}out = ort_session.run(['output'], ort_inputs)[0]def softmax(x):return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)prob_dist = softmax(out)result_dict = {label: float(prob_dist[0][i]) for i, label in enumerate(class_names)}result_dict = dict(sorted(result_dict.items(), key=lambda item: item[1], reverse=True))for key, value in list(result_dict.items())[:top_k]:print(f'Pre: {key}   Conf: {value:.3f}')confs_max = list(result_dict.values())[0]class_names_top = list(result_dict.keys())plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.axis('off')plt.title(f'Pre: {class_names_top[0]}   Conf: {confs_max:.3f}')plt.imshow(img_pil)plt.subplot(1, 2, 2)bars = plt.bar(class_names_top[:top_k], list(result_dict.values())[:top_k], color='lightcoral')plt.xlabel('Class Names')plt.ylabel('Confidence')plt.title('Top 5 Predictions (Descending Order)')plt.xticks(rotation=45)plt.ylim(0, 1)plt.tight_layout()for bar, conf in zip(bars, list(result_dict.values())[:top_k]):yval = bar.get_height()plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, f'{conf:.3f}', ha='center', va='bottom')plt.savefig('predict_image_with_bars.jpg')if __name__ == '__main__':data_transform = transforms.Compose([transforms.Resize((opt.loadsize, opt.loadsize)), transforms.ToTensor(),transforms.Normalize(mean, std)])image_datasets = ImageFolder(opt.dataset_test, data_transform)dataloaders = DataLoader(image_datasets, batch_size=512, shuffle=True)ptl_model_path = opt.checkpoints + 'model.ptl'pth_model_path = opt.checkpoints + 'model.pth'onnx_model_path = opt.checkpoints + 'model.onnx'ptl_model = torch.jit.load(ptl_model_path).to(device)pth_model = torch.load(pth_model_path).to(device)evaluate_image_single(opt.test_img_path, data_transform, pth_model, class_names, top_k=5)     # Predicting a single image# evaluate_image_dir(pth_model, dataloaders, class_names)     # Predicting folder images# evaluate_onnx_model(opt.test_img_path, data_transform, onnx_model_path, class_names, top_k=5)   # Predicting a single image

使用evaluate_image_single函數(shù)對datasets/test/zigzag/zigzag-4508464694951936.png圖片進行預測,結(jié)果如下:
在這里插入圖片描述
使用evaluate_image_dir函數(shù)對datasets/test路徑內(nèi)的圖像進行預測,結(jié)果如下:

Top-1 Accuracy: 0.6833
Top-3 Accuracy: 0.8521
Top-5 Accuracy: 0.8933
Accuracy: 0.6833
Precision: 0.6875
Recall: 0.6833
F1 Score: 0.6817
                         precision    recall  f1-score   supportThe Eiffel Tower       0.83      0.88      0.85      1000
The Great Wall of China       0.47      0.36      0.41      1000The Mona Lisa       0.68      0.86      0.76      1000airplane       0.83      0.74      0.78      1000alarm clock       0.76      0.76      0.76      1000ambulance       0.70      0.65      0.67      1000angel       0.87      0.78      0.82      1000animal migration       0.47      0.66      0.55      1000ant       0.77      0.74      0.75      1000anvil       0.80      0.66      0.72      1000apple       0.82      0.85      0.83      1000arm       0.74      0.69      0.71      1000asparagus       0.54      0.44      0.48      1000axe       0.69      0.67      0.68      1000backpack       0.61      0.75      0.67      1000banana       0.68      0.72      0.70      1000bandage       0.83      0.71      0.77      1000barn       0.66      0.68      0.67      1000baseball       0.77      0.71      0.74      1000baseball bat       0.75      0.73      0.74      1000basket       0.71      0.62      0.66      1000basketball       0.62      0.72      0.66      1000bat       0.79      0.62      0.69      1000bathtub       0.60      0.64      0.62      1000beach       0.58      0.65      0.61      1000bear       0.46      0.31      0.37      1000beard       0.56      0.73      0.63      1000bed       0.80      0.67      0.73      1000bee       0.82      0.74      0.78      1000belt       0.78      0.55      0.64      1000bench       0.59      0.53      0.56      1000bicycle       0.73      0.72      0.72      1000binoculars       0.74      0.77      0.76      1000bird       0.47      0.43      0.45      1000birthday cake       0.52      0.64      0.57      1000blackberry       0.46      0.42      0.44      1000blueberry       0.58      0.47      0.52      1000book       0.72      0.78      0.75      1000boomerang       0.73      0.70      0.71      1000bottlecap       0.58      0.54      0.56      1000bowtie       0.87      0.86      0.86      1000bracelet       0.68      0.60      0.64      1000brain       0.59      0.60      0.59      1000bread       0.54      0.63      0.58      1000bridge       0.61      0.64      0.63      1000broccoli       0.58      0.70      0.64      1000broom       0.56      0.68      0.61      1000bucket       0.62      0.66      0.64      1000bulldozer       0.69      0.70      0.70      1000bus       0.56      0.42      0.48      1000bush       0.47      0.65      0.55      1000butterfly       0.86      0.88      0.87      1000cactus       0.69      0.87      0.77      1000cake       0.53      0.42      0.47      1000calculator       0.76      0.82      0.79      1000calendar       0.54      0.50      0.52      1000camel       0.82      0.84      0.83      1000camera       0.87      0.74      0.80      1000camouflage       0.23      0.43      0.30      1000campfire       0.72      0.77      0.75      1000candle       0.75      0.73      0.74      1000cannon       0.77      0.69      0.72      1000canoe       0.67      0.63      0.65      1000car       0.65      0.63      0.64      1000carrot       0.75      0.82      0.78      1000castle       0.79      0.72      0.75      1000cat       0.69      0.66      0.68      1000ceiling fan       0.83      0.64      0.72      1000cell phone       0.62      0.60      0.61      1000cello       0.51      0.67      0.58      1000chair       0.83      0.80      0.81      1000chandelier       0.74      0.71      0.73      1000church       0.72      0.67      0.69      1000circle       0.53      0.86      0.66      1000clarinet       0.53      0.63      0.58      1000clock       0.86      0.77      0.82      1000cloud       0.73      0.69      0.71      1000coffee cup       0.67      0.43      0.52      1000compass       0.69      0.78      0.73      1000computer       0.79      0.62      0.69      1000cookie       0.68      0.80      0.74      1000cooler       0.47      0.33      0.38      1000couch       0.76      0.82      0.79      1000cow       0.70      0.57      0.63      1000crab       0.70      0.72      0.71      1000crayon       0.44      0.52      0.47      1000crocodile       0.65      0.57      0.60      1000crown       0.87      0.87      0.87      1000cruise ship       0.76      0.69      0.73      1000cup       0.43      0.50      0.47      1000diamond       0.73      0.88      0.80      1000dishwasher       0.56      0.47      0.51      1000diving board       0.53      0.54      0.53      1000dog       0.50      0.41      0.45      1000dolphin       0.79      0.59      0.68      1000donut       0.75      0.88      0.81      1000door       0.69      0.72      0.70      1000dragon       0.52      0.42      0.47      1000dresser       0.75      0.65      0.70      1000drill       0.78      0.71      0.75      1000drums       0.71      0.68      0.70      1000duck       0.68      0.49      0.57      1000dumbbell       0.78      0.80      0.79      1000ear       0.81      0.75      0.78      1000elbow       0.74      0.62      0.68      1000elephant       0.66      0.66      0.66      1000envelope       0.87      0.94      0.90      1000eraser       0.50      0.61      0.55      1000eye       0.83      0.85      0.84      1000eyeglasses       0.84      0.80      0.82      1000face       0.62      0.64      0.63      1000fan       0.76      0.60      0.67      1000feather       0.58      0.60      0.59      1000fence       0.67      0.71      0.69      1000finger       0.70      0.63      0.67      1000fire hydrant       0.56      0.64      0.60      1000fireplace       0.74      0.67      0.71      1000firetruck       0.71      0.50      0.59      1000fish       0.89      0.85      0.87      1000flamingo       0.69      0.75      0.72      1000flashlight       0.80      0.82      0.81      1000flip flops       0.64      0.75      0.69      1000floor lamp       0.77      0.70      0.74      1000flower       0.79      0.83      0.81      1000flying saucer       0.65      0.64      0.64      1000foot       0.68      0.66      0.67      1000fork       0.81      0.79      0.80      1000frog       0.46      0.47      0.47      1000frying pan       0.78      0.76      0.77      1000garden       0.59      0.63      0.61      1000garden hose       0.42      0.28      0.33      1000giraffe       0.87      0.80      0.84      1000goatee       0.72      0.73      0.72      1000golf club       0.60      0.62      0.61      1000grapes       0.68      0.65      0.66      1000grass       0.59      0.83      0.69      1000guitar       0.68      0.50      0.58      1000hamburger       0.66      0.83      0.73      1000hammer       0.71      0.75      0.73      1000hand       0.83      0.83      0.83      1000harp       0.83      0.78      0.80      1000hat       0.72      0.71      0.72      1000headphones       0.92      0.91      0.92      1000hedgehog       0.73      0.74      0.73      1000helicopter       0.81      0.83      0.82      1000helmet       0.63      0.66      0.64      1000hexagon       0.70      0.73      0.72      1000hockey puck       0.59      0.61      0.60      1000hockey stick       0.59      0.54      0.56      1000horse       0.53      0.85      0.65      1000hospital       0.80      0.68      0.74      1000hot air balloon       0.79      0.72      0.75      1000hot dog       0.60      0.63      0.62      1000hot tub       0.58      0.51      0.54      1000hourglass       0.86      0.87      0.87      1000house       0.77      0.77      0.77      1000house plant       0.85      0.82      0.83      1000hurricane       0.39      0.45      0.42      1000ice cream       0.82      0.85      0.84      1000jacket       0.75      0.72      0.74      1000jail       0.71      0.72      0.71      1000kangaroo       0.73      0.71      0.72      1000key       0.71      0.76      0.74      1000keyboard       0.50      0.48      0.49      1000knee       0.63      0.68      0.65      1000ladder       0.88      0.91      0.89      1000lantern       0.70      0.53      0.60      1000laptop       0.63      0.80      0.71      1000leaf       0.73      0.71      0.72      1000leg       0.58      0.50      0.54      1000light bulb       0.69      0.79      0.73      1000lighthouse       0.71      0.74      0.72      1000lightning       0.76      0.69      0.72      1000line       0.55      0.82      0.66      1000lion       0.70      0.76      0.73      1000lipstick       0.59      0.69      0.63      1000lobster       0.61      0.47      0.53      1000lollipop       0.76      0.85      0.80      1000mailbox       0.75      0.66      0.70      1000map       0.65      0.73      0.68      1000marker       0.39      0.16      0.23      1000matches       0.52      0.47      0.49      1000megaphone       0.80      0.70      0.75      1000mermaid       0.76      0.84      0.80      1000microphone       0.64      0.73      0.68      1000microwave       0.79      0.75      0.77      1000monkey       0.59      0.56      0.57      1000moon       0.69      0.60      0.64      1000mosquito       0.48      0.55      0.51      1000motorbike       0.64      0.62      0.63      1000mountain       0.74      0.80      0.77      1000mouse       0.53      0.46      0.49      1000moustache       0.75      0.72      0.73      1000mouth       0.72      0.76      0.74      1000mug       0.54      0.65      0.59      1000mushroom       0.66      0.76      0.70      1000nail       0.58      0.66      0.62      1000necklace       0.75      0.63      0.68      1000nose       0.69      0.75      0.72      1000ocean       0.54      0.54      0.54      1000octagon       0.71      0.62      0.66      1000octopus       0.89      0.83      0.86      1000onion       0.75      0.68      0.71      1000oven       0.50      0.39      0.44      1000owl       0.68      0.65      0.67      1000paint can       0.51      0.49      0.50      1000paintbrush       0.58      0.63      0.61      1000palm tree       0.73      0.83      0.78      1000panda       0.66      0.62      0.64      1000pants       0.75      0.68      0.71      1000paper clip       0.75      0.78      0.76      1000parachute       0.81      0.79      0.80      1000parrot       0.54      0.59      0.56      1000passport       0.60      0.55      0.58      1000peanut       0.70      0.73      0.71      1000pear       0.72      0.80      0.76      1000peas       0.70      0.56      0.62      1000pencil       0.58      0.60      0.59      1000penguin       0.69      0.78      0.73      1000piano       0.65      0.66      0.65      1000pickup truck       0.60      0.64      0.62      1000picture frame       0.68      0.89      0.77      1000pig       0.77      0.56      0.65      1000pillow       0.60      0.58      0.59      1000pineapple       0.80      0.85      0.82      1000pizza       0.65      0.77      0.70      1000pliers       0.69      0.55      0.61      1000police car       0.67      0.68      0.67      1000pond       0.40      0.47      0.43      1000pool       0.51      0.23      0.32      1000popsicle       0.70      0.79      0.75      1000postcard       0.74      0.58      0.65      1000potato       0.54      0.40      0.46      1000power outlet       0.61      0.72      0.66      1000purse       0.64      0.69      0.66      1000rabbit       0.66      0.80      0.72      1000raccoon       0.43      0.44      0.44      1000radio       0.71      0.59      0.64      1000rain       0.77      0.90      0.83      1000rainbow       0.79      0.92      0.85      1000rake       0.69      0.67      0.68      1000remote control       0.67      0.68      0.67      1000rhinoceros       0.65      0.75      0.69      1000river       0.66      0.61      0.64      1000roller coaster       0.70      0.52      0.60      1000rollerskates       0.86      0.83      0.84      1000sailboat       0.84      0.87      0.86      1000sandwich       0.50      0.68      0.57      1000saw       0.81      0.83      0.82      1000saxophone       0.79      0.77      0.78      1000school bus       0.51      0.44      0.47      1000scissors       0.80      0.84      0.82      1000scorpion       0.70      0.76      0.73      1000screwdriver       0.58      0.62      0.60      1000sea turtle       0.79      0.73      0.76      1000see saw       0.85      0.79      0.82      1000shark       0.72      0.72      0.72      1000sheep       0.75      0.80      0.77      1000shoe       0.73      0.75      0.74      1000shorts       0.67      0.76      0.71      1000shovel       0.62      0.73      0.67      1000sink       0.62      0.76      0.68      1000skateboard       0.83      0.85      0.84      1000skull       0.86      0.83      0.85      1000skyscraper       0.65      0.56      0.60      1000sleeping bag       0.55      0.59      0.57      1000smiley face       0.74      0.80      0.77      1000snail       0.79      0.90      0.84      1000snake       0.65      0.66      0.65      1000snorkel       0.79      0.73      0.76      1000snowflake       0.79      0.84      0.81      1000snowman       0.83      0.90      0.86      1000soccer ball       0.69      0.70      0.69      1000sock       0.77      0.75      0.76      1000speedboat       0.65      0.65      0.65      1000spider       0.72      0.79      0.76      1000spoon       0.69      0.57      0.63      1000spreadsheet       0.67      0.62      0.65      1000square       0.52      0.84      0.65      1000squiggle       0.41      0.40      0.40      1000squirrel       0.71      0.74      0.72      1000stairs       0.90      0.91      0.90      1000star       0.93      0.91      0.92      1000steak       0.53      0.46      0.49      1000stereo       0.61      0.68      0.64      1000stethoscope       0.87      0.75      0.81      1000stitches       0.71      0.79      0.75      1000stop sign       0.86      0.88      0.87      1000stove       0.71      0.66      0.69      1000strawberry       0.80      0.80      0.80      1000streetlight       0.75      0.71      0.73      1000string bean       0.51      0.39      0.44      1000submarine       0.83      0.67      0.74      1000suitcase       0.75      0.57      0.64      1000sun       0.87      0.88      0.87      1000swan       0.69      0.67      0.68      1000sweater       0.68      0.65      0.67      1000swing set       0.89      0.90      0.89      1000sword       0.85      0.81      0.83      1000t-shirt       0.80      0.78      0.79      1000table       0.73      0.76      0.74      1000teapot       0.82      0.77      0.80      1000teddy-bear       0.66      0.74      0.70      1000telephone       0.67      0.54      0.60      1000television       0.88      0.85      0.86      1000tennis racquet       0.86      0.74      0.80      1000tent       0.80      0.77      0.78      1000tiger       0.53      0.47      0.50      1000toaster       0.59      0.70      0.64      1000toe       0.67      0.63      0.65      1000toilet       0.74      0.80      0.77      1000tooth       0.72      0.74      0.73      1000toothbrush       0.74      0.76      0.75      1000toothpaste       0.54      0.56      0.55      1000tornado       0.63      0.69      0.66      1000tractor       0.65      0.71      0.68      1000traffic light       0.84      0.84      0.84      1000train       0.61      0.74      0.67      1000tree       0.72      0.75      0.73      1000triangle       0.87      0.93      0.90      1000trombone       0.58      0.48      0.53      1000truck       0.50      0.41      0.45      1000trumpet       0.65      0.49      0.56      1000umbrella       0.91      0.86      0.88      1000underwear       0.83      0.64      0.72      1000van       0.46      0.58      0.51      1000vase       0.82      0.67      0.74      1000violin       0.52      0.52      0.52      1000washing machine       0.74      0.78      0.76      1000watermelon       0.56      0.66      0.61      1000waterslide       0.57      0.70      0.63      1000whale       0.71      0.74      0.72      1000wheel       0.82      0.50      0.62      1000windmill       0.82      0.77      0.79      1000wine bottle       0.77      0.81      0.79      1000wine glass       0.86      0.85      0.86      1000wristwatch       0.72      0.74      0.73      1000yoga       0.60      0.57      0.58      1000zebra       0.73      0.66      0.69      1000zigzag       0.73      0.75      0.74      1000accuracy                           0.68    340000macro avg       0.69      0.68      0.68    340000weighted avg       0.69      0.68      0.68    340000
http://www.risenshineclean.com/news/47181.html

相關(guān)文章:

  • 做網(wǎng)站媒體內(nèi)蒙古seo優(yōu)化
  • 如何在淘寶上做自己的網(wǎng)站廣州疫情最新數(shù)據(jù)
  • 做一款app需要網(wǎng)站嗎鄭州百度公司地址
  • 做網(wǎng)站在哪里租服務器新網(wǎng)
  • 培訓型網(wǎng)站建設(shè)網(wǎng)站搜索引擎優(yōu)化的基本內(nèi)容
  • 做網(wǎng)站認證違法嗎煙臺seo
  • 動態(tài)網(wǎng)站設(shè)計論文3000字seoul是哪個國家
  • 中國鐵建華南建設(shè)有限公司網(wǎng)站十大搜索引擎地址
  • wordpress文章標題字體大小東莞市網(wǎng)絡(luò)seo推廣服務機構(gòu)
  • 國外 上海網(wǎng)站建設(shè)google搜索網(wǎng)址
  • 常州網(wǎng)站建設(shè)多少錢收錄網(wǎng)站有哪些
  • 網(wǎng)站搭建培訓學電腦培訓班
  • 網(wǎng)站建設(shè)的相關(guān)書籍今日頭條鄭州頭條新聞
  • 順德營銷型網(wǎng)站建設(shè)查關(guān)鍵詞的排名工具
  • 網(wǎng)站建設(shè)方案標準模板seo技術(shù)交流
  • 浙江建設(shè)廳網(wǎng)站官網(wǎng)seo關(guān)鍵詞排名系統(tǒng)
  • 做網(wǎng)站都需要哪些技術(shù)網(wǎng)絡(luò)推廣和seo
  • 電子商務網(wǎng)站軟件建設(shè)的核心是武漢大學人民醫(yī)院地址
  • 香港公司網(wǎng)站備案公司建立網(wǎng)站的步驟
  • 做二手房網(wǎng)站有哪些seo營銷是什么
  • 海寧高端高端網(wǎng)站設(shè)計人工智能培訓機構(gòu)排名
  • 讓人做網(wǎng)站 需要準備什么軟件深圳英文站seo
  • 網(wǎng)站備案號在哪里查詢美國seo薪酬
  • 網(wǎng)站機房建設(shè)有助于怎么做盲盒
  • 開封企業(yè)網(wǎng)絡(luò)推廣方案seo和sem的區(qū)別
  • 鐘表玻璃東莞網(wǎng)站建設(shè)寧波seo網(wǎng)絡(luò)推廣軟件系統(tǒng)
  • h5 網(wǎng)站建設(shè)網(wǎng)絡(luò)營銷活動方案
  • 深圳網(wǎng)站建設(shè)定制網(wǎng)站seo推廣多少錢
  • 做網(wǎng)站那個公司網(wǎng)站建設(shè)公司官網(wǎng)
  • 網(wǎng)站建設(shè)進展推進表旺道seo軟件技術(shù)