🚀 基於LaTeX的Nougat模型
基於LaTeX的Nougat模型是從 facebook/nougat-base 微調而來,使用 im2latex-100k 數據集進行訓練,以提升其從圖像生成LaTeX代碼的能力。該模型解決了原Nougat模型在處理方程圖像片段時,因輸入圖像尺寸不合適導致的縮放偽影問題,從而提高了LaTeX代碼的生成質量。
🚀 快速開始
安裝依賴
pip install transformers >= 4.34.0
運行步驟
- 下載倉庫
git clone git@github.com:NormXU/nougat-latex-ocr.git
cd ./nougat-latex-ocr
- 進行推理
import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel
from transformers.models.nougat import NougatTokenizerFast
from nougat_latex import NougatLaTexProcessor
model_name = "Norm/nougat-latex-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)
tokenizer = NougatTokenizerFast.from_pretrained(model_name)
latex_processor = NougatLaTexProcessor.from_pretrained(model_name)
image = Image.open("path/to/latex/image.png")
if not image.mode == "RGB":
image = image.convert('RGB')
pixel_values = latex_processor(image, return_tensors="pt").pixel_values
decoder_input_ids = tokenizer(tokenizer.bos_token, add_special_tokens=False,
return_tensors="pt").input_ids
with torch.no_grad():
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_length,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
num_beams=5,
bad_words_ids=[[tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = tokenizer.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "").replace(tokenizer.bos_token, "")
print(sequence)
✨ 主要特性
📚 詳細文檔
模型信息
評估結果
在從Wikipedia、arXiv和im2latex-100k收集的圖像 - 方程對數據集上進行評估,該數據集由 lukas-blecher 整理。
模型 |
標記準確率 ↑ |
歸一化編輯距離 ↓ |
pix2tex |
0.5346 |
0.10312 |
pix2tex* |
0.60 |
0.10 |
nougat-latex-based |
0.623850 |
0.06180 |
pix2tex是 LaTeX-OCR 中引入的ResNet + ViT + 文本解碼器架構。
pix2tex*:來自 LaTeX-OCR 的報告; pix2tex:使用發佈的 檢查點 進行的評估; nougat-latex-based:在使用束搜索策略生成的結果上進行評估。
注意事項
⚠️ 重要提示
推理API小部件有時會截斷響應。請查看 此問題 以獲取更多詳細信息。如果推理API的錯誤導致結果被截斷,您可能需要自己運行模型。
📄 許可證
本項目採用Apache-2.0許可證。