🚀 增強版SwarmFormer:高效序列建模的智能新方法
本項目持續致力於優化高效序列建模架構,對 SwarmFormer 進行了一系列更新,顯著提升了其性能、可擴展性和穩定性。新設計引入了 分層注意力、動態聚類和門控反饋機制,使模型能夠更有效地處理長序列,同時降低計算開銷。
🚀 快速開始
本項目主要聚焦於對SwarmFormer模型的改進,暫未提供具體的代碼運行等快速開始的操作步驟。若有相關代碼及使用需求,可根據後續補充的代碼部分進行操作。
✨ 主要特性
為何改進SwarmFormer?
原始的SwarmFormer引入了 令牌 - 集群交互模型,其中令牌自組織成集群,在更高層次上交換信息,然後傳播回精煉的表示。雖然這種方法有效地處理了長距離依賴關係,但它存在一些侷限性:
- ❌ 固定的集群分配 導致令牌分組僵化。
- ❌ 用於局部注意力的滾動移位 並非捕捉細粒度依賴關係的最佳方式。
- ❌ 集群到令牌的更新缺乏門控,導致更新嘈雜。
- ❌ 注意力層中沒有權重共享,增加了參數數量。
為了解決這些問題,我們引入了 一系列關鍵改進,在保持計算效率的同時提高了模型的表達能力。
新SwarmFormer架構的關鍵改進
- 使用局部窗口注意力取代滾動移位:我們用 局部窗口注意力(類似於滑動變換器和卷積濾波器)取代了滾動移位注意力。這允許更有效地提取局部特徵,而無需冗餘移位,從而改善局部建模。
- 對集群應用多頭注意力:我們沒有對集群使用單一的注意力機制,而是應用了 多頭自注意力(MHA)。這使每個注意力頭能夠學習 不同的集群 - 令牌關係,從而改善上下文表示。
- 使用令牌到集群的門控取代均勻分塊:以前,令牌是 均勻分配到集群 的,這限制了靈活性。我們現在使用基於注意力的動態路由機制,允許令牌 自適應地選擇其集群。這提高了集群形成中的 語義連貫性。
- 引入門控反饋機制以實現穩定的令牌更新:我們不再直接從集群更新令牌嵌入,而是引入了 殘差MLP門控機制。這過濾掉了 嘈雜的集群更新,確保只有 相關信息 被傳播回令牌。
- 在每個MLP和注意力塊之前進行層歸一化:我們發現,在每個前饋和注意力層之前添加 層歸一化(LayerNorm) 顯著穩定了訓練,改善了梯度流和收斂性。
- 在集群注意力中進行線性投影的權重綁定:為了 在不影響表達能力的情況下減小模型大小,我們現在在 GlobalClusterAttention 模塊中的 查詢、鍵和值投影 之間共享權重。這種優化減少了可訓練參數的數量,同時保持了性能。
- 採用金字塔結構的分層聚類:我們不再在所有層使用 固定的集群大小,而是實現了 分層金字塔:
- ✅ 較低層 專注於 細粒度的局部交互(更多集群)。
- ✅ 較高層 處理 抽象的、粗粒度的表示(較少集群)。
這種 多尺度集群形成 允許模型在不丟失局部細節的情況下有效地 傳播高層信息。
- 使用Gumbel - 軟最大化進行可微聚類:為了提高 集群分配的可訓練性,我們實現了 Gumbel - 軟最大化採樣。這使模型能夠通過反向傳播學習 集群分配,允許強化信號(如集群連貫性)指導優化。
🔧 技術細節
計算複雜度分析
原始SwarmFormer的計算複雜度
- 令牌到集群的注意力:在原始的SwarmFormer中,每個令牌關注所有集群,複雜度為 (O(NCd)),其中 (N) 是序列長度,(C) 是集群數量,(d) 是隱藏維度。
- 集群到令牌的廣播:每個集群更新所有令牌,複雜度同樣為 (O(NCd))。
- 總複雜度:原始SwarmFormer的總複雜度為 (O(2NCd))。
新SwarmFormer的計算複雜度
- 局部窗口注意力取代滾動移位注意力:每個令牌只關注大小為 (w)(通常 (w \ll N))的局部窗口,複雜度為 (O(Nwd)),取代了滾動移位操作,顯著降低了成本。
- 多頭集群注意力與權重共享:在原始版本中,查詢、鍵和值投影有單獨的權重。現在,我們在這些投影之間共享權重,減少了集群注意力層中的參數數量和浮點運算次數。注意力複雜度仍為 (O(NCd)),但矩陣乘法次數減少。
- 令牌到集群的門控:令牌不再統一分配到集群,而是根據學習到的路由選擇性地更新集群。這將從所有令牌到所有集群的更新數量減少到只有一小部分 (p) 的令牌參與,複雜度為 (O(pNCd)),其中 (p < 1)。由於 (p) 通常為 0.5 或更低,這顯著減少了計算量。
- 門控反饋機制(MLP過濾):我們不再直接從集群更新令牌嵌入,而是引入了殘差MLP門控機制。MLP的複雜度為 (O(Nd^2)),但它過濾掉了嘈雜的集群更新,確保只有相關信息被傳播回令牌,減少了後續層的有效計算量。
- 金字塔結構的分層聚類:我們不再在所有層使用固定的集群大小,而是實現了分層金字塔。較低層有 (C) 個集群,中間層有 (C/2) 個集群,頂層有 (C/4) 個集群。這導致聚類計算的有效減少,複雜度為 (O(NCd + NC/2d + NC/4d + …)),形成一個幾何級數,降低了總計算成本。
最終複雜度比較
模型 |
複雜度 |
原始SwarmFormer |
(O(2NCd)) |
新SwarmFormer |
(O(Nwd + pNCd + Nd^2)) |
由於 (w \ll N)(窗口注意力降低了成本),(p < 1)(較少的集群更新),(d^2) 項僅在小的MLP中,而不是在完整的注意力層中,並且分層聚類減少了總的集群交互,我們得到 (O(NCd) > O(Nwd + pNCd + Nd^2)),這表明新架構在計算上更高效。
結論:新SwarmFormer更高效
- ✅ 更低的浮點運算次數:由於窗口注意力和分層聚類,新架構的浮點運算次數更低。
- ✅ 更少的冗餘更新:門控反饋和令牌到集群的門控減少了冗餘更新。
- ✅ 權重共享:進一步減少了參數數量。
總結:🚀 新的SwarmFormer架構在保持或提高性能的同時,實現了更快的訓練和推理!
參考資料
@article{legg2025swarmformer,
title={SwarmFormer: Local-Global Hierarchical Attention via Swarming Token Representations},
author={Legg, Jordan and Sturmanis, Mikus and {Takara.ai}},
journal={Takara.ai Research},
year={2025},
url={https://takara.ai/papers/SwarmFormer-Local-Global-Hierarchical-Attention-via-Swarming-Token-Representations.pdf}
}
📄 許可證
本項目採用Apache - 2.0許可證。