模型概述
模型特點
模型能力
使用案例
🚀 DistilBERT-產品分類器
這是一個針對電子商務領域產品分類任務,對DistilBERT模型進行微調後的版本。該模型能夠區分“CPU”、“數碼相機”、“洗碗機”等多種類別,可用於在線零售平臺的產品整理和分類。
🚀 快速開始
使用以下代碼開始使用該模型進行產品分類:
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
# Load the model and tokenizer from the Hugging Face Hub
def load_model_and_tokenizer(model_name, num_labels):
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
model.eval() # Set the model to evaluation mode
return model, tokenizer
# Predict categories for the provided prompts
def predict(model, tokenizer, prompts, category_mapping, device):
model.to(device)
inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors='pt', max_length=128)
with torch.no_grad():
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
predictions = torch.argmax(logits, dim=1).cpu().numpy()
predicted_categories = [category_mapping[pred] for pred in predictions]
return predicted_categories
# Main execution block
if __name__ == "__main__":
# Define some example prompts for prediction
prompts = [
"Intel Core i7 CPU",
"Nikon D3500 Digital Camera",
"Bosch Series 6 Dishwasher",
"Samsung 32 inch Smart TV",
"Apple iPhone 13"
]
# Create the category mapping based on provided comments
category_mapping = {
0: 'cpus',
1: 'digital cameras',
2: 'dishwashers',
3: 'fridge freezers',
4: 'microwaves',
5: 'mobile phones',
6: 'tvs',
7: 'washing machines'
}
model_name = 'Adnan-AI-Labs/DistilBERT-ProductClassifier'
# Load the model and tokenizer
print(f"Loading model and tokenizer from Hugging Face Hub: {model_name}")
model, tokenizer = load_model_and_tokenizer(model_name, len(category_mapping))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Make predictions
predicted_categories = predict(model, tokenizer, prompts, category_mapping, device)
# Display the predictions
for prompt, category in zip(prompts, predicted_categories):
print(f"Prompt: '{prompt}' | Predicted Category: '{category}'")
輸出示例
Loading model and tokenizer from Hugging Face Hub: Adnan-AI-Labs/DistilBERT-ProductClassifier
Prompt: 'Intel Core i7 CPU' | Predicted Category: 'cpus'
Prompt: 'Nikon D3500 Digital Camera' | Predicted Category: 'digital cameras'
Prompt: 'Bosch Series 6 Dishwasher' | Predicted Category: 'dishwashers'
Prompt: 'Samsung 32 inch Smart TV' | Predicted Category: 'tvs'
Prompt: 'Apple iPhone 13' | Predicted Category: 'mobile phones'
✨ 主要特性
- 該模型是DistilBERT的微調版本,專為電子商務領域的產品分類任務而訓練。
- 能夠區分多種產品類別,如“CPU”、“數碼相機”、“洗碗機”等,有助於在線零售平臺的產品組織和分類。
- 利用緊湊的DistilBERT架構,適合在即時應用中使用,包括移動和Web環境。
📚 詳細文檔
模型詳情
模型描述
DistilBERT-ProductClassifier模型在電子商務數據集上進行訓練,用於將產品分類到特定類別。它針對高效的文本分類任務進行了優化,設計為在有限的計算資源下也能良好工作。
- 開發者:Adnan AI Labs
- 模型類型:針對文本分類進行微調的DistilBERT
- 語言:英語
- 許可證:Apache 2.0
- 微調基礎模型:DistilBERT
模型來源
用途
直接用途
該模型旨在對電子商務平臺中基於文本的產品列表進行產品分類。用戶可以使用此模型自動對產品進行分類,減少手動標記的需求,提高產品的可發現性。
非適用用途
此模型不適用於與產品分類無關或非電子商務場景的任務。它不用於情感分析、通用文本生成或其他無關的自然語言處理任務。
偏差、風險和侷限性
該模型在電子商務數據上進行訓練,對於訓練範圍之外的產品或類別可能表現不佳。此外,某些類別在訓練數據中的代表性可能較少,可能會影響這些類別的分類準確性。
建議
對於涉及其他語言或訓練數據中未包含的高度專業化產品類別的用例,可能需要進一步微調。用戶在將模型輸出用於高風險決策之前,應驗證模型的輸出。
訓練詳情
訓練數據
該模型在一個電子商務數據集上進行訓練,該數據集包括各種產品類別,如CPU、數碼相機、洗碗機、冰箱冰櫃、微波爐、手機、電視和洗衣機。數據經過去重、小寫轉換和分詞等預處理。
訓練過程
- 預處理:文本數據進行了清理、小寫轉換和分詞處理。產品描述被截斷為128個標記以保持一致性。
- 超參數:微調時學習率為2e-5,批量大小為16,訓練3個週期。
- 訓練硬件:模型在單個NVIDIA Tesla V100 GPU上訓練約3小時。
評估
測試數據、因素和指標
該模型在一個單獨的產品描述測試集上進行評估,使用精確率、召回率和F1分數作為評估指標。
總結
該模型總體準確率達到96.16%,在多個產品類別上表現出色。F1分數表明,該模型在“CPU”和“數碼相機”類別中表現尤其出色。
技術規格
模型架構和目標
DistilBERT-ProductClassifier模型採用DistilBERT架構,並通過文本分類頭進行微調,用於電子商務產品分類任務。
計算基礎設施
該模型經過優化,可在CPU和小型GPU上高效運行,適合即時應用。
硬件要求
該模型進行高效推理至少需要4GB的RAM,建議使用現代CPU或基本GPU。
軟件要求
- Transformers庫:v4.26.0
- Python版本:3.8或更高
- CUDA [可選]:10.2或更高(如果在GPU上運行)
引用
如果您使用此模型,請按以下方式引用:
@misc{distilbert_product_classifier,
author = {Adnan AI Labs},
title = {DistilBERT-ProductClassifier for E-commerce},
year = {2024},
url = {https://huggingface.co/Adnan-AI-Labs/DistilBERT-ProductClassifier}
}
📄 許可證
本項目採用Apache 2.0許可證。








