🚀 Kwaipilot OASIS-1.5B
Kwaipilotが開発した最先端のコード埋め込みモデルOASISは、独自の手法を用いて、コード検索の効率と精度において新たな基準を設定します。このモデルは、コード検索システムの改善に最適で、様々なプログラミングコンテキストでのコードスニペットの意味理解と検索に優れています。
🚀 クイックスタート
モデルの概要
OASIS (Order-Augmented Strategy for Improved Code Search) は、Kwaipilotによって開発された最先端のコード埋め込みモデルです。このモデルは、リポジトリレベルのプログラム分析、OASIS-instructデータ合成アルゴリズム、および特殊な融合損失関数を含む独自の手法を用いて、コード検索の効率と精度において新たな基準を設定しています。
主な用途
このモデルは、コード検索システムの改善に取り組む開発者や研究者に最適です。OASISは、様々なプログラミングコンテキストでのコードスニペットの意味理解と検索が必要なシナリオで優れた性能を発揮します。
トレーニングとパフォーマンス
OASISは、リポジトリレベルの分析を通じて作成された合成データセットでトレーニングされており、さまざまなコーディングスタイルや言語に対する幅広い理解を保証します。最新のコード検索ベンチマークで最先端の性能を示しています。
最新情報 📢
✨ 主な機能
- 独自の手法を用いたコード埋め込み:リポジトリレベルのプログラム分析、OASIS-instructデータ合成アルゴリズム、特殊な融合損失関数を用いて、コード検索の効率と精度を向上させます。
- 幅広いコーディングスタイルと言語に対応:合成データセットでトレーニングされており、さまざまなコーディングスタイルや言語に対する幅広い理解を保証します。
- 最先端のパフォーマンス:最新のコード検索ベンチマークで最先端の性能を示しています。
📦 インストール
直接使用する場合
pip install -U torch
pip install -U transformers
torch_dtype=torch.bfloat16
でモデルをロードする際には、torch=2.5.0
の使用を避けてください。最適なパフォーマンスと安定性を得るために、PyTorchバージョン2.4.1以前を使用するか、2.5.1以降にアップグレードしてください。
Sentence Transformersを使用する場合
pip install -U sentence-transformers
💻 使用例
基本的な使用法
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoModel, AutoTokenizer
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def get_query_prompt(query: str):
query_description = 'Given a code search query, retrieve relevant code snippet that answer the query'
prompt = f'Instruct: {query_description}\nQuery: {query}'
return prompt
query = "How to do quicksort in python?"
code1 = """def bubble_sort(arr):
n = len(arr)
for i in range(n):
swapped = False
for j in range(1, n - i):
if arr[j - 1] > arr[j]:
arr[j - 1], arr[j] = arr[j], arr[j - 1]
swapped = True
if not swapped:
break
return arr"""
code2 = """def quick_sort(arr):
if len(arr) <= 1:
return arr
else:
pivot = arr[0]
less = [x for x in arr[1:] if x <= pivot]
greater = [x for x in arr[1:] if x > pivot]
return quick_sort(less) + [pivot] + quick_sort(greater)"""
model = AutoModel.from_pretrained("Kwaipilot/OASIS-code-1.5B", output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained("Kwaipilot/OASIS-code-1.5B")
inputs = tokenizer([get_query_prompt(query), code1, code2], max_length=1024, padding=True, truncation=True, return_tensors='pt')
outputs = model(**inputs)
embeddings = last_token_pool(outputs.hidden_states[-1], inputs['attention_mask'])
print(embeddings.shape)
embeddings = F.normalize(embeddings, dim=1, p=2)
similarity = embeddings @ embeddings.T
print(similarity[0, 1:])
高度な使用法
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("Kwaipilot/OASIS-code-1.5B")
query = "How to do quicksort in python?"
code1 = """def bubble_sort(arr):
n = len(arr)
for i in range(n):
swapped = False
for j in range(1, n - i):
if arr[j - 1] > arr[j]:
arr[j - 1], arr[j] = arr[j], arr[j - 1]
swapped = True
if not swapped:
break
return arr"""
code2 = """def quick_sort(arr):
if len(arr) <= 1:
return arr
else:
pivot = arr[0]
less = [x for x in arr[1:] if x <= pivot]
greater = [x for x in arr[1:] if x > pivot]
return quick_sort(less) + [pivot] + quick_sort(greater)"""
query_embedding = model.encode([query], prompt_name="query")
code_embeddings = model.encode([code1, code2])
print(code_embeddings.shape)
print(model.similarity(query_embedding[0], code_embeddings[0]))
print(model.similarity(query_embedding[0], code_embeddings[1]))
📚 ドキュメント
パフォーマンス
|
サイズ |
CoSQA |
AdvTest |
CSN-Py |
CSN-Ja |
CSN-JS |
CSN-PHP |
CSN-Go |
CSN-Ruby |
平均 |
OpenAI-Embedding-Ada-002 |
不明 |
0.4423 |
0.3808 |
0.6802 |
0.7149 |
0.6750 |
0.6062 |
0.8563 |
0.7472 |
0.6378 |
OpenAI-Text-embedding-3-large |
不明 |
0.5538 |
0.4684 |
0.7084 |
0.7292 |
0.6813 |
0.5959 |
0.8764 |
0.7525 |
0.6707 |
jina-embeddings-v2-base-code |
161M |
0.6837 |
0.385 |
0.6634 |
0.6803 |
0.6304 |
0.5701 |
0.8595 |
0.7095 |
0.6477 |
CodeSage-large |
1.3B |
0.4753 |
0.5267 |
0.7077 |
0.7021 |
0.695 |
0.6133 |
0.8371 |
0.7192 |
0.6595 |
CodeFuse-CGE-Small |
3.8B |
0.5619 |
0.4639 |
0.6958 |
0.6863 |
0.6564 |
0.6133 |
0.8637 |
0.7341 |
0.6594 |
OASIS-code-1.5B |
1.5B |
0.5577 |
0.5727 |
0.7369 |
0.7397 |
0.6980 |
0.6384 |
0.8821 |
0.7547 |
0.6975 |
BibTeX
@misc{kwaipilotoasis,
title = {Optimized Augmentation Strategy for Improved code Search},
author = {Kwaipilot team},
year = {2024},
}
📄 ライセンス
このプロジェクトはMITライセンスの下で公開されています。