🚀 LSG模型
LSG模型是LEGAL - BERT模型的小型版本,目前尚未進行額外的預訓練。它使用了相同數量的參數和層數,以及相同的分詞器。該模型能夠處理長序列,並且比Longformer或BigBird(來自Transformers)更快、更高效,它依賴於局部 + 稀疏 + 全局注意力機制(LSG)。
🚀 快速開始
此模型依賴自定義建模文件,使用時需要添加trust_remote_code=True
。具體可參考 #13467。
LSG的ArXiv 論文 可查看詳細信息。Github上的轉換腳本可通過此 鏈接 獲取。
✨ 主要特性
- 能夠處理長序列,相比Longformer或BigBird(來自Transformers),處理速度更快、效率更高。
- 依賴於局部 + 稀疏 + 全局注意力機制(LSG)。
- 模型具有“自適應”能力,可根據需要自動填充序列(在配置中
adaptive=True
)。不過,建議使用分詞器截斷輸入(truncation=True
),並可選擇以塊大小的倍數進行填充(pad_to_multiple_of=...
)。
- 支持編碼器 - 解碼器架構,但尚未進行廣泛測試。
- 採用PyTorch實現。
📦 安裝指南
此模型依賴自定義建模文件,使用時需添加trust_remote_code=True
。以下是加載模型和分詞器的示例代碼:
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
💻 使用示例
基礎用法
以下是使用模型進行掩碼填充和序列分類的基礎代碼示例:
from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer
model = AutoModelForMaskedLM.from_pretrained("ccdv/legal-lsg-small-uncased-4096", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
SENTENCES = ["Paris is the <mask> of France.", "The goal of life is <mask>."]
pipeline = FillMaskPipeline(model, tokenizer)
output = pipeline(SENTENCES, top_k=1)
output = [o[0]["sequence"] for o in output]
print(output)
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
trust_remote_code=True,
pool_with_global=True,
)
tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
SENTENCE = "This is a test for sequence classification. " * 300
token_ids = tokenizer(
SENTENCE,
return_tensors="pt",
truncation=True
)
output = model(**token_ids)
print(output)
高級用法
可以修改模型的各種參數,例如全局令牌的數量、局部塊大小、稀疏塊大小等。以下是一個修改參數的示例:
from transformers import AutoModel
model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
trust_remote_code=True,
num_global_tokens=16,
block_size=64,
sparse_block_size=64,
attention_probs_dropout_prob=0.0,
sparsity_factor=4,
sparsity_type="none",
mask_first_token=True
)
📚 詳細文檔
參數說明
可以更改以下各種參數:
屬性 |
詳情 |
全局令牌數量 (num_global_tokens ) |
可設置全局令牌的數量,默認值為1。 |
局部塊大小 (block_size ) |
局部塊的大小,默認值為128。 |
稀疏塊大小 (sparse_block_size ) |
稀疏塊的大小,默認值為128。 |
稀疏因子 (sparsity_factor ) |
稀疏因子,默認值為2。 |
掩碼第一個令牌 (mask_first_token ) |
由於第一個令牌與第一個全局令牌冗餘,可選擇掩碼第一個令牌。 |
默認參數在實踐中效果良好。如果內存不足,可以減小塊大小、增加稀疏因子並去除注意力分數矩陣中的丟棄率。
稀疏選擇類型
有6種不同的稀疏選擇模式,最佳類型取決於具體任務。
- 如果
sparse_block_size = 0
或sparsity_type = "none"
,則僅考慮局部注意力。
- 注意,對於長度小於2 * 塊大小的序列,稀疏選擇類型沒有影響。
以下是各種稀疏選擇類型的說明:
sparsity_type = "bos_pooling"
(新):
- 使用BOS令牌進行加權平均池化。
- 通常效果最佳,尤其是在稀疏因子較大(8、16、32)的情況下。
- 無額外參數。
sparsity_type = "norm"
:選擇範數最高的令牌。
- 在稀疏因子較小(2到4)時效果最佳。
- 無額外參數。
sparsity_type = "pooling"
:使用平均池化合並令牌。
- 在稀疏因子較小(2到4)時效果最佳。
- 無額外參數。
sparsity_type = "lsh"
:使用LSH算法對相似令牌進行聚類。
- 在稀疏因子較大(4+)時效果最佳。
- LSH依賴於隨機投影,因此不同種子的推理結果可能略有不同。
- 額外參數:
lsg_num_pre_rounds = 1
,在計算質心之前先合併令牌n次。
sparsity_type = "stride"
:每個頭使用由稀疏因子步長的不同令牌。
sparsity_type = "block_stride"
:每個頭使用由稀疏因子步長的令牌塊。
訓練全局令牌
以下是僅訓練全局令牌和分類頭的示例代碼:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
trust_remote_code=True,
pool_with_global=True,
num_global_tokens=16
)
tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
for name, param in model.named_parameters():
if "global_embeddings" not in name:
param.requires_grad = False
else:
param.requires_grad = True
📄 許可證
LEGAL - BERT的引用信息如下:
@inproceedings{chalkidis-etal-2020-legal,
title = "{LEGAL}-{BERT}: The Muppets straight out of Law School",
author = "Chalkidis, Ilias and
Fergadiotis, Manos and
Malakasiotis, Prodromos and
Aletras, Nikolaos and
Androutsopoulos, Ion",
booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020",
month = nov,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
doi = "10.18653/v1/2020.findings-emnlp.261",
pages = "2898--2904"
}