昆明賢邦網(wǎng)站建設百度站長工具seo查詢
優(yōu)于立方復雜度的 Rust 中矩陣乘法
邁克·克維特
跟隨
更好的編程
143
中途:三次矩陣乘法
一、說明
????????幾年前,我在 C++ 年編寫了?Strassen 矩陣乘法算法的實現(xiàn),最近在?Rust 中重新實現(xiàn)了它,因為我繼續(xù)學習該語言。這是學習 Rust 性能特征和優(yōu)化技術的有用練習,因為盡管 Strassen 的算法復雜性優(yōu)于樸素方法,但它在算法結構中的分配和遞歸開銷中具有很高的常數(shù)因子。
- 通用算法
- 換位以獲得更好的性能
- 次立方:斯特拉森算法的工作原理
- 排比
- 標桿
- 分析和性能優(yōu)化
二、通用算法
????????一般(樸素)矩陣乘法算法是每個人在他們的第一堂線性代數(shù)課上學習的三個嵌套循環(huán)方法,大多數(shù)人會將其識別為?O(n3)
pub fn
mult_naive (a: &Matrix, b: &Matrix) -> Matrix {if a.rows == b.cols {let m = a.rows;let n = a.cols;// preallocatelet mut c: Vec<f64> = Vec::with_capacity(m * m);for i in 0..m {for j in 0..m {let mut sum: f64 = 0.0;for k in 0..n {sum += a.at(i, k) * b.at(k, j);}c.push(sum);}}return Matrix::with_vector(c, m, m);} else {panic!("Matrix sizes do not match");}
}
????????這種算法很慢,不僅因為三個嵌套循環(huán),還因為按列通過而不是按行的內(nèi)部循環(huán)遍歷對于 CPU 緩存命中率來說是可怕的。B
b.at(k, j)
三、換位以獲得更好的性能
? ? ? ? 轉置樸素方法允許 B 上的乘法迭代在行而不是列上運行,將矩陣 B 的乘法步幅重新組織為更有利于緩存的格式。從而變成A x B
A x B^t
?????????它涉及一個新的矩陣分配(無論如何,在這個實現(xiàn)中)和一個完整的矩陣迭代(一個 O(n2) 操作,更準確地說,這種方法是 O(n3) + O(n2))——我將進一步展示它的性能有多好。它如下所示:
fn multiply_transpose (A: Matrix, B: Matrix):C = new Matrix(A.num_rows, B.num_cols)// Construct transpose; requires allocation and iteration through BB’ = B.transpose()for i in 0 to A.num_rows:for j in 0 to B'.num_rows:sum = 0;for k in 0 to A.num_cols:// Sequential access of B'[j, k] is much faster than B[k, j]sum += A[i, k] * B'[j, k]C[i, j] = sumreturn C
四、次立方:斯特拉森算法的工作原理
????????要了解 Strassen 算法的工作原理(此處為 Rust 代碼),首先考慮矩陣如何用象限表示。要概念化它的外觀:
????????在樸素算法中使用此象限模型,結果矩陣?C?的四個象限中的每一個都是兩個子矩陣乘積的總和,總共產(chǎn)生 8 次乘法。
????????考慮到這八個乘法,每個乘法都在一個塊矩陣上運行,其行和列跨度約為 A 和 B 大小的一半,復雜性相同:
????????斯特拉森算法定義了由這些象限組成的七個中間塊矩陣:
????????僅通過?7?次乘法而不是 8 次乘法計算。這些乘法可以是遞歸斯特拉森乘法,并可用于組成最終矩陣:
由此產(chǎn)生的亞立方復雜度:
五、排比
????????中間矩陣 M1 的計算 ...M7 是一個令人尷尬的并行問題,因此也很容易檢測算法的并發(fā)變體(一旦你開始理解?Rust 關于閉包的規(guī)則)。
/*** Execute a recursive strassen multiplication of the given vectors, * from a thread contained within the provided thread pool.*/
fn
_par_run_strassen (a: Vec<f64>, b: Vec<f64>, m: usize, pool: &ThreadPool) -> Arc<Mutex<Option<Matrix>>> {let m1: Arc<Mutex<Option<Matrix>>> = Arc::new(Mutex::new(None));let m1_clone = Arc::clone(&m1);pool.execute(move|| { // Recurse with non-parallel algorithm once we're // in a working threadlet result = mult_strassen(&mut Matrix::with_vector(a, m, m),&mut Matrix::with_vector(b, m, m));*m1_clone.lock().unwrap() = Some(result);});return m1;
}
六、標桿
????????我編寫了一些快速的基準測試代碼,該代碼在不斷增加的矩陣維度范圍內(nèi)運行四種算法中的每一種進行幾次試驗,并報告每種算法的平均時間。
~/code/strassen ~>> ./strassen --lower 75 --upper 100 --factor 50 --trials 2running 50 groups of 2 trials with bounds between [75->3750, 100->5000]x y nxn naive transpose strassen par_strassen
75 100 7500 0.00ms 0.00ms 1.00ms 0.00ms
150 200 30000 6.50ms 4.00ms 4.00ms 1.00ms
225 300 67500 12.50ms 9.00ms 8.50ms 2.50ms
300 400 120000 26.50ms 22.00ms 18.00ms 5.50ms
[...]
3600 4800 17280000 131445.00ms 53683.50ms 21210.50ms 5660.00ms
3675 4900 18007500 141419.00ms 58530.00ms 28291.50ms 6811.00ms
3750 5000 18750000 154941.00ms 60990.00ms 26132.00ms 6613.00ms
????????然后,我通過以下方式可視化結果:pyplot
????????此圖顯示了矩陣從 7.5k 元素 () 到大約 19 萬 () 的乘法時間。你可以看到樸素算法在計算上變得不切實際的速度有多快,在高端需要兩分半鐘。N x M = 75 x 100
N x M = 3750 x 5000
????????相比之下,Strassen 算法的擴展更平滑,并行算法計算兩個 19M 個元素的矩陣的結果,而樸素算法只處理 3.6M 個元素所花費的時間。
????????對我來說最有趣的是算法的性能。如前所述,緩存性能的改進(以犧牲完整矩陣副本為代價)在這些結果中得到了清楚地證明 - 即使使用與該方法漸近等效的算法也是如此。transpose
naive
七、分析和性能優(yōu)化
????????這個文檔是理解 Rust 性能基礎知識的絕佳資源。在?Mac OS 上啟動并運行儀器進行分析是微不足道的,這要歸功于貨運儀器的 Rust 指南。這是調(diào)查分配行為、CPU 熱點和其他事情的絕佳工具。
在此過程中發(fā)生了一些變化:
- Strassen 代碼通過分而治之策略遞歸調(diào)用自己,但是一旦矩陣達到足夠小的大小,其高常數(shù)因子使其比一般矩陣算法慢。我發(fā)現(xiàn)這個點是大約?64?的行寬或列寬;通過提高吞吐量提高幾個因素來增加此閾值
2
- 斯特拉森算法要求矩陣填充到最接近的指數(shù) 2;減少這種情況以懶惰地確保矩陣只有偶數(shù)行和列?通過減少昂貴的大分配,將吞吐量提高了大約兩倍
- 將小矩陣回退算法從 更改為 導致大約 20% 的改進
naive
transpose
- 添加和添加到?Cargo.toml?發(fā)布構建標志大約提高了 5%。有趣的是,性能持續(xù)惡化
codegen-units = 1
lto = "thin"
lto = “true”
- 一絲不茍地刪除所有可能的副本大約提高了~10%
Vec
- 提供一些提示并刪除隨機訪問查找中的向量邊界檢查,又提高了大約 20%
#[inline]
/*** Returns the element at (i, j). Unsafe.*/#[inline]pub fn at (&self, i: usize, j: usize) -> f64 {unsafe {return *self.elements.get_unchecked(i * self.cols + j);}}
參考資料: