🚀 CED-Base模型
CED是基于ViT-Transformer的简单音频标签模型,在AudioSet数据集上实现了最优性能。
🚀 快速开始
CED模型为音频标签任务提供了高效且性能卓越的解决方案。它在AudioSet数据集上表现出色,具备简化微调、支持可变长度输入、加速训练和推理以及优异的性能等特点。
✨ 主要特性
- 简化微调:对Mel频谱图进行批量归一化。在微调时,无需像AST模型那样先计算数据集的均值和方差。
- 支持可变长度输入:大多数其他模型使用静态时频位置嵌入,这限制了模型对短于10秒音频片段的泛化能力。许多先前的Transformer模型为避免性能影响,将输入填充到10秒,这会大幅降低训练和推理速度。
- 训练/推理加速:采用64维梅尔滤波器组和16x16无重叠的图像块,从10秒的频谱图中可得到248个图像块。相比之下,AST在训练/推理时使用128个梅尔滤波器组和16x16(10x10重叠)卷积,会产生1212个图像块。CED-Tiny在普通CPU上的运行速度与可比的MobileNetV3相当。
- 性能优异:参数为1000万的CED模型性能优于大多数先前参数约为8000万的方法。
📦 安装指南
pip install git+https://github.com/jimbozhang/hf_transformers_custom_model_ced.git
💻 使用示例
基础用法
>>> from ced_model.feature_extraction_ced import CedFeatureExtractor
>>> from ced_model.modeling_ced import CedForAudioClassification
>>> model_name = "mispeech/ced-base"
>>> feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
>>> model = CedForAudioClassification.from_pretrained(model_name)
>>> import torchaudio
>>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
>>> assert sampling_rate == 16000
>>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
>>> import torch
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_id = torch.argmax(logits, dim=-1).item()
>>> model.config.id2label[predicted_class_id]
'Finger snapping'
高级用法
example_finetune_esc50.ipynb
展示了如何在ESC-50数据集上,冻结CED编码器的情况下训练一个线性头。
📚 详细文档
模型性能
模型 |
参数数量(M) |
AS-20K (mAP) |
AS-2M (mAP) |
CED-Tiny |
5.5 |
36.5 |
48.1 |
CED-Mini |
9.6 |
38.5 |
49.0 |
CED-Small |
22 |
41.6 |
49.6 |
CED-Base |
86 |
44.0 |
50.0 |
模型来源
模型信息
属性 |
详情 |
模型类型 |
音频分类 |
训练数据 |
AudioSet |
评估指标 |
mAP |
📄 许可证
本项目采用Apache-2.0许可证。