CostSensitive Training for Autoregressive Models
Abstract
Training autoregressive models to better predict under the test metric, instead of maximizing the likelihood, has been reported to be beneficial in several use cases but brings additional complications, which prevent wider adoption. In this paper, we follow the learningtosearch approach (Daumé III et al., 2009; Leblond et al., 2018) and investigate its several components. First, we propose a way to construct a reference policy based on an alignment between the model output and ground truth. Our reference policy is optimal when applied to the Kendalltau distance between permutations (appear in the task of word ordering) and helps when working with the METEOR score for machine translation. Second, we observe that the learningtosearch approach benefits from choosing the costs related to the test metrics. Finally, we study the effect of different learning objectives and find that the standard KL loss only learns several highprobability tokens and can be replaced with ranking objectives that target these tokens explicitly.
1 Introduction
Autoregressive models are a popular choice for many applications, including machine translation, image captioning and code generation. These models output predictions one by one, and each prediction depends on the previous ones (by prediction we mean choosing a token from the set of available tokens). Modern autoregressive models are usually trained with the maximum likelihood estimation (MLE) approach, which has at least two disturbing properties: exposure bias (Ranzato et al., 2016) and lossevaluation mismatch (Wiseman & Rush, 2016).
 Exposure bias:

training with the MLE objective requires the model to take groundtruth sequences as inputs but, to generate the output sequence stepbystep at the testing stage, the model takes its output from the previous steps as the input. This discrepancy implies that the model never sees its own errors at the training stage, thus never explicitly learns to correct them.
 Lossevaluation mismatch:

the MLE approach attempts to learn a full probabilistic model of the ground truth, which requires enormous diverse datasets and heuristic prediction algorithms (Stahlberg & Byrne, 2019). Learning directly to predict under the target cost function might require less data and be better compatible with available prediction algorithms.
These properties have motivated Bahdanau et al. (2017); Edunov et al. (2018); Zhang et al. (2019) to work on new training methods. In this paper, we follow the learningtosearch (L2S) line of work (Daumé III et al., 2009; Chang et al., 2015; Leblond et al., 2018). This approach performs a reduction of sequence prediction to the costsensitive multiclass classification, where the classes correspond to the tokens available at each individual prediction. The costs of the classes come from the metrics between the groundtruth sequence and the sequence generated by some policy, which is an important hyperparameter.
The L2S approach has a lot of similarities with reinforcement learning (RL), and the critical difference is the access to the reference policy that chooses tokens optimally w.r.t. the annotation and cost function (such policies relate to the expert oracles of the RL world, where they are usually not available). Reference policies are often difficult to construct for nontrivial test metrics like the BLEU (Papineni et al., 2002) score, so one resorts to using approximations.
In this paper, we investigate different components of the learningtosearch approach (reference policy, costs and loss function) on the tasks of word ordering, neural machine translation (NMT), code generation. First, we propose the reference policy based on the alignment between the predictions and ground truth and prove its optimality w.r.t. the Kendalltau distance between the permutations (appear in the word ordering task). The constructed alignment can also be used to approximate the METEOR score (Denkowski & Lavie, 2014). Next, we experiment with different costs for word ordering, code generation and neural machine translation and show that the learningtosearch approach benefits from choosing the costs related to the test metrics. Although it is expected behavior, this is not always the case (Shen et al., 2016a; Wieting et al., 2019).
For the training objective, Leblond et al. (2018) use the KL divergence between the distributions obtained from the costs and the model. With this loss, we observe that training requires an extreme value of the parameter controlling the scale of costs, which corresponds to the degenerate target distribution. In this mode, the model learns only few (12) tokens corresponding to the lowest costs and expensive computation of other costs appears to be useless. We consider alternative training losses that depend only on the ordering of costs. First, we redefine the target distribution for the KL loss based on the cost ordering. Second, we adapt the top ListMLE loss from the ranking literature (Xia et al., 2009) to the L2S pipeline. We report that the orderingbased losses can outperform the original KL loss.
The rest of the paper is organized as follow: we describe the training with MLE and SeaRNN objectives in the Section 2 and we give the details of the experiments in the Section 3. In the Section 4, we choose the test metrics and define the costs and reference policy based on these metrics; in the Section 5, we discuss the training with different losses. The Sections 6 and 7 provide the related work and conclusion respectively.
2 Training with learningtosearch
Consider an autoregressive model with the input and sequential output . At each step, the goal is to predict a new output given the previous outputs and the input . Such models are usually trained with maximum likelihood estimation, which consists in maximizing the following objective:
(1) 
A common practice is to model the conditional probability by a neural network with an encoderdecoder architecture. The encoder maps the input to the latent space, which the decoder uses to produce the output. The encoder and decoder are usually build with the RNN (Bahdanau et al., 2015) or Transformer (Vaswani et al., 2017) blocks.
At the training stage, the input of the model is the groundtruth token from the previous step (teacher forcing). At the testing stage, the groundtruth tokens are not available, so the input of the model is the output from the previous step.
2.1 SeaRNN algorithm
One alternative to the MLE training of autoregressive models is the learningtosearch approach (Daumé III et al., 2009; Chang et al., 2015; Leblond et al., 2018). We use the SeaRNN algorithm (Leblond et al., 2018). For computing the loss for the th token () of an autoregressive model, the algorithm does the following steps:

Construct the prefix of length according to the rollin policy of choosing tokens.

Sample tokens to try. For the tasks with small vocabularies, we can sample all available tokens, which is too expensive otherwise. Sampling can be done uniformly from all available tokens, by choosing some fixed number of neighbors of this token in the ground truth or by selecting the tokens with the top values of probability according to the current model.

Try adding each token to the prefix and complete all the sequences with the rollout policy of choosing tokens.

Measure how close each completed sequence is to the ground truth, obtain the costs and compute the training loss.
A popular choice of such training loss is the KLdivergence between the distributions defined from the costs and model (Section 5.1). The costs are typically defined according to the test metric, inform the model how good or bad were the choices of tokens and allow the SeaRNN algorithm to optimize the test metric during the training stage.
Rollin and rollout policies. An important component of the method is the policy of choosing the tokens. We need these tokens to construct a prefix (rollin) and to complete a sequence (rollout); the completed sequences are used for computing the costs. Differently to the MLE training with teacher forcing, when using SeaRNN, the selected tokens are used as the input of the model.
Both rollin and rollout policies can be reference, learned or mixed. The learned policy chooses the most probable word according to the current model. The reference policy at rollin acts as teacher forcing and always outputs the groundtruth tokens. At rollout, the reference policy completes the rollin prefix optimally w.r.t. the test metric. The optimal reference policy w.r.t. metrics like BLEU or METEOR is hard to define, so we need to approximate it (we also refer to these approximations as the reference policy). Finally, we can mix the reference and learned policies. We choose reference or learned policy with probability for each sequence (mixed policy) or for each step (mixedcells policy).
3 Experimental setup
3.1 Tasks and Datasets
Word ordering. The goal of this task is to recover the order of words from the permuted sentence. We use the English part of the Multi30k dataset (Elliott et al., 2016). Following (Gu et al., 2019), we randomly permute the sentences to obtain the inputs.
Code generation. The task is to generate Python code from the descriptions in natural language. We use the standard sequencetosequence framework and the Django dataset (Oda et al., 2015) .
Neural machine translation. The task is to translate sentences from a source language to a target language. The neural machine translation experiments are conducted on the Multi30k dataset (Elliott et al., 2016); the source language is German and the target one is English.
3.2 Models
In all experiments, we use the standard encoderdecoder architecture with the attention of Bahdanau et al. (2015), gated recurrent units (Cho et al., 2014) and bidirectional encoder. The models have 2 layers and are regularized with the dropout of rate 0.3. We train the models with the Adam optimizer with the learning rate of . For inference, we use the greedy decoding. For the word ordering task, we constrained the model to always output a permutation of the input (by masking the output of the softmax layer). Other hyperparameters and training details are provided in Appendix A.
4 Costsensitive training
In this section, we describe the test metrics that we use, and the way to define the costs and reference policy based on these metrics.
We start with the BLEU metric, the corresponding costs and reference policy (Section 4.1). Next, we describe the Kendalltau distance and propose the reference policy, which we prove to be optimal w.r.t. this metric (Section 4.2). In the Section 4.3, we describe the METEOR metric and propose the reference policy and the way to approximate the METEOR score for computing the costs. Finally, we train the models to optimize these metrics and report the results (Section 4.4).
4.1 Training with BLEU
BLEU (Papineni et al., 2002) is a widelyused metric for text generation tasks. This metric is based on grams of different (the standard choice is BLEU4 with ). All our tasks is measured with BLEU: word ordering (Wiseman & Rush, 2016), machine translation (Papineni et al., 2002), code generation (Oda et al., 2015). When training to maximize BLEU, we use the reference policy proposed by Leblond et al. (2018): they try adding every suffix in the groundtruth sequence to the current prediction and pick the one with the highest BLEU1 score. They follow Bahdanau et al. (2017) and use a sentencelevel smoothed version of BLEU4 as the costs.
However, the BLEU is known to have drawbacks: it does not rely on the word meaning or grammatical structure (CallisonBurch et al., 2006); it correlates with the human evaluation less than other metrics (Sun, 2010); it is difficult to optimize (Wieting et al., 2019). We consider alternative metrics for all our tasks and investigate how different cost functions influence the results of training.
4.2 Training with Kendalltau
For the word ordering task, we use the Kendalltau distance as the additional to BLEU metric. The Kendalltau distance is a standard way to measure the difference between two permutations: it computes the number of pairwise disagreements between the two permutations. The only difference between the predicted and groundtruth sequences in the word ordering task is the order of tokens, which means that we can consider the predicted sequence as a permutation of the ground truth. Specifically, we use the Kendalltau distance between the permutation, which corresponds to the predicted sequence, and the identity permutation, which corresponds to the groundtruth sequence.
The reference policy of SeaRNN can not be used for the word ordering task because it does not return a permutation of the input. We propose the reference policy based on the alignment between the groundtruth sequence and the current predictions. The proposed policy completes the current predictions with missing elements and adds them in the order they appear in the ground truth. We prove that this reference policy is optimal w.r.t. the Kendalltau distance (Appendix B).
4.3 Training with METEOR
For the neural machine translation task, METEOR (Denkowski & Lavie, 2014), as well as BLEU, is a popular choice for the quality estimation. METEOR has an alignment module for computing the score. The alignment, constructed by METEOR, maximizes the number of covered tokens and minimizes the number of chunks (here a chunk is a contiguous subsequence with the correct internal order).
While the direct computation of METEOR at each training step is computationally expensive, we can construct the alignment similarly to Section 4.2 and define the simplified version of this metric as a cost function (we call it sMETEOR). We extend the proposed alignment reference policy: instead of adding the missing elements in the order they appear in the ground truth, we group them into chunks (see Appendix C for the details). We add missing chunks in the order that minimizes the number of chunks, as in the METEOR computation.
The standard METEOR score still computes the alignments in a more sophisticated way: it allows to align not only the tokens matching exactly, but also with other types of matching. For simplicity, we allow only exact matching in sMETEOR. For the code generation task, we use only the exact matching in the standard METEOR metric (we refer it as METEORe.m.).
4.4 Results
We apply the proposed in the Section 4.2 alignment policy to the word ordering task and compare the two settings: training with the BLEU and Kendalltau costs. The results show that learningtosearch outperforms MLE for both metrics and the training method benefits from using the correct cost function (Table 1).
In the neural machine translation and code generation tasks, our experiments show that even with our simplified version of METEOR, sMETEOR (defined in Section 4.3), as costs and our alignment policy, we can maximize the original METEOR metric (Table 2). The results with learningtosearch outperform MLE. However, difference between the training with BLEU and sMETEOR costs is not very large in terms of the METEOR metric. This is probably due to the fact that code generation and neural machine translation are more challenging tasks than word ordering because there are a lot of factors that affect training.
Training Method  BLEU  Kendalltau 

MLE  53.98  15.20 
BLEU costs  55.19  15.19 
Kendalltau costs  50.74  14.30 
NMT  Code generation  

Training Method  BLEU  METEOR  BLEU  METEORe.m. 
MLE  38.71  36.73  51.46  70.42 
BLEU costs  40.52  37.76  60.69  75.84 
sMETEOR costs  39.86  37.87  59.60  76.29 
5 The ranks of the costs are more important than the values
In the previous section, we discussed training with different costs, which could be of different scales. In the original KL loss, the scale of the costs is controlled with the parameter, which is important and requires tuning (Section 5.1). To evaluate the importance of the scale of costs, we try two orderingbased losses. The first one replaces the distribution formed from the cost values with the distribution formed from the cost order (Section 5.2). The second one is based on a loss from learningtorank (Section 5.3). We compare training with different losses and investigate what is more important for training: ordering or scale. We discuss the effect of the scale parameter on the quality of the model in Section 5.4.
For the next experiments, we choose the type of costs that provide larger performance improvement w.r.t. MLE for each task: the Kendalltau costs on the word ordering task, and BLEU on the neural machine translation and code generation tasks.
MLE  KL  KL (q 0.9)  KL (q 0.7)  top1 ListMLE  top2 ListMLE  

Word ord.  15.20  14.30  14.03  14.05  14.15  14.14 
Code gen.  51.46  60.69  58.90  59.40  59.42  40.41 
NMT  38.71  40.52  40.72  40.53  40.28  38.94 
5.1 Original KL loss
The widelyused objective that includes the costs is the KLdivergence between the model and cost distributions at each prediction step (Leblond et al., 2018; Welleck et al., 2018; Sabour et al., 2019; Welleck et al., 2019). The model distribution is the model output, which corresponds to the probability of each token in the dictionary at the current step. The standard way to convert the costs into a distribution is to use the softmax function. The costs can be of different scales in different cases, which affects the target distribution. An additional scale parameter (the inverse of the temperature) controls the scale of the costs.
The KL loss up to a constant equals the crossentropy between the model and cost distributions at each prediction step:
(2)  
(3)  
(4) 
Here is the scale parameter, is the number of samples. The loss (2) relies on the scale of the cost function and it is important to tune the scale parameter, which appears to be highly correlated with the model quality (Section 5.4).
5.2 Orderingbased KL loss
Instead of using the target distribution obtained with the softmax function of the costs, we define the target distribution from only the ordering of the costs. We parameterize the target distribution with one parameter similarly to the stickbreaking process:
Here and is a permutation corresponding to the nondecreasing order of the cost values ( gives the index of a token with the smallest cost value).
5.3 Top ListMLE loss
We can explicitly learn the order of the cost values with the top ListMLE loss (Xia et al., 2009), which comes from learningtorank. The order of costs corresponds to the groundtruth permutation that we want to learn.
The top ListMLE loss (Xia et al., 2009) is the negative loglikelihood of the topk subgroup in the groundtruth permutation :
(5) 
When equals one, this loss is equivalent to target learning used by Leblond et al. (2018).
5.4 Results and discussion
We compare all the losses of Sections 5.1, 5.2, 5.3 and observe that for the word ordering and machine translation tasks the orderingbased losses perform better than the original KL and MLE losses (Table 3): best results are achieved with orderingbased KL loss (). Information about the ordering of the costs appears to be sufficient for training. Let us discuss why the performance does not improve when using all the values of the costs in the original KL loss (2). We investigate the importance of the scale parameter in the original KL loss. When increasing this parameter, from (3) gets close to the degenerate distribution, where all the probability mass is concentrated at one point (Figure 1).
We observe that the training with original KL loss (2) works only for the high values of the scale parameter (the values that give the best performance are in the range of ; training with the values does not provide improvements; the model trained with the parameter in the range performs worse). This means that the KL loss learns only 12 entries of the target distribution with the lowest costs. The influence of other costs is very low. Computing such costs appears to be effectively useless, while it is the most computationally expensive part of the training process.
6 Related work
The learningtosearch approach is used in different tasks. Welleck et al. (2018) applied L2S for multiset prediction. Unlike our applications, in multiset prediction, the costs are naturally designed to be equal for all tokens from the multiset and the sizes of the multisets are smaller than the dictionary sizes in our tasks. Leblond et al. (2018) proposed the SeaRNN algorithm and demonstrated its advantages on OCR, spelling correction and neural machine translation. Sabour et al. (2019) used the L2S approach for speech recognition. Welleck et al. (2019) used L2S to generate text without prespecifying a generation order. They also studied several variants of the reference policy for their task. However, these policies do not depend on the test metric directly.
Some works investigate the components of the alternative to MLE training algorithms. While proposing minimum risk training (MRT), Shen et al. (2016b) studied the effect of different costs for neural machine translation. They concluded that training with sentencelevel BLEU improves the results in terms of the corpus level metrics. They also noticed that training with the costs based on some other metrics does not lead to improvement in the corresponding corpus level metrics. Wieting et al. (2019) claimed that BLEU is hard to optimize and proposed an alternative metric for MRT called SIMILE. They used the costs based on this metric to optimize BLEU to improve the translation quality in terms of both BLEU and SIMILE metrics.
Choshen et al. (2019) found that neural machine translation models trained with common RL methods improve the translation quality only when the correct translations are already in the top of the distribution of a pretrained model used for initialization. They noticed that such an improvement might be easier to achieve with reranking methods instead of RL. Their finding is very close to the effect we see in the model trained with the original KL loss, although the SeaRNN algorithm does not require pretraining. Edunov et al. (2018) investigated different losses from structured prediction in the context of neural machine translation. They found that combining sequencelevel and tokenlevel losses performed better. The losses in our work are computed at the token level, but use the sentence level while obtaining the costs.
7 Conclusion
In this study, we investigate the effect of different components of the SeaRNN algorithm (Leblond et al., 2018): the reference policy, costs, and loss functions. For word ordering, we show that the performance improves when choosing the costs related to the test metrics and training with the optimal to the Kendalltau distance reference policy, which we proposed. In the case of optimizing the BLEU and METEOR metrics, the method benefits less from using the correct costs and more work is required to understand this effect. We observe that the original KL loss tends to learn only the top tokens of the target distribution and does not fully utilize the costs. We propose the losses based only on the ordering of the costs and demonstrate that these losses can perform better than the original KL loss.
Acknowledgments
This research is in part based on the work supported by Samsung Research, Samsung Electronics, and by the Russian Science Foundation grant no. 197130020. We also thank NRU HSE for providing computational resources.
References
 Bahdanau et al. (2015) Bahdanau, D., Cho, K., and Bengio, Y. Neural machine translation by jointly learning to align and translate. In ICLR, 2015.
 Bahdanau et al. (2017) Bahdanau, D., Brakel, P., Xu, K., Goyal, A., Lowe, R., Pineau, J., Courville, A., and Bengio, Y. An actorcritic algorithm for sequence prediction. In ICLR, 2017.
 CallisonBurch et al. (2006) CallisonBurch, C., Osborne, M., and Koehn, P. Reevaluating the role of Bleu in machine translation research. In EACL, 2006.
 Chang et al. (2015) Chang, K.W., Krishnamurthy, A., Agarwal, A., Daumé III, H., and Langford, J. Learning to search better than your teacher. In ICML, 2015.
 Cho et al. (2014) Cho, K., Van Merrienboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., and Bengio, Y. Learning phrase representations using RNN encoderdecoder for statistical machine translation. In EMNLP, 2014.
 Choshen et al. (2019) Choshen, L., Fox, L., Aizenbud, Z., and Abend, O. On the weaknesses of reinforcement learning for neural machine translation. 2019.
 Daumé III et al. (2009) Daumé III, H., Langford, J., and Marcu, D. Searchbased structured prediction. Machine Learning, 2009.
 Denkowski & Lavie (2014) Denkowski, M. and Lavie, A. Meteor universal: Language specific translation evaluation for any target language. In the 9th Workshop on Statistical Machine Translation, 2014.
 Edunov et al. (2018) Edunov, S., Ott, M., Auli, M., Grangier, D., and Ranzato, M. Classical structured prediction losses for sequence to sequence learning. In ACL, 2018.
 Elliott et al. (2016) Elliott, D., Frank, S., et al. Multi30k: Multilingual englishgerman image descriptions. In the 5th Workshop on Vision and Language, 2016.
 Gu et al. (2019) Gu, J., Liu, Q., and Cho, K. Insertionbased decoding with automatically inferred generation order. Transactions of the ACL, 7:661–676, 2019.
 Leblond et al. (2018) Leblond, R., Alayrac, J.B., Osokin, A., and LacosteJulien, S. SeaRnn: Training rnns with globallocal losses. In ICLR, 2018.
 Oda et al. (2015) Oda, Y., Fudaba, H., et al. Learning to generate pseudocode from source code using statistical machine translation. In ASE, 2015.
 Papineni et al. (2002) Papineni, K., Roukos, S., Ward, T., and Zhu, W.J. Bleu: a method for automatic evaluation of machine translation. In ACL, 2002.
 Ranzato et al. (2016) Ranzato, M., Chopra, S., Auli, M., and Zaremba, W. Sequence level training with recurrent neural networks. In ICLR, 2016.
 Sabour et al. (2019) Sabour, S., Chan, W., and Norouzi, M. Optimal completion distillation for sequence learning. In ICLR, 2019.
 Shen et al. (2016a) Shen, S., Cheng, Y., He, Z., He, W., Wu, H., Sun, M., and Liu, Y. Minimum risk training for neural machine translation. ACL, 2016a.
 Shen et al. (2016b) Shen, S., Cheng, Y., He, Z., He, W., Wu, H., Sun, M., and Liu, Y. Minimum risk training for neural machine translation. In ACL, 2016b.
 Stahlberg & Byrne (2019) Stahlberg, F. and Byrne, B. On nmt search errors and model errors: Cat got your tongue? In EMNLP, 2019.
 Sun (2010) Sun, Y. Mining the correlation between human and automatic evaluation at sentence level. In LREC, 2010.
 Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., and Polosukhin, I. Attention is all you need. In NIPS, 2017.
 Welleck et al. (2018) Welleck, S., Yao, Z., Gai, Y., Mao, J., Zhang, Z., and Cho, K. Loss functions for multiset prediction. In NeurIPS, 2018.
 Welleck et al. (2019) Welleck, S., Brantley, K., Iii, H. D., and Cho, K. Nonmonotonic sequential text generation. In ICML, 2019.
 Wieting et al. (2019) Wieting, J., BergKirkpatrick, T., Gimpel, K., and Neubig, G. Beyond bleu:training neural machine translation with semantic similarity. In ACL, 2019.
 Wiseman & Rush (2016) Wiseman, S. and Rush, A. M. Sequencetosequence learning as beamsearch optimization. In EMNLP, 2016.
 Xia et al. (2009) Xia, F., Liu, T.Y., and Li, H. Statistical consistency of topk ranking. In NIPS, 2009.
 Zhang et al. (2019) Zhang, W., Feng, Y., Meng, F., You, D., and Liu, Q. Bridging the gap between training and inference for neural machine translation. In ACL, 2019.
Appendix A Training details
Parameters  Word ordering  Code generation  NMT 

max length  80  50  50 
rollin  mixed  mixedcells  mixedcells 
rollout  mixed  mixed  mixed 
max iteration  25000  10000  10000 
batch size  32  32  128 
embedding size  500  128  500 
hidden size  500  256  500 
Additional hyperparameters can be found in Table 4; we chose these values based on the performance on the validation set.
When training the word ordering models, we share the source and target embeddings. We choose all tokens to try from ground truth that are not in the prefix. When training the models on the code generation and neural machine translation tasks, we follow Leblond et al. (2018) and choose 5 tokens before and after the current position in the ground truth and 15 tokens that correspond to the top probabilities of the model distribution. The probability of reference or learned policy step is for both mixed and mixedcells modes (Leblond et al., 2018).
We observed that the model receive wrong signal from the costs defined with the default METEOR parameters. Computing the approximate METEOR scores in neural machine translation, we use the following parameters of the metric , , . For code generation, we use the languageindependent METEOR parameters described in the METEOR documentation.
Appendix B Optimal policy for the Kendalltau distance
Proposition.
The reference policy that inserts missing tokens from the ground truth in the order of their positions in the ground truth is optimal w.r.t. the Kendalltau distance.
Proof.
The ground truth corresponds to the identity permutation . Given the prefix we seek to obtain the lowest possible value of the Kendalltau distance between the ground truth and its permutation . By definition, we can write the Kendalltau distance as follows:
The last term corresponds to the suffix and is minimal if the completion has the minimal number of inversions. We can achieve zero inversions, when all tokens in the suffix come in the order they appear in the ground truth. This is exactly what our alignment policy outputs.
Appendix C Alignment policy for the METEOR metric
To optimize METEOR, we proposed the policy based on the alignment between the groundtruth sequence and current prediction. For this purpose, we construct the alignment at each step of the prediction. If the current policy is reference, it does one of the following steps:

if it is the first step, the reference policy returns the first token of the ground truth;

if it is not the first step and we can continue the already started chunk, the reference policy returns the next token, which continues this chunk;

if it is not the first step, but we can not continue the started chunk (the next token is already used or the current token is the endofsequence), we start a new chunk with the token that is not used yet and appears in the ground truth first.
If the policy of the current step is learned and the output of the model aligns with the groundtruth token that is not used yet we update our alignment.