pytorch使用SGM进行多标签分类的序列生成模型

Song • 1228 次浏览 • 0 个回复 • 2019年03月17日

多标签分类是自然语言处理中的一项重要但具有挑战性的任务。它比单标签分类更复杂,因为标签往往是相关的。现有方法倾向于忽略标签之间的相关性。此外,文本的不同部分可以不同地用于预测不同的标签,现有模型不考虑这些标签。在本文中,LancoPKU(北大学语言计算与机器学习小组)建议将多标签分类任务视为序列生成问题,并应用具有新颖解码器结构的序列生成模型来解决它。大量实验结果表明,LancoPKU(北大学语言计算与机器学习小组)提出的方法在很大程度上优于以前的工作。对实验结果的进一步分析表明,所提出的方法不仅捕获标签之间的相关性,

一、多标签分类的序列生成模型

这是LancoPKU(北大学语言计算与机器学习小组)的论文SGM的代码:多标签分类的序列生成模型PDF

注意:
提供的代码基于RCV1-V2数据集。如果需要在其他数据集上运行代码,请相应地修改与数据集的特定名称相关的所有程序语句。

二、数据集

  • RCV1-V2
  • AAPD 有两个数据集可在这里获取

三、要求

  • Ubuntu 16.0.4
  • Python 3.5
  • Pytorch 0.3.1

代码复现

LancoPKU(北大学语言计算与机器学习小组)在RCV1-V2数据集上提供SGM模型和SGM + GE模型的预训练检查点,以帮助您重现LancoPKU(北大学语言计算与机器学习小组)报告的实验结果。详细的复制步骤如下:

  • 请先点击上面提供的链接下载RCV1-V2数据集和检查点,然后将它们放在文件夹./data/data/
  • 预处理: python3 preprocess.py
  • 预测: python3 predict.py -gpus id -log log_name

1、预处理

python3 preprocess.py

请记住下载数据集并将其放在文件夹./data/data/

2、训练

python3 train.py -gpus id -log log_name

3、评估

python3 predict.py -gpus id -restore checkpoint -log log_name

代码地址:lancopku/SGM 论文地址:SGM: Sequence Generation Model for Multi-label Classification


原创文章,转载请注明 :pytorch使用SGM进行多标签分类的序列生成模型 - pytorch中文网
原文出处: https://ptorch.com/news/240.html
问题交流群 :168117787
提交评论
要回复文章请先登录注册
用户评论
  • 没有评论
Pytorch是什么?关于Pytorch! python通过opencv使用图片制作简单视频