🚀 コード脆弱性検出用にファインチューニングされたCodeBERT 💾⛔
codebert-base を CodeXGLUE -- Defect Detection データセットでファインチューニングし、コード脆弱性検出 という下流タスクに対応させたモデルです。
🚀 クイックスタート
このモデルは、ソースコードがソフトウェアシステムに攻撃を与える可能性のある脆弱性のあるコードかどうかを識別するために使用できます。以下のセクションでは、モデルの詳細、使用するデータセット、テストセットのメトリクス、そしてモデルの使用例について説明します。
✨ 主な機能
- CodeBERTの活用:CodeBERTは、プログラミング言語(PL)と自然言語(NL)のための二峰性事前学習モデルで、下流のNL - PLアプリケーションに役立つ汎用的な表現を学習します。
- コード脆弱性検出:ソースコードが脆弱性のあるコードか安全なコードかを二値分類することができます。
- 高精度:テストセットのメトリクスでは、他のモデルよりも高い精度を達成しています。
📚 ドキュメント
我々はCodeBERTを提案します。これは、プログラミング言語(PL)と自然言語(NL)のための二峰性事前学習モデルです。CodeBERTは、自然言語コード検索やコードドキュメント生成などの下流のNL - PLアプリケーションをサポートする汎用的な表現を学習します。我々は、TransformerベースのニューラルアーキテクチャでCodeBERTを開発し、生成器からサンプリングされた妥当な代替案を検出する置換トークン検出の事前学習タスクを組み込んだハイブリッド目的関数でトレーニングします。これにより、NL - PLペアの二峰性データと単峰性データの両方を利用でき、前者はモデルトレーニングの入力トークンを提供し、後者はより良い生成器を学習するのに役立ちます。我々は、モデルパラメータをファインチューニングすることで、2つのNL - PLアプリケーションでCodeBERTを評価しました。結果は、CodeBERTが自然言語コード検索とコードドキュメント生成の両方のタスクで最先端の性能を達成していることを示しています。さらに、CodeBERTで学習される知識の種類を調査するために、NL - PLプロービング用のデータセットを構築し、事前学習モデルのパラメータを固定したゼロショット設定で評価しました。結果は、CodeBERTがNL - PLプロービングで以前の事前学習モデルよりも優れていることを示しています。
🔎 下流タスク(コード分類)の詳細 - データセット 📚
与えられたソースコードが、リソースリーク、ユーザー後の解放脆弱性、DoS攻撃など、ソフトウェアシステムを攻撃する可能性のある脆弱性のあるコードかどうかを識別するタスクです。このタスクを二値分類(0/1)として扱い、1は脆弱性のあるコード、0は安全なコードを表します。
使用するデータセットは、論文 Devign: Effective Vulnerability Identification by Learning Comprehensive Program Semantics via Graph Neural Networks から取得したものです。すべてのプロジェクトを統合し、トレーニング/開発/テスト用に80%/10%/10%に分割しています。
データセットのデータ統計は以下の表に示されています。
|
#サンプル数 |
トレーニング |
21,854 |
開発 |
2,732 |
テスト |
2,732 |
🔎 テストセットのメトリクス 🧾
💻 使用例
基本的な使用法
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
tokenizer = AutoTokenizer.from_pretrained('mrm8488/codebert-base-finetuned-detect-insecure-code')
model = AutoModelForSequenceClassification.from_pretrained('mrm8488/codebert-base-finetuned-detect-insecure-code')
inputs = tokenizer("your code here", return_tensors="pt", truncation=True, padding='max_length')
labels = torch.tensor([1]).unsqueeze(0)
outputs = model(**inputs, labels=labels)
loss = outputs.loss
logits = outputs.logits
print(np.argmax(logits.detach().numpy()))
作成者: Manuel Romero/@mrm8488 | LinkedIn
スペインで ♥ を込めて作成