🚀 VideoMAE(基礎大小模型,僅預訓練)
VideoMAE是一個在視頻領域表現出色的模型,它通過自監督學習的方式在Something - Something - v2數據集上進行了2400個週期的預訓練。該模型能夠學習視頻的內在表示,為下游任務提取有用特徵,在視頻分類等領域具有重要價值。
🚀 快速開始
你可以使用這個預訓練的VideoMAE模型進行視頻掩碼塊的像素值預測,不過它主要還是用於在下游任務中進行微調。你可以在模型中心中尋找針對你感興趣任務的微調版本。
✨ 主要特性
- 擴展自MAE:VideoMAE是掩碼自編碼器(MAE)在視頻領域的擴展,其架構與標準的視覺變換器(ViT)非常相似,頂部添加了解碼器用於預測掩碼塊的像素值。
- 學習視頻內在表示:通過預訓練,模型學習到視頻的內在表示,可用於提取對下游任務有用的特徵。例如,在有標籤的視頻數據集上,可在預訓練編碼器頂部添加線性層來訓練標準分類器。
📚 詳細文檔
模型描述
VideoMAE將視頻以固定大小的塊(分辨率為16x16)序列形式輸入,這些塊經過線性嵌入處理。同時,在序列開頭添加[CLS]標記用於分類任務,並在將序列輸入到Transformer編碼器層之前添加固定的正弦/餘弦位置嵌入。
通過預訓練,模型學習到視頻的內在表示,可用於下游任務。通常在[CLS]標記頂部添加線性層,因為該標記的最後隱藏狀態可視為整個視頻的表示。
預期用途和侷限性
可以使用原始模型預測視頻掩碼塊的像素值,但它主要用於在下游任務中進行微調。
訓練數據
(待補充,歡迎提交PR)
訓練過程
預處理
(待補充,歡迎提交PR)
預訓練
(待補充,歡迎提交PR)
評估結果
(待補充,歡迎提交PR)
💻 使用示例
基礎用法
以下是如何使用該模型預測隨機掩碼塊的像素值:
from transformers import VideoMAEFeatureExtractor, VideoMAEForPreTraining
import numpy as np
import torch
num_frames = 16
video = list(np.random.randn(16, 3, 224, 224))
feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base-short-ssv2")
model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base-short-ssv2")
pixel_values = feature_extractor(video, return_tensors="pt").pixel_values
num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
loss = outputs.loss
更多代碼示例請參考文檔。
📄 許可證
本模型採用CC BY-NC 4.0許可證。
BibTeX引用
misc{https://doi.org/10.48550/arxiv.2203.12602,
doi = {10.48550/ARXIV.2203.12602},
url = {https://arxiv.org/abs/2203.12602},
author = {Tong, Zhan and Song, Yibing and Wang, Jue and Wang, Limin},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}