模型简介
模型特点
模型能力
使用案例
🚀 DistilBERT-产品分类器
这是一个针对电子商务领域产品分类任务,对DistilBERT模型进行微调后的版本。该模型能够区分“CPU”、“数码相机”、“洗碗机”等多种类别,可用于在线零售平台的产品整理和分类。
🚀 快速开始
使用以下代码开始使用该模型进行产品分类:
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
# Load the model and tokenizer from the Hugging Face Hub
def load_model_and_tokenizer(model_name, num_labels):
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
model.eval() # Set the model to evaluation mode
return model, tokenizer
# Predict categories for the provided prompts
def predict(model, tokenizer, prompts, category_mapping, device):
model.to(device)
inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors='pt', max_length=128)
with torch.no_grad():
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
predictions = torch.argmax(logits, dim=1).cpu().numpy()
predicted_categories = [category_mapping[pred] for pred in predictions]
return predicted_categories
# Main execution block
if __name__ == "__main__":
# Define some example prompts for prediction
prompts = [
"Intel Core i7 CPU",
"Nikon D3500 Digital Camera",
"Bosch Series 6 Dishwasher",
"Samsung 32 inch Smart TV",
"Apple iPhone 13"
]
# Create the category mapping based on provided comments
category_mapping = {
0: 'cpus',
1: 'digital cameras',
2: 'dishwashers',
3: 'fridge freezers',
4: 'microwaves',
5: 'mobile phones',
6: 'tvs',
7: 'washing machines'
}
model_name = 'Adnan-AI-Labs/DistilBERT-ProductClassifier'
# Load the model and tokenizer
print(f"Loading model and tokenizer from Hugging Face Hub: {model_name}")
model, tokenizer = load_model_and_tokenizer(model_name, len(category_mapping))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Make predictions
predicted_categories = predict(model, tokenizer, prompts, category_mapping, device)
# Display the predictions
for prompt, category in zip(prompts, predicted_categories):
print(f"Prompt: '{prompt}' | Predicted Category: '{category}'")
输出示例
Loading model and tokenizer from Hugging Face Hub: Adnan-AI-Labs/DistilBERT-ProductClassifier
Prompt: 'Intel Core i7 CPU' | Predicted Category: 'cpus'
Prompt: 'Nikon D3500 Digital Camera' | Predicted Category: 'digital cameras'
Prompt: 'Bosch Series 6 Dishwasher' | Predicted Category: 'dishwashers'
Prompt: 'Samsung 32 inch Smart TV' | Predicted Category: 'tvs'
Prompt: 'Apple iPhone 13' | Predicted Category: 'mobile phones'
✨ 主要特性
- 该模型是DistilBERT的微调版本,专为电子商务领域的产品分类任务而训练。
- 能够区分多种产品类别,如“CPU”、“数码相机”、“洗碗机”等,有助于在线零售平台的产品组织和分类。
- 利用紧凑的DistilBERT架构,适合在实时应用中使用,包括移动和Web环境。
📚 详细文档
模型详情
模型描述
DistilBERT-ProductClassifier模型在电子商务数据集上进行训练,用于将产品分类到特定类别。它针对高效的文本分类任务进行了优化,设计为在有限的计算资源下也能良好工作。
- 开发者:Adnan AI Labs
- 模型类型:针对文本分类进行微调的DistilBERT
- 语言:英语
- 许可证:Apache 2.0
- 微调基础模型:DistilBERT
模型来源
用途
直接用途
该模型旨在对电子商务平台中基于文本的产品列表进行产品分类。用户可以使用此模型自动对产品进行分类,减少手动标记的需求,提高产品的可发现性。
非适用用途
此模型不适用于与产品分类无关或非电子商务场景的任务。它不用于情感分析、通用文本生成或其他无关的自然语言处理任务。
偏差、风险和局限性
该模型在电子商务数据上进行训练,对于训练范围之外的产品或类别可能表现不佳。此外,某些类别在训练数据中的代表性可能较少,可能会影响这些类别的分类准确性。
建议
对于涉及其他语言或训练数据中未包含的高度专业化产品类别的用例,可能需要进一步微调。用户在将模型输出用于高风险决策之前,应验证模型的输出。
训练详情
训练数据
该模型在一个电子商务数据集上进行训练,该数据集包括各种产品类别,如CPU、数码相机、洗碗机、冰箱冰柜、微波炉、手机、电视和洗衣机。数据经过去重、小写转换和分词等预处理。
训练过程
- 预处理:文本数据进行了清理、小写转换和分词处理。产品描述被截断为128个标记以保持一致性。
- 超参数:微调时学习率为2e-5,批量大小为16,训练3个周期。
- 训练硬件:模型在单个NVIDIA Tesla V100 GPU上训练约3小时。
评估
测试数据、因素和指标
该模型在一个单独的产品描述测试集上进行评估,使用精确率、召回率和F1分数作为评估指标。
总结
该模型总体准确率达到96.16%,在多个产品类别上表现出色。F1分数表明,该模型在“CPU”和“数码相机”类别中表现尤其出色。
技术规格
模型架构和目标
DistilBERT-ProductClassifier模型采用DistilBERT架构,并通过文本分类头进行微调,用于电子商务产品分类任务。
计算基础设施
该模型经过优化,可在CPU和小型GPU上高效运行,适合实时应用。
硬件要求
该模型进行高效推理至少需要4GB的RAM,建议使用现代CPU或基本GPU。
软件要求
- Transformers库:v4.26.0
- Python版本:3.8或更高
- CUDA [可选]:10.2或更高(如果在GPU上运行)
引用
如果您使用此模型,请按以下方式引用:
@misc{distilbert_product_classifier,
author = {Adnan AI Labs},
title = {DistilBERT-ProductClassifier for E-commerce},
year = {2024},
url = {https://huggingface.co/Adnan-AI-Labs/DistilBERT-ProductClassifier}
}
📄 许可证
本项目采用Apache 2.0许可证。








