代刷網(wǎng)可以做網(wǎng)站地圖全媒體廣告代理
提高模型復(fù)用性,讓模型對應(yīng)的配置更加清晰,代碼書寫條理
學(xué)習(xí)自https://zhuanlan.zhihu.com/p/409662511
Project
├── checkpoints # 存放模型
├── data # 定義各種用于訓(xùn)練測試的數(shù)據(jù)集
├── eval.py # 測試代碼
├── loss.py # 定義的各種loss
├── metrics.py # 定義約定俗成的評價指標
├── model/src # 定義實驗中的模型
├── options.py # 定義各種實驗的參數(shù),以命令行形式傳入
├── README.md # 介紹report
├── scripts # 訓(xùn)練、測試腳本(訓(xùn)練、測試的運行命令)
├── train.py # 訓(xùn)練代碼
└── utils # 訓(xùn)練工具代碼
文章目錄
- Checkpoints
- Scripts
- options.py
- train.py
Checkpoints
訓(xùn)練好的模型放在checkpoints里面,通常保存訓(xùn)練過程中的中間結(jié)果。主要包括:模型權(quán)重文件、模型配置文件、優(yōu)化器和日志文件等。
Scripts
每次訓(xùn)練或者測試用的腳本命令。
- 訓(xùn)練腳本:用于執(zhí)行模型訓(xùn)練的腳本文件,通常包括定義模型、加載數(shù)據(jù)、設(shè)置損失函數(shù)和優(yōu)化器、執(zhí)行循環(huán)等步驟。
- 評估腳本:用于評估模型性能的腳本文件。加載訓(xùn)練好的模型或者指定的checkpoints文件,對模型在測試集或驗證集上的表現(xiàn)進行評估。
- 預(yù)測腳本:……
- 數(shù)據(jù)預(yù)處理腳本:用于數(shù)據(jù)預(yù)處理和準備的腳本文件。
options.py
- 定義實驗參數(shù)。
def parse_common_args(parser):parser.add_argument('--model_type', type=str, default='base_model', help='used in model_entry.py')parser.add_argument('--data_type', type=str, default='base_dataset', help='used in data_entry.py')parser.add_argument('--save_prefix', type=str, default='pref', help='some comment for model or test result dir')parser.add_argument('--load_model_path', type=str, default='checkpoints/base_model_pref/0.pth', help='model path for pretrain or test')parser.add_argument('--load_not_strict', action='store_true', help='allow to load only common state dicts')parser.add_argument('--val_list', type=str, default='/data/dataset1/list/base/val.txt', help='val list in train, test list path in test')parser.add_argument('--gpus', nargs='+', type=int)return parserdef parse_train_args(parser):parser = parse_common_args(parser)...return parserdef parse_test_args(parser):parser = parse_common_args(parser)...return parser
- 路徑配置:定義數(shù)據(jù)集、模型、日志文件等路徑
DATA_PATH = '/path/to/dataset/'
MODEL_PATH = '/path/to/models/'
LOG_PATH = '/path/to/logs/'
- 數(shù)據(jù)處理
IMAGE_SIZE = (256, 256)
DATA_AUGMENTATION = True
- 加載模型超參
LEARNING_RATE = 0.001
BATCH_SIZE = 32
MAX_EPOCHS = 10
- 其他配置
train.py
主要任務(wù)是把整體寫好的內(nèi)容串起來
- 導(dǎo)入必要的庫和模塊
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from model import MyModel # 假設(shè)模型定義在model.py中
from options import * # 導(dǎo)入配置選項
- 數(shù)據(jù)加載和預(yù)處理
# 定義數(shù)據(jù)預(yù)處理和增強方式
transform = transforms.Compose([transforms.Resize(IMAGE_SIZE),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加載數(shù)據(jù)集
train_dataset = datasets.ImageFolder(root=DATA_PATH, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
- 模型定義和初始化
# 定義模型
model = MyModel()
# 如果有預(yù)訓(xùn)練模型,加載參數(shù)
# model.load_state_dict(torch.load(PRETRAINED_MODEL_PATH))
- 定義損失函數(shù)和優(yōu)化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
- 保存模型
torch.save(model.state_dict(), MODEL_SAVE_PATH)
- 可選的評估和測試
# 評估模型
model.eval()
with torch.no_grad():# 執(zhí)行評估代碼