網(wǎng)站建設(shè)案例圖片做銷售最掙錢(qián)的10個(gè)行業(yè)
半監(jiān)督學(xué)習(xí)與數(shù)據(jù)增強(qiáng)(論文復(fù)現(xiàn))
本文所涉及所有資源均在傳知代碼平臺(tái)可獲取
文章目錄
- 半監(jiān)督學(xué)習(xí)與數(shù)據(jù)增強(qiáng)(論文復(fù)現(xiàn))
- 概述
- 算法原理
- 核心邏輯
- 效果演示
- 使用方式
概述
本文復(fù)現(xiàn)論文提出的半監(jiān)督學(xué)習(xí)方法,半監(jiān)督學(xué)習(xí)(Semi-supervised Learning)是一種機(jī)器學(xué)習(xí)方法,它將少量的標(biāo)注數(shù)據(jù)(帶有標(biāo)簽的數(shù)據(jù))和大量的未標(biāo)注數(shù)據(jù)(不帶標(biāo)簽的數(shù)據(jù))結(jié)合起來(lái)訓(xùn)練模型。在許多實(shí)際應(yīng)用中,標(biāo)注數(shù)據(jù)獲取成本高且困難,而未標(biāo)注數(shù)據(jù)通常較為豐富和容易獲取。因此,半監(jiān)督學(xué)習(xí)方法被引入并被用于利用未標(biāo)注數(shù)據(jù)來(lái)提高模型的性能和泛化能力
該論文介紹了一種基于一致性和置信度的半監(jiān)督學(xué)習(xí)方法 FixMatch。FixMatch首先使用模型為弱增強(qiáng)后的未標(biāo)注圖像生成偽標(biāo)簽。對(duì)于給定圖像,只有當(dāng)模型產(chǎn)生高置信度預(yù)測(cè)時(shí)才保留偽標(biāo)簽。然后,模型在輸入同一圖像的強(qiáng)增強(qiáng)版本時(shí)被訓(xùn)練去預(yù)測(cè)偽標(biāo)簽。FixMatch 在各種半監(jiān)督學(xué)習(xí)數(shù)據(jù)集上實(shí)現(xiàn)了先進(jìn)的性能
算法原理
FixMatch 結(jié)合了兩種半監(jiān)督學(xué)習(xí)方法:一致性正則化和偽標(biāo)簽。其主要?jiǎng)?chuàng)新點(diǎn)在于這兩種方法的結(jié)合以及在執(zhí)行一致性正則化時(shí)分別使用了弱增強(qiáng)和強(qiáng)增強(qiáng)。
FixMatch 的損失函數(shù)由兩個(gè)交叉熵?fù)p失項(xiàng)組成:一個(gè)用于有標(biāo)簽數(shù)據(jù)的監(jiān)督損失 lsl**s 和一個(gè)用于無(wú)標(biāo)簽數(shù)據(jù)的無(wú)監(jiān)督損失 lul**u 。具體來(lái)說(shuō),lsl**s 只是對(duì)弱增強(qiáng)有標(biāo)簽樣本應(yīng)用的標(biāo)準(zhǔn)交叉熵?fù)p失
其中 BB 表示 batch size,HH 表示交叉熵?fù)p失,pbp**b 表示標(biāo)記,pm(y∣α(xb))p**m(y∣α(x**b)) 表示模型對(duì)弱增強(qiáng)樣本的預(yù)測(cè)結(jié)果。
FixMatch 對(duì)每個(gè)無(wú)標(biāo)簽樣本計(jì)算一個(gè)偽標(biāo)簽,然后在標(biāo)準(zhǔn)交叉熵?fù)p失中使用該標(biāo)簽。為了獲得偽標(biāo)簽,我們首先計(jì)算模型對(duì)給定無(wú)標(biāo)簽圖像的弱增強(qiáng)版本的預(yù)測(cè)類別分布:qb=pm(y∣α(ub))q**b=p**m(y∣α(u**b))。然后,我們使用 qb=arg?max?qb*q*b=argmaxq**b 作為偽標(biāo)簽,但我們?cè)诮徊骒負(fù)p失中對(duì)模型對(duì) ubu**b 的強(qiáng)增強(qiáng)版本的輸出進(jìn)行約束
其中 μμ 表示無(wú)標(biāo)簽樣本與有標(biāo)簽樣本數(shù)量之比,1(max(qb)>τ)1(max(q**b)>τ) 當(dāng)前僅當(dāng) max(qb)>τmax(q**b)>τ 成立時(shí)為 1 否則為 0,ττ 表示置信度閾值,A(ub)A(u**b) 表示對(duì)無(wú)標(biāo)簽樣本的強(qiáng)增強(qiáng)。
FixMatch的總損失就是 ls+λulul**s+λul**u,其中 λuλ**u 是表示無(wú)標(biāo)簽損失相對(duì)權(quán)重的標(biāo)量超參數(shù)
FixMatch 利用兩種增強(qiáng)方法:“弱增強(qiáng)”和“強(qiáng)增強(qiáng)”。論文所使用的弱增強(qiáng)是一種標(biāo)準(zhǔn)的翻轉(zhuǎn)和位移增強(qiáng)策略。具體來(lái)說(shuō),除了SVHN數(shù)據(jù)集之外,我們?cè)谒袛?shù)據(jù)集上以50%的概率隨機(jī)水平翻轉(zhuǎn)圖像,并隨機(jī)在垂直和水平方向上平移圖像最多12.5%。對(duì)于“強(qiáng)增強(qiáng)”,我采用了基于隨機(jī)幅度采樣的 RandAugment,然后進(jìn)行了 Cutout 處理。
我在CIFAR-10、CIFAR-100 、SVHN 和 FER2013 數(shù)據(jù)集上對(duì) FixMatch 進(jìn)行了實(shí)驗(yàn)。關(guān)于使用的神經(jīng)網(wǎng)絡(luò),我在 CIFAR-10 和 SVHN 上使用了 Wide ResNet-28-2,在 CIFAR-100 上使用了 Wide ResNet-28-8,在 FER2013 上使用了 Wide ResNe-37-2。實(shí)驗(yàn)結(jié)果如下表所示
為了直觀展示 FixMatch 的效果,我在線部署了基于 FER2013 數(shù)據(jù)集訓(xùn)練的 Wide ResNe-37-2 模型。FER2013[2] 是一個(gè)面部表情識(shí)別數(shù)據(jù)集,其包含約 30000 張不同表情的面部 RGB 圖像,尺寸限制為 48×48。其主要標(biāo)簽可分為 7 種類型:憤怒(Angry),厭惡(Disgust),恐懼(Fear),快樂(lè)(Happy),悲傷(Sad),驚訝(Surprise),中性(Neutral)。厭惡表情的圖像數(shù)量最少,只有 600 張,而其他標(biāo)簽的樣本數(shù)量均接近 5,000 張
核心邏輯
具體的核心邏輯如下所示:
for epoch in range(epochs):model.train()train_tqdm = zip(labeled_dataloader, unlabeled_dataloader)for labeled_batch, unlabeled_batch in train_tqdm:optimizer.zero_grad()# 利用標(biāo)記樣本計(jì)算損失data = labeled_batch[0].to(device)labels = labeled_batch[1].to(device)logits = model(normalize(strong_aug(data)))loss = F.cross_entropy(logits, labels)# 計(jì)算未標(biāo)記樣本偽標(biāo)簽with torch.no_grad():data = unlabeled_batch[0].to(device)logits = model(normalize(weak_aug(data)))probs = F.softmax(logits, dim=-1)trusted = torch.max(probs, dim=-1).values > thresholdpseudo_labels = torch.argmax(probs[trusted], dim=-1)loss_factor = weight * torch.sum(trusted).item() / data.shape[0]# 利用未標(biāo)記樣本計(jì)算損失logits = model(normalize(strong_aug(data[trusted])))loss += loss_factor * F.cross_entropy(logits, pseudo_labels)# 反向梯度傳播并更新模型參數(shù)loss.backward()optimizer.step()
效果演示
網(wǎng)站提供了在線體驗(yàn)功能。用戶需要輸入一張長(zhǎng)寬盡可能相等且大小不超過(guò) 1MB 的正面臉部 JPG 圖像,網(wǎng)站就會(huì)返回圖片中人物表情所表達(dá)的情緒
使用方式
解壓附件壓縮包并進(jìn)入工作目錄。如果是Linux系統(tǒng),請(qǐng)使用如下命令
unzip FixMatch.zip
cd FixMatch
代碼的運(yùn)行環(huán)境可通過(guò)如下命令進(jìn)行配置
pip install -r requirements.txt
如果希望在本地運(yùn)行程序,請(qǐng)運(yùn)行如下命令
python main.py
如果希望在線部署,請(qǐng)運(yùn)行如下命令
python main-flask.py
文章代碼資源點(diǎn)擊附件獲取