🚀 LSG模型
LSG模型是LEGAL - BERT模型的小型版本,目前尚未进行额外的预训练。它使用了相同数量的参数和层数,以及相同的分词器。该模型能够处理长序列,并且比Longformer或BigBird(来自Transformers)更快、更高效,它依赖于局部 + 稀疏 + 全局注意力机制(LSG)。
🚀 快速开始
此模型依赖自定义建模文件,使用时需要添加trust_remote_code=True
。具体可参考 #13467。
LSG的ArXiv 论文 可查看详细信息。Github上的转换脚本可通过此 链接 获取。
✨ 主要特性
- 能够处理长序列,相比Longformer或BigBird(来自Transformers),处理速度更快、效率更高。
- 依赖于局部 + 稀疏 + 全局注意力机制(LSG)。
- 模型具有“自适应”能力,可根据需要自动填充序列(在配置中
adaptive=True
)。不过,建议使用分词器截断输入(truncation=True
),并可选择以块大小的倍数进行填充(pad_to_multiple_of=...
)。
- 支持编码器 - 解码器架构,但尚未进行广泛测试。
- 采用PyTorch实现。
📦 安装指南
此模型依赖自定义建模文件,使用时需添加trust_remote_code=True
。以下是加载模型和分词器的示例代码:
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
💻 使用示例
基础用法
以下是使用模型进行掩码填充和序列分类的基础代码示例:
from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer
model = AutoModelForMaskedLM.from_pretrained("ccdv/legal-lsg-small-uncased-4096", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
SENTENCES = ["Paris is the <mask> of France.", "The goal of life is <mask>."]
pipeline = FillMaskPipeline(model, tokenizer)
output = pipeline(SENTENCES, top_k=1)
output = [o[0]["sequence"] for o in output]
print(output)
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
trust_remote_code=True,
pool_with_global=True,
)
tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
SENTENCE = "This is a test for sequence classification. " * 300
token_ids = tokenizer(
SENTENCE,
return_tensors="pt",
truncation=True
)
output = model(**token_ids)
print(output)
高级用法
可以修改模型的各种参数,例如全局令牌的数量、局部块大小、稀疏块大小等。以下是一个修改参数的示例:
from transformers import AutoModel
model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
trust_remote_code=True,
num_global_tokens=16,
block_size=64,
sparse_block_size=64,
attention_probs_dropout_prob=0.0,
sparsity_factor=4,
sparsity_type="none",
mask_first_token=True
)
📚 详细文档
参数说明
可以更改以下各种参数:
属性 |
详情 |
全局令牌数量 (num_global_tokens ) |
可设置全局令牌的数量,默认值为1。 |
局部块大小 (block_size ) |
局部块的大小,默认值为128。 |
稀疏块大小 (sparse_block_size ) |
稀疏块的大小,默认值为128。 |
稀疏因子 (sparsity_factor ) |
稀疏因子,默认值为2。 |
掩码第一个令牌 (mask_first_token ) |
由于第一个令牌与第一个全局令牌冗余,可选择掩码第一个令牌。 |
默认参数在实践中效果良好。如果内存不足,可以减小块大小、增加稀疏因子并去除注意力分数矩阵中的丢弃率。
稀疏选择类型
有6种不同的稀疏选择模式,最佳类型取决于具体任务。
- 如果
sparse_block_size = 0
或sparsity_type = "none"
,则仅考虑局部注意力。
- 注意,对于长度小于2 * 块大小的序列,稀疏选择类型没有影响。
以下是各种稀疏选择类型的说明:
sparsity_type = "bos_pooling"
(新):
- 使用BOS令牌进行加权平均池化。
- 通常效果最佳,尤其是在稀疏因子较大(8、16、32)的情况下。
- 无额外参数。
sparsity_type = "norm"
:选择范数最高的令牌。
- 在稀疏因子较小(2到4)时效果最佳。
- 无额外参数。
sparsity_type = "pooling"
:使用平均池化合并令牌。
- 在稀疏因子较小(2到4)时效果最佳。
- 无额外参数。
sparsity_type = "lsh"
:使用LSH算法对相似令牌进行聚类。
- 在稀疏因子较大(4+)时效果最佳。
- LSH依赖于随机投影,因此不同种子的推理结果可能略有不同。
- 额外参数:
lsg_num_pre_rounds = 1
,在计算质心之前先合并令牌n次。
sparsity_type = "stride"
:每个头使用由稀疏因子步长的不同令牌。
sparsity_type = "block_stride"
:每个头使用由稀疏因子步长的令牌块。
训练全局令牌
以下是仅训练全局令牌和分类头的示例代码:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
trust_remote_code=True,
pool_with_global=True,
num_global_tokens=16
)
tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
for name, param in model.named_parameters():
if "global_embeddings" not in name:
param.requires_grad = False
else:
param.requires_grad = True
📄 许可证
LEGAL - BERT的引用信息如下:
@inproceedings{chalkidis-etal-2020-legal,
title = "{LEGAL}-{BERT}: The Muppets straight out of Law School",
author = "Chalkidis, Ilias and
Fergadiotis, Manos and
Malakasiotis, Prodromos and
Aletras, Nikolaos and
Androutsopoulos, Ion",
booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020",
month = nov,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
doi = "10.18653/v1/2020.findings-emnlp.261",
pages = "2898--2904"
}