🚀 基于LaTeX的Nougat模型
基于LaTeX的Nougat模型是从 facebook/nougat-base 微调而来,使用 im2latex-100k 数据集进行训练,以提升其从图像生成LaTeX代码的能力。该模型解决了原Nougat模型在处理方程图像片段时,因输入图像尺寸不合适导致的缩放伪影问题,从而提高了LaTeX代码的生成质量。
🚀 快速开始
安装依赖
pip install transformers >= 4.34.0
运行步骤
- 下载仓库
git clone git@github.com:NormXU/nougat-latex-ocr.git
cd ./nougat-latex-ocr
- 进行推理
import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel
from transformers.models.nougat import NougatTokenizerFast
from nougat_latex import NougatLaTexProcessor
model_name = "Norm/nougat-latex-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)
tokenizer = NougatTokenizerFast.from_pretrained(model_name)
latex_processor = NougatLaTexProcessor.from_pretrained(model_name)
image = Image.open("path/to/latex/image.png")
if not image.mode == "RGB":
image = image.convert('RGB')
pixel_values = latex_processor(image, return_tensors="pt").pixel_values
decoder_input_ids = tokenizer(tokenizer.bos_token, add_special_tokens=False,
return_tensors="pt").input_ids
with torch.no_grad():
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_length,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
num_beams=5,
bad_words_ids=[[tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = tokenizer.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "").replace(tokenizer.bos_token, "")
print(sequence)
✨ 主要特性
📚 详细文档
模型信息
评估结果
在从Wikipedia、arXiv和im2latex-100k收集的图像 - 方程对数据集上进行评估,该数据集由 lukas-blecher 整理。
模型 |
标记准确率 ↑ |
归一化编辑距离 ↓ |
pix2tex |
0.5346 |
0.10312 |
pix2tex* |
0.60 |
0.10 |
nougat-latex-based |
0.623850 |
0.06180 |
pix2tex是 LaTeX-OCR 中引入的ResNet + ViT + 文本解码器架构。
pix2tex*:来自 LaTeX-OCR 的报告; pix2tex:使用发布的 检查点 进行的评估; nougat-latex-based:在使用束搜索策略生成的结果上进行评估。
注意事项
⚠️ 重要提示
推理API小部件有时会截断响应。请查看 此问题 以获取更多详细信息。如果推理API的错误导致结果被截断,您可能需要自己运行模型。
📄 许可证
本项目采用Apache-2.0许可证。