๐ Enhancing SwarmFormer: A Smarter Approach to Efficient Sequence Modeling
In the pursuit of refining efficient sequence modeling architectures, updates to SwarmFormer have been made, enhancing its performance, scalability, and stability. The new design features hierarchical attention, dynamic clustering, and gated feedback mechanisms, enabling more effective long - sequence processing with reduced computational overhead.
๐ Quick Start
This README focuses on the enhancements made to the SwarmFormer architecture. For a quick start on using the model, please refer to the official documentation of the transformers
library.
โจ Features
Why Improve SwarmFormer?
The original SwarmFormer introduced a token - cluster interaction model. Tokens self - organized into clusters, exchanged information at a higher level, and then propagated back refined representations. However, it had several limitations:
- โ Fixed cluster assignments led to rigid token - grouping.
- โ Rolling shifts for local attention were not optimal for capturing fine - grained dependencies.
- โ Cluster - to - token updates lacked gating, causing noisy updates.
- โ No weight - sharing in attention layers, increasing parameter count.
To address these issues, a set of key enhancements were introduced to improve the modelโs expressiveness while maintaining computational efficiency.
Key Enhancements in Our New SwarmFormer Architecture
1. Local Windowed Attention Instead of Rolling Shifts
We replaced rolling shift attention with local windowed attention (similar to Sliding Transformers and Conv Filters). This allows more efficient local feature extraction without redundant shifts, improving locality modeling.
2. Multi - Head Attention Over Clusters
Rather than using a single attention mechanism over clusters, we applied Multi - Head Self - Attention (MHA). This enables each attention head to learn different cluster - token relationships, improving contextual representation.
3. Token - to - Cluster Gating Instead of Uniform Chunking
Previously, tokens were uniformly assigned to clusters, which limited flexibility. We now use an attention - based dynamic routing mechanism, allowing tokens to adaptively choose their cluster. This improves semantic coherence in cluster formation.
4. Gated Feedback Mechanism for Stable Token Updates
Instead of directly updating token embeddings from clusters, we now introduce a residual MLP gating mechanism. This filters out noisy cluster updates, ensuring that only relevant information is propagated back to tokens.
5. Layer Normalization Before Every MLP & Attention Block
We found that adding LayerNorm before every feedforward and attention layer significantly stabilizes training, improving gradient flow and convergence.
6. Weight - Tying for Linear Projections in Cluster Attention
To reduce model size without compromising expressiveness, we now share weights across query, key, and value projections in the GlobalClusterAttention module. This optimization reduces the number of trainable parameters while maintaining performance.
7. Hierarchical Clustering with a Pyramid Structure
Instead of using fixed cluster sizes across all layers, we now implement a hierarchical pyramid:
- โ
Lower layers focus on fine - grained local interactions (more clusters).
- โ
Higher layers process abstract, coarse - grained representations (fewer clusters).
This multi - scale cluster formation allows the model to efficiently propagate high - level information without losing local details.
8. Gumbel - Softmax for Differentiable Clustering
To improve trainability of cluster assignments, we implemented Gumbel - Softmax sampling. This enables the model to learn cluster assignments via backpropagation, allowing reinforcement signals (such as cluster coherence) to guide optimization.
๐ง Technical Details
Computational Complexity of Original SwarmFormer
Token - to - Cluster Attention:
In the original SwarmFormer, each token attends to all clusters, resulting in (O(NCd)), where:
- N = Sequence length
- C = Number of clusters
- d = Hidden dimension
Cluster - to - Token Broadcast:
Each cluster updates all tokens, leading to another (O(NCd)).
Total Complexity (Original SwarmFormer):
(O(NCd)+O(NCd)=O(2NCd))
Computational Complexity of New SwarmFormer
Local Windowed Attention Instead of Rolling Shift Attention:
Instead of global attention over all tokens, each token only attends to a local window of size w (typically (w\ll N)): (O(Nwd)). This replaces the rolling shift operation, making it significantly cheaper.
Multi - Head Cluster Attention with Weight Sharing:
In the original version, query, key, and value projections had separate weights. Now, we tie the weights across these projections, reducing the number of parameters and FLOPs in cluster attention layers. The attention complexity remains (O(NCd)), but with fewer matrix multiplications.
Token - to - Cluster Gating:
Instead of every token updating every cluster, tokens selectively update clusters based on learned routing. This reduces the number of updates from all tokens to all clusters to only a fraction p of tokens participating: (O(pNCd)), where (p < 1). Since p is usually 0.5 or lower, this significantly cuts down computation.
Gated Feedback Mechanism (MLP Filtering):
Instead of fully propagating updates from clusters to tokens, we apply a residual MLP with gating before broadcasting updates. The MLP has complexity (O(Nd^{2})), but prevents unnecessary updates, reducing the effective computation in later layers.
Hierarchical Clustering with a Pyramid Structure:
Instead of a fixed number of clusters at all layers, we gradually reduce clusters as we go deeper:
- Lower layers: C clusters
- Middle layers: C/2 clusters
- Top layers: C/4 clusters
This results in an effective reduction in clustering computation: (O(NCd + NC/2d+NC/4d+\cdots)). This forms a geometric series, reducing the total computational cost.
Final Complexity Comparison
Property |
Details |
Model Type |
Text Classification |
Training Data |
stanfordnlp/imdb |
Model |
Complexity |
Original SwarmFormer |
(O(2NCd)) |
New SwarmFormer |
(O(Nwd + pNCd+Nd^{2})) |
Since:
- (w\ll N) (window attention reduces cost)
- (p < 1) (fewer cluster updates)
- (d^{2}) term is only in a small MLP, not full attention layers
- Hierarchical clustering reduces total cluster interactions
We get (O(NCd)>O(Nwd + pNCd+Nd^{2})), which shows the new architecture is computationally less expensive.
Conclusion: New SwarmFormer is More Efficient
- โ
Lower FLOPs due to windowed attention and hierarchical clustering
- โ
Fewer redundant updates with gated feedback and token - to - cluster gating
- โ
Weight sharing further reduces parameter count
Bottom line: ๐ The new SwarmFormer architecture achieves faster training and inference while maintaining or improving performance!
References
@article{legg2025swarmformer,
title={SwarmFormer: Local-Global Hierarchical Attention via Swarming Token Representations},
author={Legg, Jordan and Sturmanis, Mikus and {Takara.ai}},
journal={Takara.ai Research},
year={2025},
url={https://takara.ai/papers/SwarmFormer-Local-Global-Hierarchical-Attention-via-Swarming-Token-Representations.pdf}
}
๐ License
This project is licensed under the Apache - 2.0 license.