PyTorch implementation of the models described in the paper Retrieving Sequential Information for Non-Autoregressive Neural Machine Translation .
- Python 3.6
- PyTorch 0.4
- Numpy
- NLTK
- torchtext
- torchvision
- revtok
- multiset
- ipdb
- CUDA (we recommend using the latest version. The version 8.0 was used in all our experiments.)
- This code is based on dl4mt-nonauto. We mainly modified the
model.py
(line 1103-1199).
The original translation corpora can be downloaded from (IWLST'16 En-De, WMT'16 En-Ro, WMT'14 En-De). We recommend you to download the preprocessed corpora released in dl4mt-nonauto.
Set correct path to data in data_path()
function located in data.py
:
Train a NAT model using the cross-entropy loss. This process usually takes about 10 days. You can download our pretrained models here
$ sh train_iwslt.sh
$ sh rf_wmt.sh
Take a checkpoint p
8000
re-trained non-autoregressive model and finetune the checkpoint using the RF-NAT algorithm. This process usually takes about 1 days.
If you want to use the origin REINFORCE, change the flag --nat_finetune
to --rf_finetune
.
$ sh rf_iwslt.sh
$ sh rf_wmt.sh
Take a finetuned checkpoint and train the length prediction model. This process usually takes about 1 day.
$ sh tune_iwslt.sh
$ sh tune_wmt.sh
Decode the test set. This process usually takes about 20 seconds.
$ sh decode_iwslt.sh
$ sh decode_wmt.sh
If you find the resources in this repository useful, please consider citing:
@inproceedings{shao2019retrieving,
title = "Retrieving Sequential Information for Non-Autoregressive Neural Machine Translation",
author = "Shao, Chenze and
Feng, Yang and
Zhang, Jinchao and
Meng, Fandong and
Chen, Xilin and
Zhou, Jie",
booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2019",
url = "https://www.aclweb.org/anthology/P19-1288",
pages = "3013--3024",
}