🚀 VideoMAE(基础大小模型,仅预训练)
VideoMAE是一个在视频领域表现出色的模型,它通过自监督学习的方式在Something - Something - v2数据集上进行了2400个周期的预训练。该模型能够学习视频的内在表示,为下游任务提取有用特征,在视频分类等领域具有重要价值。
🚀 快速开始
你可以使用这个预训练的VideoMAE模型进行视频掩码块的像素值预测,不过它主要还是用于在下游任务中进行微调。你可以在模型中心中寻找针对你感兴趣任务的微调版本。
✨ 主要特性
- 扩展自MAE:VideoMAE是掩码自编码器(MAE)在视频领域的扩展,其架构与标准的视觉变换器(ViT)非常相似,顶部添加了解码器用于预测掩码块的像素值。
- 学习视频内在表示:通过预训练,模型学习到视频的内在表示,可用于提取对下游任务有用的特征。例如,在有标签的视频数据集上,可在预训练编码器顶部添加线性层来训练标准分类器。
📚 详细文档
模型描述
VideoMAE将视频以固定大小的块(分辨率为16x16)序列形式输入,这些块经过线性嵌入处理。同时,在序列开头添加[CLS]标记用于分类任务,并在将序列输入到Transformer编码器层之前添加固定的正弦/余弦位置嵌入。
通过预训练,模型学习到视频的内在表示,可用于下游任务。通常在[CLS]标记顶部添加线性层,因为该标记的最后隐藏状态可视为整个视频的表示。
预期用途和局限性
可以使用原始模型预测视频掩码块的像素值,但它主要用于在下游任务中进行微调。
训练数据
(待补充,欢迎提交PR)
训练过程
预处理
(待补充,欢迎提交PR)
预训练
(待补充,欢迎提交PR)
评估结果
(待补充,欢迎提交PR)
💻 使用示例
基础用法
以下是如何使用该模型预测随机掩码块的像素值:
from transformers import VideoMAEFeatureExtractor, VideoMAEForPreTraining
import numpy as np
import torch
num_frames = 16
video = list(np.random.randn(16, 3, 224, 224))
feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base-short-ssv2")
model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base-short-ssv2")
pixel_values = feature_extractor(video, return_tensors="pt").pixel_values
num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
loss = outputs.loss
更多代码示例请参考文档。
📄 许可证
本模型采用CC BY-NC 4.0许可证。
BibTeX引用
misc{https://doi.org/10.48550/arxiv.2203.12602,
doi = {10.48550/ARXIV.2203.12602},
url = {https://arxiv.org/abs/2203.12602},
author = {Tong, Zhan and Song, Yibing and Wang, Jue and Wang, Limin},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}