Pythia - FAIR A-STAR团队进行视觉问答研究的模块化框架

Song • 171 次浏览 • 0 个回复 • 2018年10月20日

FAIR A-STAR团队进行视觉问答研究的模块化框架

一、Pythia

Pythia是视觉问答应用研究的模块化框架,它基于2018年Facebook AI Research(FAIR)VQA 挑战赛获奖团队A-STAR的基础(Visual Question Answering,VQA即视觉问答)。请查看他们的论文了解更多详情。

(A-STAR:看见,谈话,行为和理性的代理人。)

二、安装pythia环境

  • 安装AnacondaAnaconda推荐:https://www.continuum.io/downloads)。
  • 安装cudnn v7.0cuda.9.0
  • 配置pythia环境
conda create --name vqa python=3.6

source activate vqa
pip install demjson pyyaml

pip install http://download.pytorch.org/whl/cu90/torch-0.3.1-cp36-cp36m-linux_x86_64.whl

pip install torchvision
pip install tensorboardX

三、快速开始

官方提供预处理的数据文件,以直接开始训练和评估。除了使用原有的train2014val2014,我们拆分val2014val2train2014minival2014,并使用train2014+ val2train2014训练和minival2014进行验证。

下载数据。此步骤可能需要一些时间。检查自述文件末尾的文件大小。

git clone git@github.com:facebookresearch/pythia.git
cd Pythia

mkdir data
cd data
wget https://s3-us-west-1.amazonaws.com/pythia-vqa/data/vqa2.0_glove.6B.300d.txt.npy
wget https://s3-us-west-1.amazonaws.com/pythia-vqa/data/vocabulary_vqa.txt
wget https://s3-us-west-1.amazonaws.com/pythia-vqa/data/answers_vqa.txt
wget https://s3-us-west-1.amazonaws.com/pythia-vqa/data/imdb.tar.gz
wget https://s3-us-west-1.amazonaws.com/pythia-vqa/data/rcnn_10_100.tar.gz
wget https://s3-us-west-1.amazonaws.com/pythia-vqa/data/detectron.tar.gz
wget https://s3-us-west-1.amazonaws.com/pythia-vqa/data/large_vocabulary_vqa.txt
wget https://s3-us-west-1.amazonaws.com/pythia-vqa/data/large_vqa2.0_glove.6B.300d.txt.npy
gunzip imdb.tar.gz 
tar -xf imdb.tar

gunzip rcnn_10_100.tar.gz 
tar -xf rcnn_10_100.tar
rm -f rcnn_10_100.tar

gunzip detectron.tar.gz
tar -xf detectron.tar
rm -f detectron.tar

可选的命令行参数train.py

python train.py -h

usage: train.py [-h] [--config CONFIG] [--out_dir OUT_DIR] [--seed SEED]
                [--config_overwrite CONFIG_OVERWRITE] [--force_restart]

optional arguments:
  -h, --help            show this help message and exit
  --config CONFIG       config yaml file
  --out_dir OUT_DIR     output directory, default is current directory
  --seed SEED           random seed, default 1234, set seed to -1 if need a
                        random seed between 1 and 100000
  --config_overwrite CONFIG_OVERWRITE
                        a json string to update yaml config file
  --force_restart       flag to force clean previous result and restart
                        training

运行模型

cd ../
python train.py

如果出现内存不足错误,请尝试:

python train.py --config_overwrite '{data:{image_fast_reader:false}}'

运行带有微调的检测器功能的模型

python train.py --config config/keep/detectron.yaml

检查默认运行的结果

cd results/default/1234

results文件夹包含以下信息

results
|_ default
|  |_ 1234 (default seed)
|  |  |_config.yaml
|  |  |_best_model.pth
|  |  |_best_model_predict_test.pkl 
|  |  |_best_model_predict_test.json (json file for predicted results on test dataset)
|  |  |_model_00001000.pth (mpdel snapshot at iter 1000)
|  |  |_result_on_val.txt
|  |  |_ ...
|  |_(other_cofig_setting)
|  |  |_...
|_ (other_config_file)
|

tensorbord的日志文件存储在boards/

1、预处理数据集

如果要自己从原始VQA数据集开始并预处理数据,请参考data_preprocess.md中的说明。 如果从快速入门下载所有数据,则无需此部分。

使用预训练模型进行测试 注意:下面的所有这些模型都经过了包含验证集的训练

描述 性能(test-dev) 链接
detectron_100_resnet_most_data 70.01 https://s3-us-west-1.amazonaws.com/pythia-vqa/pretrained_models/detectron_100_resnet_most_data.tar.gz
baseline 68.05 https://s3-us-west-1.amazonaws.com/pythia-vqa/pretrained_models/baseline.tar.gz
baseline +VG +VisDal +mirror 68.98 https://s3-us-west-1.amazonaws.com/pythia-vqa/pretrained_models/most_data.tar.gz
detectron_finetune 68.49 https://s3-us-west-1.amazonaws.com/pythia-vqa/pretrained_models/detectron.tar.gz
detectron_finetune+VG +VisDal +mirror 69.24 https://s3-us-west-1.amazonaws.com/pythia-vqa/pretrained_models/detectron_most_data.tar.gz

2、最佳预训练模型

最好的预训练模型可以按如下方式下载:

mkdir pretrained_models/
cd pretrained_models
wget https://s3-us-west-1.amazonaws.com/pythia-vqa/pretrained_models/detectron_100_resnet_most_data.tar.gz
gunzip detectron_100_resnet_most_data.tar.gz 
tar -xf detectron_100_resnet_most_data.tar
rm -f detectron_100_resnet_most_data.tar

通过固定的100 bounding boxes获得ResNet152功能和Detectron功能

cd data
wget https://s3-us-west-1.amazonaws.com/pythia-vqa/data/detectron_fix_100.tar.gz
gunzip detectron_fix_100.tar.gz
tar -xf detectron_fix_100.tar
rm -f detectron_fix_100.tar

wget https://s3-us-west-1.amazonaws.com/pythia-vqa/data/resnet152.tar.gz
gunzip resnet152.tar.gz
tar -xf resnet152.tar
rm -f resnet152.tar

在VQA test2015数据集上测试最佳模型

python run_test.py --config pretrained_models/detectron_100_resnet_most_data/1234/config.yaml \
--model_path pretrained_models/detectron_100_resnet_most_data/1234/best_model.pth \
--out_prefix test_best_model

结果将保存为json文件test_best_model.json,此文件可以上载到EvalAI上的评估服务器(https://evalai.cloudcv.org/web/challenges/challenge-page/80/submission)。

3、全部模型

下载上面的所有型号

python ensemble.py --res_dirs pretrained_models/ --out ensemble_5.json

结果将保存在ensemble_5.json。这个集合可以在test-dev上获得71.65的准确度。

a、运行30个模型

要运行30个预训练模型的集合,请按如下方式下载模型和图像功能。这在test-dev上的准确度为72.18

wget https://s3-us-west-1.amazonaws.com/pythia-vqa/ensembled.tar.gz

4、自定义配置

要更改模型或调整参数,请参阅config_help.md

5、Docker演示

使用nvidia-docker以交互方式快速试用模型

git clone https://github.com/facebookresearch/pythia.git
nvidia-docker build pythia -t pythia:latest
nvidia-docker run -ti --net=host pythia:latest

这将打开一个带有演示模型的jupyter笔记,您可以以交互方式提问。

6、AWS s3数据集摘要

在这里,我们列出了AWS S3存储中一些大文件的大小。

描述 大小
data/rcnn_10_100.tar.gz 71.0GB
data/detectron.tar.gz 106.2 GB
data/detectron_fix_100.tar.gz 162.6GB
data/resnet152.tar.gz 399.6GB
ensembled.tar.gz 462.1GB

原创文章,转载请注明 :Pythia - FAIR A-STAR团队进行视觉问答研究的模块化框架 - pytorch中文网
原文出处: https://ptorch.com/news/213.html
问题交流群 :168117787
提交评论
要回复文章请先登录注册
用户评论
  • 没有评论
Pytorch是什么?关于Pytorch! pytorch句子嵌入(InferSent)和NLI的训练代码。