🚀 NLGP文檔字符串模型
NLGP文檔字符串模型在論文自然語言引導編程中被提出。該模型在一組Jupyter筆記本上進行訓練,可用於合成Python代碼,以在特定代碼上下文中實現自然語言意圖(見下面的示例)。
也可查看NLGP自然模型。
這項工作由諾基亞貝爾實驗室的一個研究團隊完成。
🚀 快速開始
NLGP文檔字符串模型可根據給定的代碼上下文和自然語言意圖生成相應的Python代碼。以下是一個簡單示例,展示瞭如何使用該模型根據上下文和意圖生成代碼。
示例
上下文
import matplotlib.pyplot as plt
values = [1, 2, 3, 4]
labels = ["a", "b", "c", "d"]
意圖
預測結果
plt.bar(labels, values)
plt.show()
💻 使用示例
基礎用法
以下是使用NLGP文檔字符串模型的完整代碼示例:
import re
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
tok = GPT2TokenizerFast.from_pretrained("Nokia/nlgp-docstring")
model = GPT2LMHeadModel.from_pretrained("Nokia/nlgp-docstring")
num_spaces = [2, 4, 6, 8, 10, 12, 14, 16, 18]
def preprocess(context, query):
"""
Encodes context + query as a single string and
replaces whitespace with special tokens <|2space|>, <|4space|>, ...
"""
input_str = f"{context}\n{query} <|endofcomment|>\n"
indentation_symbols = {n: f"<|{n}space|>" for n in num_spaces}
m = re.match("^[ ]+", input_str)
if not m:
return input_str
leading_whitespace = m.group(0)
N = len(leading_whitespace)
for n in self.num_spaces:
leading_whitespace = leading_whitespace.replace(n * " ", self.indentation_symbols[n])
return leading_whitespace + input_str[N:]
detokenize_pattern = re.compile(fr"<\|(\d+)space\|>")
def postprocess(output):
output = output.split("<|cell|>")[0]
def insert_space(m):
num_spaces = int(m.group(1))
return num_spaces * " "
return detokenize_pattern.sub(insert_space, output)
code_context = """
import matplotlib.pyplot as plt
values = [1, 2, 3, 4]
labels = ["a", "b", "c", "d"]
"""
query = "# plot a bar chart"
input_str = preprocess(code_context, query)
input_ids = tok(input_str, return_tensors="pt").input_ids
max_length = 150
total_max_length = min(1024 - input_ids.shape[-1], input_ids.shape[-1] + 150)
input_and_output = model.generate(
input_ids=input_ids,
max_length=total_max_length,
min_length=10,
do_sample=False,
num_beams=4,
early_stopping=True,
eos_token_id=tok.encode("<|cell|>")[0]
)
output = input_and_output[:, input_ids.shape[-1]:]
output_str = tok.decode(output[0])
postprocess(output_str)
📄 許可證
版權信息
Copyright 2021 Nokia
許可協議
本項目採用Apache License 2.0許可協議。
SPDX-License-Identifier: Apache-2.0