SDA: Improving Text Generation with Self Data Augmentation

Page created by Ida Barnes
 
CONTINUE READING
SDA: Improving Text Generation with Self Data Augmentation
                                                  Ping Yu1 , Ruiyi Zhang2 , Yang Zhao1 , Yizhe Zhang3 , Chunyuan Li3 , Changyou Chen1
                                                                                      1
                                                                                          The State University of New York at Buffalo
                                                                                                      2
                                                                                                        Duke University
                                                                                               3
                                                                                                 Microsoft Research, Redmond
                                                                                                     pingyu@buffalo.edu

                                                                    Abstract                                     the effect of such word-level replacement is limited. A bet-
arXiv:2101.03236v1 [cs.CL] 2 Jan 2021

                                                                                                                 ter option is to use sentence-level augmentation, which has
                                          Data augmentation has been widely used to improve deep                 been successfully used in many applications. For example,
                                          neural networks in many research fields, such as computer
                                                                                                                 (Sennrich, Haddow, and Birch 2015; Sugiyama and Yoshi-
                                          vision. However, less work has been done in the context of
                                          text, partially due to its discrete nature and the complexity of       naga 2019) used the back-translation method to generate
                                          natural languages. In this paper, we propose to improve the            more training data to improve neural machine translation
                                          standard maximum likelihood estimation (MLE) paradigm by               (NMT); (Kafle, Yousefhussien, and Kanan 2017) generated
                                          incorporating a self-imitation-learning phase for automatic            question and answering pairs using LSTM for the Visual
                                          data augmentation. Unlike most existing sentence-level aug-            Question Answering (VQA) task. A noticeable drawback of
                                          mentation strategies, which are only applied to specific models,       these sentence-level text augmentation methods, however,
                                          our method is more general and could be easily adapted to any          is the lack of universality in the sense that they are only
                                          MLE-based training procedure. In addition, our framework al-           applicable to specific models.
                                          lows task-specific evaluation metrics to be designed to flexibly          In this paper, we overcome this limitation by proposing a
                                          control the generated sentences, for example, in terms of con-
                                                                                                                 self-imitation-learning based method to generate augmented
                                          trolling vocabulary usage and avoiding nontrivial repetitions.
                                          Extensive experimental results demonstrate the superiority of          text data. In general, our method contains two learning
                                          our method on two synthetic and several standard real datasets,        phases:
                                          significantly improving related baselines.                            i) The MLE Training stage: our model starts from the regular
                                                                                                                   MLE based training until convergence;
                                                                Introduction                                    ii) The Self Data Augmentation stage: this phase maintains
                                                                                                                    a buffer DB to store top-k ranked generated data samples
                                        Deep learning models are often considered data-hungry, in                   as the self-augmented data. In training, it also updates a
                                        the sense that the more data we have, the better performance                discriminator to alternatively match the model generated
                                        we can achieve. In practice, it is usually too laborious and                data with the self-augmented data and the training data.
                                        time-consuming to annotate a large amount of training data.
                                        An appropriate data augmentation method can significantly                Especially, the rationality of our self-data-augmentation
                                        boost the performance of a deep-learning model. Generally                stage is supported by a proposed self-inverse reinforcement-
                                        speaking, augmenting continuous data such as images is                   learning framework, which reveals that the alternative match-
                                        relatively simpler, as one could easily infer the meaning of             ing in the discriminator leads to learning a policy that prefers
                                        an image even if it is manipulated by flipping, adding noise,            generating higher-quality text. The major contributions of
                                        resizing, etc. However, for discrete data such as those in the           this work are summarized in three aspects:
                                        natural-language-processing (NLP) filed, data augmentation               • We propose a flexible sentence-level text augmentation
                                        is more challenging. This is partly because even a single-word             technique that could be combined with any MLE based
                                        change could affect the meaning of a whole sentence.                       text generation method to improve model performance.
                                           Recent advances of NLP has advocated to adopt pre-
                                        trained models to effectively leverage extensive unlabeled               • We show how to control the generated data by defining
                                        data collected from the internet. Such a learning paradigm                 task-specific evaluation metrics to rank the self-augmented
                                        has achieved defeating success for generating human-level                  data in the buffer, e.g., controlling vocabulary usage and
                                        sentences and articles. Nevertheless, we argue that a large                avoiding nontrivial repetition.
                                        amount of data still benefits the finetuning stage, which can            • We evaluate our model on both unconditional generation
                                        be achieved by data augmentation. In fact, data augmenta-                  and conditional generation tasks, achieving significant
                                        tion for text is not new. The traditional way is to replace                improvements relative to the baselines. We further con-
                                        several words in a sentence with its synonyms (words or                    duct human evaluations, grammar check evaluations, and
                                        phrases having similar meanings) (Zhang, Zhao, and LeCun                   memorization evaluations to demonstrate our method’s
                                        2015; Wang and Yang 2015; Kobayashi 2018). However,                        superiority. Ablation studies are conducted to investigate
a                                00
  the effects of varying training set size and removing the       δs,s00 = 0 for other next states s . In these methods, usually

  imitation-learning component.                                   a discriminator network, denoted as Dφ , is learned to pro-
                                                                  vide the reward for the REINFORCE algorithm (Williams
         Preliminaries and Related Works                          1992). The discriminator could only calculate a reward for
                                                                  a complete sentences instead of short-term reward, which
Text augmentation                                                 will be a problem for long-text generation. Yu et al. (2017)
The traditional way of text augmentation mainly focuses on        designed the N-Time Monte Carlo search (MC) to calculate
word-level cases. Zhang, Zhao, and LeCun (2015) proposed          a roll out reward. Specifically, they applied the Monte Carlo
to select a word according to the geometric distribution and      search with a roll-out policy Gθ to sample the unknown last
replace it with its synonym. Wang and Yang (2015) found           T − t tokens. They represent an N -time Monte Carlo search
synonyms with k-nearest neighbors using word embedding.           as:       n                     o
Kobayashi (2018) proposed to replace a word by the out-                      y1:T (1) , . . . , y1:T (N )   = MCGθ (y1:t ; N ) ,              (1)
put of a bidirectional RNN language model. EDA (Wei and
                                                                   where y1:t = (y1 , . . . , yt ) and yt+1:T (n) is sampled based
Zou 2019) proposed four operations: synonym replacement,          on the generator policy Gθ and the current state. To reduce
random insertion, random swap, and random deletion. As            the variance and get a more accurate assessment of the action
acknowledged, the performance of word-level text augmenta-        value, they ran the roll-out policy starting from the current
tion is limited due to simple word replacement in a sentence      state until the end of the sequence for N times to get a batch
without considering the sentence structure.                       of output samples. The roll-out reward QG       θ
                                                                                                                Dφ is calculated as
    There have been some attempts for sentence-level text
augmentation, but they were only designed for specific appli-                 G
                                                                            QDθφ (s = y1:t−1 , a = yt ) =                                     (2)
cations. In a translation task, Sennrich, Haddow, and Birch                      PN                   
                                                                                                   (n)
(2015) designed back translation. Fadaee, Bisazza, and Monz                  N1
                                                                            
                                                                                     n=1 Dφ y1:T          ,               for t < T
(2017) generated new sentence pairs containing rare words
in new and synthetically created contexts. Zheng et al. (2019)
                                                                             
                                                                              D (y )
                                                                                φ  1:t                                    for t = T
designed a mirror architecture to augment translation pairs
for non-parallel data. For the Visual Question Answering          The objective function w.r.t. the generator parameters θ is
(VQA) task, Kafle, Yousefhussien, and Kanan (2017) gener-                             T
ated question-answer pairs using LSTM. It is worth noting
                                                                                      X                               G
                                                                          J(θ) =            Eyt ∼Gθ (yt |y1:t−1 ) [QDθφ (y1:t−1 , yt )]       (3)
that the aforementioned methods are not universal in terms                            t=1
of models.
    By contrast, our framework a universal sentence-level         The gradient is calculated through policy gradient:
method for text augmentation, which could be easily com-                               T
                                                                                       X
bined with the popular MLE based methods for training.                  ∇θ J(θ) =             Eyt ∼Gθ (yt |y1:t−1 )                           (4)
Besides, our augmentation process can be easily controlled                              t=1
(e.g., control the vocabulary usage and avoid nontrivial repe-
                                                                                  h                                                       i
                                                                                                                      G
                                                                                      ∇θ log Gθ (yt |y1:t−1 ) · QDθφ (y1:t−1 , yt )
titions in the generation) by incorporating some user-defined
metrics.                                                           LeakGAN (Guo et al. 2018) improved the SeqGAN frame-
                                                                  work through leaking feature information from discriminator
Text generation as reinforcement learning                         to generator and treating it as additional guided information
Text generation can be explained as a reinforcement-learning      for training generator.
problem. To generate a text sequence, neural language mod-
els (Mikolov et al. 2010) typically
                                adopt an auto-regressive                Self-Improved Text Augmentation
manner. Given a dataset D = y1:T (i) . Denote y1:t−1 (i)
as all tokens before token t. T is the length of the sentence     Imitation learning is a framework of training policy to imitate
                                                                  training data. The recently developed self imitation learning
y(i) . The standard MLE approach to train the model pθ mini-
mizes:                                                            (Oh et al. 2018) explores the possibility of using the gen-
                                                                  erated data to improve a model further. In this paper, we
                         T
                       X                          
                                                                  leverage the advances of these ideas and propose a model to
      LMLE pθ , y(i) = −   log pθ yt (i) |y1:t−1 (i)
                            t=1
                                                                  use self-augmented data to improve MLE-based models for
                                                                  data augmentation.
 The above sequential text generation process has been re-           Specifically, in text generation, we treat an MLE-based
formulated as reinforcement learning in SeqGAN (Yu et al.         model, pθ (yt |y1:t−1 ), as a generator network Gθ (yt |y1:t−1 ),
2017). In the following, we will omit the sentence index “i”      which aims to learn a policy to generate next word based
when the context is clear. In the timestep t, the state s is
the current produced tokens y1:t−1 and the action a is the        on previous word. Some example generators can be LSTM-
next token yt to be selected. The policy model (or generator),    based models, CNN-based models and Transformer-based
denoted as Gθ (yt |y1:t−1 ), is usually stochastic and defines    models. On top of this, we add a discriminator network Dφ to
the probability of generating the next token given previous       provide reward to be applicable to the reinforcement-learning
ones. The state transition is deterministic after an action has   setting. Based on this setting, our self data-augmentation
                     a                              0
been chosen, i.e., δs,s 0 = 1 for the next state s = y1:t if      method consists of two stages in training. The first stage
the current state s = y1:t−1 and the action a = yt ; and          is the general MLE-based model training (MLE Training
stage); and the second stage further refines the MLE-based           Before describing the two steps, we first define a reference
model with some self-augmented data, called the Self Data         metric used to select generated sentences to construct the
Augmentation stage. Note for the general MLE-based model,         self-augmented data in the buffer DB . Some examples of the
we typically add the second stage after the convergence of        specific metric are given in the experiment part.
the MLE training to further improve the MLE-based model.
Below we explain these two stages in detail, with a formal        Definition 1 (Reference Metric) A reference metric is de-
explanation of the second stage given in the next section.        fined as a function, m : S → R, mapping a sentence s ∈ S
                                                                  to a real value measuring the goodness of the sentence.
The MLE Training stage
In this stage, the generator Gθ (yt |y1:t−1 ) and the discrim-    The imitating self-augmented data step In this step, we
inator Dφ are trained separately. The generator is trained        first update the buffer DB with the generated sentences asso-
with the general MLE training mechanism, and the discrim-         ciated with the highest reference metric values. Specifically,
inator is trained in a GAN frame (Goodfellow et al. 2014).        we associate each sentence with a reference metric score in
Specifically, the generator is trained with an objective:
                                                                  the buffer. When new data are generated, we replace the low-
                          T
                          X                                       score sentences in the buffer with the high-score generated
               Jg (π) =         log Gθ (yt |y1:t−1 ) .      (5)
                                                                  sentences, and only keep k highest-score sentences, where
                          t=1
                                                                  k is a hyperparameter depending on the size of the training
 The discriminator performs as a classifier to discriminate the
generated sentences and ground truth training sentences with      set and complexity of the task. After this, we assign the self-
an objective:                                                     augmented data with label 1, and the current generated data
                                                                  with label 0, which are then used to update the discriminator
            Jd = Ey1:T ∼pdata [log Dφ (y1:T )]              (6)
                                                                  with the objective (6) (but replacing the data distribution pdata
               + Ey1:T ∼Gθ [log (1 − Dφ (y1:T ))] ,               with self-augmented data distribution paug for this setting).
 where pdata represents the training data distribution; and          As mentioned above, the discriminator provides rewards
Gθ represents the generated-sentence distribution. In a large-    to the generator to measure the current generated data quality.
model setting, one can choose to use a pre-trained MLE-           If a generated sentence is similar to sentences in the buffer
based model and only train the discriminator with its gener-      DB , the discriminator will output a larger value. In this way,
ated sentences and the training data.                             a large reward will be fed into the generator, which is then up-
                                                                  dated in a reinforcement-learning setting. Specifically, given
The Self Data Augmentation stage                                  the current generated sentence y1:T ∈ Gθ , a reward QG         θ
                                                                                                                               Dφ
Despite the simplicity, an MLE-trained model usually gen-
                                                                  is first calculated by (1) and (2). Finally, the generator is
erates low-quality text in practice. A natural idea for im-
                                                                  updated with the following gradient: ∇θ J(θ)
provement is to select some high-quality sentences from the
generated data as the augmented data to enlarge the original          T
                                                                      X                          h                               i
training dataset. However, we note two issues with this ap-       =         Eyt ∼Gθ (yt |y1:t−1 ) ∇θ log Gθ (yt |y1:t−1 ) · QG θ
                                                                                                                                   (7)
                                                                                                                             Dφ
proach: i) It is hard to control the percentage of augmented          t=1
data mixed with the training data. ii) If we direct mix aug-
mented data with training data, it might have an adverse effect   The imitating training data step This step is the same as
as there might be some low-quality data in the augmented          the previous step except that the data fed to the discriminator
data. To overcome these issues, we propose to separate the        is the training data with label 1 (instead of the self-augmented
augmented and training data by using a buffer DB to store the     data) and the currently generated data with label 0 through
self-augmented data. In this way, i) we could conveniently        (6). After training the discriminator, the reward is calculate
control the percentage of augmented data at any time; and ii)     through (1) and (2), which are then used to update the gen-
the low-quality augmented data will not have a direct effect      erator with the objective (7). Replacing the augmented data
for the training data.                                            in the imitating self-augmented data step with training data
   To apply the imitation-learning framework in our setting,      would make the generator imitate the training data.
we need to address the issue that there is only one data source
in imitation learning. If we were to mix parts of the train-      The whole process
ing data and self-augmented data into a minibatch at every
epoch, we found the training usually unstable because the         To summarize, the Self Data Augmentation stage contains
data in one small minibatch are not representative enough. To     two steps: the imitating self augmented data step imitates the
overcome this, we propose to divide the self Data Augmen-         self-augmented data, while the imitating training data step
tation stage into two sub-steps, which alternatively updates      imitates the training data. These two steps run alternately. In
the discriminator using the training data (called the imitating   the beginning, we propose to run more imitating training data
training data step) and the self-augmented data in the buffer     steps and gradually increase the imitating self augmented
DB (called the imitating augmented data step), respectively.      data step. In our experiment, we run six imitating training
                                                                  data steps at the beginning. After updating the discriminator
Remark 1 Although these two steps in the Self Data Aug-           and generator, we run one imitating self augmented data
mentation stage seem heuristic, we show that they are ac-         step. After every ten epochs, we decrease one step of the
tually derived from a more principled view of self-imitation      imitating training data steps until it reduces to one. The
learning framework, which is detailed in the next section.        whole algorithm is described in algorithm 1.
Algorithm 1: Self Data Augmentation                                   Capacity regularizer       It has been shown in (Ho and Er-
                                                                       mon 2016b) that the adversarial regularizer can provide better
  Require: i) Generator network Gθ ; ii) Discriminator                 generalization. To imitate the expert policy, we adopt the ca-
      network Dφ ; iii) Training dataset DE ; iv) Reference            pacity regularizer ψ1 (r) in GAIL (Ho and Ermon 2016b):
      metric m(s); v) the buffer DB to store self-augmented            The specific form is given in the appendix .
      data.
   1: Initialize πθ , rφ ; initialize the buffer DB ← ∅                Self-guided regularizer To promote generating subopti-
    2:   # Stage I: MLE Training stage                                 mal demonstrations that are well aligned with the diverse
    3:   repeat                                                        tasks, we propose the self-guided regularizer as a weighted
    4:     Train generator Gθ via MLE with (5)                         linear function over rewards:
    5:     Train discriminator Dφ with (6)                                                                X
    6:   until converge                                                                     ψ0 (r) =            q(s, a)r(s, a)            (10)
    7:   # Stage II: Self Data Augmentation stage                                                         s,a
    8:   repeat
    9:     Generate a set of sentences W from policy Gθ , Update the
                                                                       where q(s, a) is a probability function constructed via evalu-
           buffer DB by choosing the sentences from W with highest     ation metrics. The specific form of q is defined in Appendix.
           reference-metric values.                                    Generally, this regularizer encourages higher rewards in re-
  10:      if Imitating self augmented data step then                  gions where q(s, a) are large.
  11:         Sample sentences S ∼ πθ and SB ∼ DB                         By plugging the capacity regularizer and the self-guided
  12:         Update the discriminator Dφ using S and SB with (6)      regularizer into (8), we formulate our objective as an op-
  13:      else if Imitating training data step then                   timization problem to find the optimal reward r∗ and the
  14:         Sample sentences S ∼ πθ and SE ∼ DE                      optimal policy π ∗ :
  15:         Update the discriminator Dφ using S and SE with (6)
  16:      end if                                                        {π ∗ , r∗ } = arg min max −2H(π) − ψ1 (r)+                (11)
                                                                                           π∈Π r∈RS×A
                                  G
  17:      Update the reward QDθφ with (1 and 2)                         X                   X                      X
                                                                             q(s, a)r(s, a)+     ρE (s, a)r(s, a)−2   ρ(s, a)r(s, a),
  18:      Update the generator Gθ with (7)
                                                                         s,a                  s,a                           s,a
  19:    until Converge
                                                                        where we have set λ = 1 in (9), and the factor “2” is in-
                                                                       troduced for formulation convenience. The problem (11) is
                                                                       called IRL with self enhancement.
  Understanding the Self Data Augmentation
    Stage as IRL with Self Enhancement                                 Solving (11) with a two-step algorithm        Due to the in-
In this section, we show that the self data augmentation stage,        troduction of q(s, a), solving (11) directly is difficult. We
which consists of alternatively updating the discriminator             reformulate (11) with an augmented policy π1 :
with the two data sources, can be explained as a procedure to                                                   X
solve an inverse reinforcement learning (IRL) problem with                     min   max −H(π1 ) +               (q(s, a) − ρ(s, a))r(s, a)
                                                                           π1 ,π∈Π r∈RS×A
                                                                                                                s,a
self-enhancement. The goal of IRL is to recover the optimal
reward function r∗ (·, ·) as well as the corresponding policy
                                                                                             X
                                                                           − H(π) − ψ1 (r) +  (ρE (s, a) − ρ(s, a))r(s, a),
π ∗ (Ho and Ermon 2016a):                                                                           s,a

           {r∗ , π ∗ } = arg min max L(π, r) − ψ(r)           (8)                      s.t. r1 = r, π1 = π                                (12)
                             π∈Π r∈RS×A
                            X                                            Objective (12) now allows us to decompose the origi-
            = arg max           ρE (s, a)r(s, a)                       nal problem into two sub-problems. Adopting ideas from
                  r∈RS×A
                            s,a
                             X                                         ADMM (Boyd et al. 2011), we decompose (12) into two
            − [max H(π) +           ρ(s, a)r(s, a)] − ψ(r),            sub-problems, which represent as the objectives of the im-
               π∈Π
                              s,a                                      itating self augmented data step (objective (13)) and the
                                                                       imitating training data step (objective (14)) in our proposed
  where ρ(s, a) is the occupancy measure defined as the                data-augmentation framework:
stationary joint distribution of (s, a) under policy π, ψ(·) is
the regularizer.                                                          min max L1 (π1 , r1 ) , −H(π1 )                                 (13)
                                                                          π1 ∈Π r∈RS×A
   The traditional regularization in IRL only considers con-                X
straining the model capacity of the policy (Ho and Ermon                  +   (q(s, a) − ρ(s, a))r(s, a) + λW1 (π, π1 )
2016b), which makes it hard to be adapted to diverse tasks.                    s,a

Natural languages own intrinsic characteristics based on syn-             min max L2 (π, r) , −H(π)                              (14)
tax and semantics. To leverage this prior knowledge, we                   π∈Π r∈RS×A

propose incorporating user-defined evaluation measures into               − ψ1 (r) +
                                                                                     X
                                                                                      (ρE (s, a) − ρ(s, a))r(s, a) + λW1 (π, π1 ) .
ψ(r) so that the learned policy aligns well with the pre-                             s,a
defined metrics when generating unseen sentences. Specif-
ically, we define ψ(r) in (8) to be a mixture of a capacity             The regularizer W1 (π, π1 ) denotes the 1-Wasserstein dis-
regularizer ψ1 (r) and a self-guided regularizer ψ0 (r):               tance, and is introduced to accommodate the constraint of
                     ψ(r) = −λψ0 (r) + ψ1 (r) ,               (9)      π1 = π in the original problem (12). The two phases (13)
                                                                       and (14) are solved alternatively. More details are shown in
with λ > 0 the hyper-parameter to balance two terms.                   Appendix.
Experiments                                            where Gθ and Goracle denote our generator and oracle net-
We conduct experiments on two synthetic datasets and two                      work respectively. During training, we leverage Gθ to gener-
real datasets for the unconditional generation task. The un-                  ate 100,000 sentences and calculate NLLoracle every 5 train-
conditional generation task means input for the generator Gθ                  ing epoch.
is pure Gaussian noise, and the generated sentences should                       In the experiment, we first run the MLE Training stage
                                                                              for 100 epochs and store the generated data to the buffer
match the distribution with the training data distribution. We                DB . Then we update our discriminator and generator using
further test our method on a conditional generation task –                    the training data and the self augmented data selected by a
unsupervised style transfer, detailed in the Appendix.                        reference metric:
Experimental settings                                                                                       "   T
                                                                                                                X
                                                                                                                                             #
                                                                                       m(s) , Ey1:T ∼Gθ               log Gθ (yt |y1:t−1 )       (16)
Model architectures and baselines Two model bases:
                                                                                                                t=1
   i) Both the generator and discriminator bases are RNN
architectures. We test this model base with only the MLE                       As we can see from Fig. 1a and 1b, before the 100th epoch,
training (denoted as RNN). For the ablation study, we im-                     both RNN and LeakRNN models converge. At stage 2, im-
plement this model base with the Imitation Learning stage                     itation learning (both with and without self augmentation)
but with only the imitating training data step without using                  further improves model performance. However, we observed
augmented data, denoted as RNN+RL. Our method, com-                           that RNN+RL makes the training process less stable. This
bined with this model base, is denoted as RNN+SDA. Since                      phenomenon becomes more obvious as the training process
the EDA (Wei and Zou 2019) method could be applied with                       becomes more complicated. We suspect that learning a policy
the general MLE training paradigm, we compare our model                       for text is more challenging due to the complexity of text
with RNN+EDA. In addition, (Li et al. 2019; Roller et al.                     itself. Adding our self augmented step enlarges the training
2020) demonstrate the usefulness of the unlikelihood training                 dataset, making it easier to learn.
mechanism for controlling vocabulary usage and nontriv-                          Table 1: BLEUs on COCO Image-caption testing dataset.
ial repetition, which could be easily combined with MLE
training models. We thus compare our SDA method with                                                                      BLEU ↑
                                                                                                                2        3      4       5
unlikelihood training, denoted as RNN+UN for vocabulary
                                                                                       RNN                  0.800      0.599   0.367   0.201
usage control experiment and repetition control experiment.                            RNN+RL               0.819      0.601   0.370   0.211
   ii) The second model base adapts the “leak” architecture                            RNN+EDA              0.822      0.610   0.373   0.220
                                                                                       RNN+SDA (ours)       0.826      0.612   0.378   0.222
from LeakGAN (Guo et al. 2018). The Discriminator Dφ                                   LeakRNN              0.916      0.789   0.594   0.404
consists of a feature extractor and an output layer (the soft-                         LeakRNN+RL           0.921      0.797   0.602   0.416
max classification layer), whose output ranges in between                              LeakRNN+EDA          0.932      0.838   0.691   0.526
                                                                                       LeakRNN+SDA (ours)   0.943      0.856   0.738   0.599
[0, 1]. The Generator is designed to be a hierarchical RNN
network and can use sentence features to generate the next-
word. We call this model base with only the MLE train-                        Real data experiments
ing LeakRNN. We further implement models LeakRNN+RL,                          For a real data experiment, we run both the MLE Training
LeakRNN+SDA, LeakRNN+EDA and LeakRNN+UN using                                 stage and the Self Data Augmentation stage for 100 epochs.
the same method as first model base.                                          At the Imitation Learning stage, generated sentences are
                                                                              stored in the buffer DB . The BLEU score (Papineni et al.
Datasets We use two synthetic datasets and several                            2002) is calculated between sentences in the buffer DB and
real datasets, including the COCO image-caption dataset,                      some randomly selected 1,000 samples from the training
EMNLP2017 WMT dataset* for unconditional generation,                          dataset. We use the BLEU3 score as our reference metric
Yelp review dataset for conditional generation in Appendix.                   m(·) to construct the buffer.
The two synthetic datasets are generated following (Yu et al.                    According to our observation of real data experiments,
2017; Guo et al. 2018). Table 9 in Appendix summarizes                        directly applying the policy gradient makes the model vulner-
statics of all the datasets used in this paper.                               able and unstable, which is consistent with synthetic experi-
Synthetic-sata experiments                                                    ment results. The time we conduct the test has a significant
                                                                              impact on results. Nevertheless, monitoring the BLEU score
In order to better verify the effectiveness of our method and                 in real-time is time-consuming and laborious due to the large
track the training process, we adapt the synthetic data exper-
iments followed by (Yu et al. 2017; Guo et al. 2018). We                      testing dataset. Therefore, we examine the BLEU score every
leverage a randomly initialized LSTM as an oracle model to                    ten epochs in the imitation Learning stage and pick the best
generate the real data distribution p(xt |x1 , . . . , xt−1 ), which          results for RNN+RL and LeakRNN+RL. Since our method is
is considered to be a human server for real-world problems.                   stable, we only test the BLEU score result after 100 epochs
Then we use NLL to track the performance of the model                         for RNN+SDA and LeakRNN+SDA. In addition to the results
during training, defined as                                                   presented below, we also put some example-sentences in
                             "   T
                                                                   #          Appendix.
   NLLoracle = −Ey1:T ∼Gθ
                                 X
                                       log Goracle (yt |y1:t−1 )       (15)
                                                                              Short-text generation The original COCO dataset con-
                                 t=1
                                                                              sists of a total of 20,734 words and 417,126 sentences. Con-
                                                                              sidering the large gap between the average-length (11) and
   * http://statmt.org/wmt17/translation-task.html                            the max-length (50) of the sentences in this dataset, we only
LeakRNN

                                                           RNN
          RNN

         LeakRNN                                       LeakRNN

          RNN      (a)                                           (b)                                             (c)
Figure 1: Training curves on the synthetic dataset with sentence length 20 (a) and length 40 (b). The blue line before 100 epochs is our first
model base RNN. After 100 epochs, the solid blue line is RNN+SDA, while the blue dotted line is RNN+RL. The yellow line before 100 epochs
is our second model base LeakRNN. After 100 epochs, the solid yellow line is LeakRNN+SDA, while the yellow dotted line is LeakRNN+RL;
(c): Test BLEU score between training dataset. A lower value on the y-axis means generative sentences achieve higher accuracy; a Lower value
         LeakRNN
on the x-axis means the higher ability for generating sentences instead of memorializing sentences.

randomly selected sentences with lengths less than 32, end-               asked to rate over 100 randomly sampled sentences from
ing up with a training dataset of 120,000 sentences and a                 each model with a scale from 1 to 5. The results shown in
test dataset of 10,000 sentences. The experiment results are              Table 3 indicate the outstanding performance of our proposed
shown in Table 1. Although we do not pick the best result dur-            methods.
ing the training epoch, our self augmentation method (SDA)                              Table 3: Human evaluation results.
still surpasses other baselines in all metrics, demonstrating
                                                                                  Methods          LeakRNN   LeakRNN+RL   LeakRNN+SDA   Real
the effectiveness of our data-augmentation scheme.
                                                                                  Human scores ↑    2.97        2.67         3.11       4.10
       Table 2: BLEU score on test data set of EMNLP2017.
                                        BLEU ↑                            Grammar check evaluations
                               2       3      4        5
                RNN           0.626   0.350   0.160   0.085               We employ the grammar check tool † for Grammar Check
                RNN+RL        0.630   0.354   0.164   0.087               Evaluation, which is a professional software widely adopted
                RNN+EDA       0.639   0.361   0.172   0.106               for English writing checking in terms of contextual spelling,
                RNN+SDA       0.641   0.368   0.176   0.109
                LeakRNN       0.914   0.692   0.461   0.280               grammar, punctuation, sentence structure, style, vocabulary
                LeakRNN+RL    0.916   0.707   0.469   0.289               enhancement, measured by the critical issue rate (CIR). The
                LeakRNN+EDA   0.921   0.738   0.509   0.312
                LeakRNN+SDA   0.929   0.744   0.519   0.330               results are shown in Table 4. Please refer to Appendix for
                                                                          more details.
Long-text generation To increase text generation diffi-
culty, we use the whole sentences (with a maximum length                  Memorization evaluations
of 50) without discarding low-frequency words on the                      This experiment aims at comparing the generalization abil-
EMNLP2017 WMT dataset. The experimental settings are                      ity of different methods. To this end, we note that BigGAN
the same as those used in the COCO dataset. Results are pre-              (Brock, Donahue, and Simonyan 2018) carried out the nearest
sented in Table 2. Our method outperforms other baselines.                neighbor analysis on ImageNet (Deng et al. 2009) to shown
                                                                          their generator does not merely memorize data from the train-
                   Analysis and Discussion                                ing set, which is a direct reflection of a model’s generalization
To further evaluate the effectiveness and usefulness of our               ability. As a result, we compare the BLEU score between our
method, we carry out the following analyses and discussions               generated sentences and training dataset with LeakRNN and
on the more challenging dataset – EMNLP WMT dataset,                      LeakRNN+RL. Figure 1c shows a better ability of our model
which has a longer sentence length, contains much more                    for generating sentences instead of memorizing sentences
sentence structures and syntax than the COCO dataset.                     directly from the training dataset.
   This section first conducts a human evaluation on Amazon
Mechanical Turk, grammar check evaluation, and memoriza-                  Improvement with different training data sizes
tion evaluation. Then we carry out ablation studies to explore            To test our method’s performance with different sizes of the
the effect of different training data sizes and the necessity             training set, we split the training dataset of EMNLP2017 to
of imitation learning. At last, we design two task-specific               5%, 10%, and 50%, respectively. We compare our method
reference metrics to control vocabulary usage and nontrivial              with the MLE baseline LeakRNN. The results are shown in
repetitions.                                                              Table 7. Interestingly, with our data augmentation scheme,
                                                                          our method could boost performance over the model base
Human evaluations                                                         more significantly on smaller training data. We speculate that
To further evaluate the generation quality, we conduct a hu-
                                                                              †
man evaluation on Amazon Mechanical Turk; 3 judges are                            https://www.grammarly.com/
Table 4: Grammar evaluation results.                   Table 5: Vocabulary usage control.                     Table 6: Repetition control.
     Model               Ave Len ↑   CIR ↓
                                                                          BLEU2 ↑   Occurrence                              BLEU3 ↑   self-BLEU3 ↓
     Training Dataset     27.526      2.583
                                                          LeakRNN          0.914      0.2%                   LeakRNN         0.914        0.838
     LeakRNN              23.990      9.880
                                                          LeakRNN+UN       0.915      0.9%                   LeakRNN+UN      0.912        0.800
     LeakRNN+RL           26.280     12.929
                                                          LeakRNN+SDA      0.928      0.5%                   LeakRNN+SDA     0.920        0.816
     LeakRNN+SDA          27.850      9.874

Table 7: Performance comparison on different size of training                       occurs 20 times in the training dataset. Once the generated
dataset.                                                                            sentences contain the word ”project,” we add 0.0025 to our
                              Size      BLEU2 ↑    Increase %                       reference metrics. If the generated sentence contains two
                5%
                           LeakRNN        0.8199
                                                     4.12%
                                                                                    special words, we add 1/o2 term twice.
                         LeakRNN+SDA      0.8537                                       As a comparison, we adopt the unlikelihood estimation
                10%
                           LeakRNN        0.8551
                                                     2.28%                          (Welleck et al. 2019; Li et al. 2019; Roller et al. 2020) and
                         LeakRNN+SDA      0.8746
                                                                                    combine it with the LeakRNN model (named LeakRNN+UN)
                           LeakRNN        0.9024
                50%
                         LeakRNN+SDA      0.9181
                                                     1.74%                          to control vocabulary usage with a modified objective:
                           LeakRNN        0.9144
                                                                                                   
                100%                                 1.60%                                       T  log Gθ (yt |y1:t−1 )             for yt 6∈ Crare
                         LeakRNN+SDA      0.9290                                                 X
                                                                                      Jg (π) =
                                                                                                        log G (y |y                  for yt ∈ Crare .
                                                                                                 t=1          θ  t 1:t−1 ) + 1/o

since overfitting tends to be more severe when training on
smaller datasets, our method could alleviate the overfitting                           Table 5 shows that our method could control the appear-
problem seamlessly as a data augmented method.                                      ance of rare words to a certain extent, and can also improve
                                                                                    the quality of generated sentences.
The necessity of imitation learning
                                                                                    Nontrivial repetition control
We conduct experiments to verify the necessity of imitation
learning to mimic both the training data and self-augmented                         A related issue to vocabulary usage is that generative models
data. To this end, instead of leverage imitation learning to                        also tend to repeat (Holtzman et al. 2019). (Roller et al. 2020)
update the generator alternatively, we mix training data with                       uses unlikelihood training to penalize 3-gram repetition in
5% the buffer DB at each epoch and use these mixed data                             generated sentences. We use this model as a baseline, de-
to update the generator after the MLE Training stage. Mixup                         noted as LeakRNN+UN. Again, we use equation (17) as the
(Zhang et al. 2017) could be successfully applied to data                           reference metric, with o representing the occurrence times
augmentation for the image. We tried this method, but it                            of repeated 3-gram phases in a sentence. We adapt the self-
did not gain better performance than EDA (Wei and Zou                               BLEU3 score (Zhu et al. 2018) to measure the repetition of
2019) in text augmentation. Then we did not include the                             3-gram phases in the generated sentences.
result from this method in this article. Our results show that                         Table 6 shows that unlikelihood training could significantly
LeakRNN achieved 0.916 BLEU2 score, while the model                                 reduce the repetition rate at the cost of lowering generation
updating generator used mixed data achieves 0.919, versus                           quality, i.e., it can help control repetition with the risk of
0.943 with our LeakRNN+SDA model. The results indicate                              saying something that makes less sense. On the contrary, our
the benefits of using imitation learning.                                           method could seamlessly control repetition while keeping
                                                                                    generation quality.
Vocabulary usage control
It has been observed that generative models tend to gener-
                                                                                                              Conclusion
ate common words too frequently, and rare words too in-                             We propose a general and flexible text augmentation method
frequently, as compared with human distribution (Holtzman                           that could be easily incorporated into the popular MLE based
et al. 2018; Welleck et al. 2019; Li et al. 2019; Roller et al.                     methods to boosting model performance. Essentially, we
2020). We first analyze our training data distribution, which                       propose to add a Self Data Augmentation stage after the
contains over 7 million words, with 5,722 unique characters                         standard MLE Training stage. At this stage, our generator
in the dictionary. That means each word appears 1,260 times                         alternatively imitates training sentences and self-augmented
on average. However, 20 words occur less than ten times, 43                         sentences selected by a reference metric. We show formally
words less than 20 times.                                                           that doing so can enhance the generator to focus on higher-
   Our model could be used to control vocabulary usage dur-                         quality sentences via the proposed IRL formulation with
ing training due to the flexibility of defining a specific metric                   self-enhancement. Moreover, by defining different reference
to select generated sentences. To this end, we collect charac-                      metrics, we can emphasize some characteristics of the gener-
ters with occurrences of less than 50 to be a special character                     ated sentences, e.g., vocabulary usage, and nontrivial repeti-
list Crare . We define a new reference metric as:                                   tion control. Extensive experiments are conducted for both
                        m(s) , BLEU(s) + 1/o2 ,                         (17)        unconditional and conditional text-generation, verifying the
                                                                                    effectiveness of our proposed data-augmentation framework
where o means occurrences times in the training dataset for                         for NLP.
words in the collection Crare . For example, the word ”project”
References                                 Hu, Z.; Yang, Z.; Liang, X.; Salakhutdinov, R.; and Xing,
Ambrosio, L.; Gigli, N.; and Savaré, G. 2005. Gradient Flows     E. P. 2017. Controllable text generation. In ICML.
in Metric Spaces and in the Space of Probability Measures.        Kafle, K.; Yousefhussien, M.; and Kanan, C. 2017. Data
Lectures in Mathematics ETH Zürich.                              augmentation for visual question answering. In Proceedings
Bahdanau, D.; Brakel, P.; Xu, K.; Goyal, A.; Lowe, R.;            of the 10th International Conference on Natural Language
Pineau, J.; Courville, A.; and Bengio, Y. 2017. An actor-         Generation, 198–202.
critic algorithm for sequence prediction. In ICLR.                Kobayashi, S. 2018. Contextual augmentation: Data augmen-
Boyd, S.; Parikh, N.; Chu, E.; Peleato, B.; and Eckstein,         tation by words with paradigmatic relations. arXiv preprint
J. 2011. Distributed Optimization and Statistical Learning        arXiv:1805.06201 .
via the Alternating Direction Method of Multipliers. Found.       Li, J.; Jia, R.; He, H.; and Liang, P. 2018. Delete, retrieve,
Trends Mach. Learn. 3(1): 1–122. ISSN 1935-8237. doi:10.          generate: A simple approach to sentiment and style transfer.
1561/2200000016. URL http://dx.doi.org/10.1561/                   arXiv preprint arXiv:1804.06437 .
2200000016.
                                                                  Li, M.; Roller, S.; Kulikov, I.; Welleck, S.; Boureau, Y.-L.;
Brock, A.; Donahue, J.; and Simonyan, K. 2018. Large scale        Cho, K.; and Weston, J. 2019. Don’t Say That! Making
gan training for high fidelity natural image synthesis. arXiv     Inconsistent Dialogue Unlikely with Unlikelihood Training.
preprint arXiv:1809.11096 .                                       arXiv preprint arXiv:1911.03860 .
Caccia, M.; Caccia, L.; Fedus, W.; Larochelle, H.; Pineau,        Liu, Q.; Li, L.; Tang, Z.; and Zhou, D. 2018. Breaking the
J.; and Charlin, L. 2018. Language gans falling short. arXiv      Curse of Horizon: Infinite-Horizon Off-policy Estimation. In
preprint arXiv:1811.02549 .                                       NIPS.
Deng, J.; Dong, W.; Socher, R.; Li, L.-J.; Li, K.; and Fei-       Mikolov, T.; Karafiát, M.; Burget, L.; Černockỳ, J.; and Khu-
Fei, L. 2009. ImageNet: A Large-Scale Hierarchical Image          danpur, S. 2010. Recurrent neural network based language
Database. In CVPR09.                                              model. In Eleventh Annual Conference of the International
Fadaee, M.; Bisazza, A.; and Monz, C. 2017. Data augmen-          Speech Communication Association.
tation for low-resource neural machine translation. arXiv         Ng, A. Y.; Harada, D.; and Russell, S. J. 1999. Policy In-
preprint arXiv:1705.00440 .                                       variance Under Reward Transformations: Theory and Ap-
Finn, C.; Levine, S.; and Abbeel, P. 2016. Guided Cost Learn-     plication to Reward Shaping. In Proceedings of the Six-
ing: Deep Inverse Optimal Control via Policy Optimization.        teenth International Conference on Machine Learning, ICML
In Proceedings of the 33rd International Conference on In-        ’99, 278–287. San Francisco, CA, USA: Morgan Kauf-
ternational Conference on Machine Learning - Volume 48,           mann Publishers Inc. ISBN 1-55860-612-2. URL http:
ICML’16, 49–58. JMLR.org. URL http://dl.acm.org/                  //dl.acm.org/citation.cfm?id=645528.657613.
citation.cfm?id=3045390.3045397.                                  Ng, A. Y.; and Russell, S. J. 2000. Algorithms for In-
Goodfellow, I.; Pouget-Abadie, J.; Mirza, M.; Xu, B.; Warde-      verse Reinforcement Learning. In Proceedings of the Sev-
Farley, D.; Ozair, S.; Courville, A.; and Bengio, Y. 2014.        enteenth International Conference on Machine Learning,
Generative Adversarial Nets. In Neural Information Process-       ICML ’00, 663–670. San Francisco, CA, USA: Morgan Kauf-
ing Systems (NIPS).                                               mann Publishers Inc. ISBN 1-55860-707-2. URL http:
                                                                  //dl.acm.org/citation.cfm?id=645529.657801.
Guo, J.; Lu, S.; Cai, H.; Zhang, W.; Yu, Y.; and Wang, J. 2018.
Long Text Generation via Adversarial Training with Leaked         Oh, J.; Guo, Y.; Singh, S.; and Lee, H. 2018. Self-imitation
Information. In AAAI.                                             learning. arXiv preprint arXiv:1806.05635 .
Ho, J.; and Ermon, S. 2016a. Generative Adversar-                 Papineni, K.; Roukos, S.; Ward, T.; and Zhu, W.-J. 2002.
ial Imitation Learning.  In Lee, D. D.; Sugiyama,                 BLEU: a method for automatic evaluation of machine trans-
M.; Luxburg, U. V.; Guyon, I.; and Garnett, R.,                   lation. In Proceedings of the 40th annual meeting on asso-
eds., Advances in Neural Information Processing Sys-              ciation for computational linguistics, 311–318. Association
tems 29, 4565–4573. Curran Associates, Inc.    URL                for Computational Linguistics.
http://papers.nips.cc/paper/6391-generative-                      Ranzato, M.; Chopra, S.; Auli, M.; and Zaremba, W. 2016.
adversarial-imitation-learning.pdf.                               Sequence level training with recurrent neural networks. In
Ho, J.; and Ermon, S. 2016b. Generative adversarial imitation     ICLR.
learning. In NIPS.                                                Rennie, S. J.; Marcheret, E.; Mroueh, Y.; Ross, J.; and Goel,
Holtzman, A.; Buys, J.; Du, L.; Forbes, M.; and Choi, Y.          V. 2016. Self-critical sequence training for image captioning.
2019. The curious case of neural text degeneration. arXiv         In CVPR.
preprint arXiv:1904.09751 .                                       Roller, S.; Dinan, E.; Goyal, N.; Ju, D.; Williamson, M.; Liu,
Holtzman, A.; Buys, J.; Forbes, M.; Bosselut, A.; Golub,          Y.; Xu, J.; Ott, M.; Shuster, K.; Smith, E. M.; et al. 2020.
D.; and Choi, Y. 2018. Learning to write with cooperative         Recipes for building an open-domain chatbot. arXiv preprint
discriminators. arXiv preprint arXiv:1805.06087 .                 arXiv:2004.13637 .
Sennrich, R.; Haddow, B.; and Birch, A. 2015. Improving
neural machine translation models with monolingual data.
arXiv preprint arXiv:1511.06709 .
Shen, T.; Lei, T.; Barzilay, R.; and Jaakkola, T. 2017. Style
transfer from non-parallel text by cross-alignment. In NIPS.
Song, J.; Ren, H.; Sadigh, D.; and Ermon, S. 2018. Multi-
Agent Generative Adversarial Imitation Learning. In NIPS.
Sudhakar, A.; Upadhyay, B.; and Maheswaran, A. 2019.
Transforming delete, retrieve, generate approach for con-
trolled text style transfer. arXiv preprint arXiv:1908.09368
.
Sugiyama, A.; and Yoshinaga, N. 2019. Data augmenta-
tion using back-translation for context-aware neural machine
translation. In Proceedings of the Fourth Workshop on Dis-
course in Machine Translation (DiscoMT 2019), 35–44.
Wang, W. Y.; and Yang, D. 2015. That’s so annoying!!!:
A lexical and frame-semantic embedding based data aug-
mentation approach to automatic categorization of annoying
behaviors using# petpeeve tweets. In Proceedings of the
2015 Conference on Empirical Methods in Natural Language
Processing, 2557–2563.
Wang, X.; Chen, W.; Wang, Y.-F.; and Wang, W. Y. 2018. No
metrics are perfect: Adversarial reward learning for visual
storytelling. arXiv preprint arXiv:1804.09160 .
Wei, J. W.; and Zou, K. 2019. Eda: Easy data augmentation
techniques for boosting performance on text classification
tasks. arXiv preprint arXiv:1901.11196 .
Welleck, S.; Kulikov, I.; Roller, S.; Dinan, E.; Cho, K.; and
Weston, J. 2019. Neural text generation with unlikelihood
training. arXiv preprint arXiv:1908.04319 .
Williams, R. J. 1992. Simple statistical gradient-following
algorithms for connectionist reinforcement learning. Machine
learning 8(3-4): 229–256.
Yang, Z.; Hu, Z.; Dyer, C.; Xing, E. P.; and Berg-Kirkpatrick,
T. 2018. Unsupervised text style transfer using language
models as discriminators. In Advances in Neural Information
Processing Systems, 7287–7298.
Yu, L.; Zhang, W.; Wang, J.; and Yu, Y. 2017. SeqGAN:
Sequence Generative Adversarial Nets with Policy Gradient.
In AAAI.
Zhang, H.; Cisse, M.; Dauphin, Y. N.; and Lopez-Paz, D.
2017. mixup: Beyond empirical risk minimization. arXiv
preprint arXiv:1710.09412 .
Zhang, X.; Zhao, J.; and LeCun, Y. 2015. Character-level
convolutional networks for text classification. In Advances
in neural information processing systems, 649–657.
Zheng, Z.; Zhou, H.; Huang, S.; Li, L.; Dai, X.-Y.; and Chen,
J. 2019. Mirror-Generative Neural Machine Translation. In
International Conference on Learning Representations.
Zhu, Y.; Lu, S.; Zheng, L.; Guo, J.; Zhang, W.; Wang, J.; and
Yu, Y. 2018. Texygen: A Benchmarking Platform for Text
Generation Models. arXiv preprint arXiv:1802.01886 .
Preliminaries about Inverse Reinforcement
                  Learning                                                   {r∗ , π ∗ } , IRL(πE )
                                                                                              X
An alternative way (Ho and Ermon 2016b; Liu et al. 2018) to                   = arg max          ρE (s, a)r(s, a)
                                                                                    r∈RS×A
write the RL objective (equation 3) is:                                                         s,a
                                                                                                 X
                                                h            i                − [max H(π) +             ρ(s, a)r(s, a)]
                                                      G                          π∈Π
 J(π) = Est ∼dπ (·),yt ∼Gθ (yt |y1:t−1 ) QDθφ (y1:t−1 , yt ) , with                               s,a

                  T                             T                             = arg max min L(π, r)
                  X         t−1
                                                X      t                           r∈RS×A π∈Π
 dπ (s) , lim (         γ         dπ,t (s))/(         γ ),       (18)                  X                     X
          T →∞                                                                , −H(π)+     ρE (s, a)r(s, a)−   ρ(s, a)r(s, a).
                  t=0                           t=0
                                                                                          s,a                        s,a
   where dπ,t (s) is the distribution of st when executing
policy π starting from some initial state distribution; and              Intuitively, the objective enforces the expert policy πE to
dπ (s) is the state marginal distribution.                              induce higher rewards (the “max” part) than all other policies
   In order to unify the symbols in our formula with the                (the “min” part). Note this setting is sub-optimal if the given
symbols in Reinforcement Learning, we then use r(st , at ) to           data is noisy, as the learned policy is expected to perform
                                                                        worse than the expert policy (and equal to the expert policy at
represent QG  Dφ (y1:t−1 , yt ), at ∼ π(·| st ) to represent yt ∼
                θ
                                                                        optimality). Besides, the IRL objective is ill-defined (Ng and
Gθ (yt |y1:t−1 ) in the following theories.                             Russell 2000; Finn, Levine, and Abbeel 2016), which often
   The occupancy measure ρ(s, a) is defined as the stationary           induces multiple valid solutions to the problem because the
joint distribution of (s, a) under policy π, i.e., ρ(s, a) ,            solution space is too flexible. For example, one can assign an
dπ (s)π(a | s). Thus, given an environment, there is a one-             arbitrary reward to trajectories not visited by the expert, as
to-one correspondence between the policy π and occupancy                long as these trajectories yield lower rewards than the expert
measure ρ (Ho and Ermon 2016b). Consequently, RL can be                 trajectories (Song et al. 2018). To address these problems,
equivalently rewritten as                                               some constraints are typically placed on the reward functions,
                                                                        e.g., a convex reward functional, ψ : RS×A → R, is usually
       π ∗ , RL(π) = arg max E(s,a)∼ρ [r(s, a)] .                (19)   introduced as a regularizer (Ho and Ermon 2016b). Finally,
                                     π∈Π                                the IRL is formulated as
                                                                              {π ∗ , r∗ } = arg min max L(π, r) − ψ(r).           (21)
  To formulate text generation as RL, we may consider the                                        π∈Π r∈RS×A
process as an agent taking actions to generate next words,
conditioned on the states of observed words, i.e.,                                     Self-Improved Inverse RL
                                                                        In this paper, we advocate that IRL is a more natural frame-
                   st = w 0 is the weighting hyper-parameter to balance
learn the optimal reward function from the expert trajectories          two terms.
DE , which is represented as a training dataset in text genera-
tion. Typically, DE consists of state-action pairs generated            Capacity Regularizer        It has been shown in (Ho and Er-
from an expert policy πE , with the corresponding occupancy             mon 2016b) that the adversarial regularizer can provide better
measure ρE . The goal of IRL is to recover the optimal re-              generalization. To imitate the expert policy, we adopt the ca-
ward function r∗ (·, ·) as well as the corresponding policy π ∗ .       pacity regularizer ψ1 (r) in GAIL (Ho and Ermon 2016b).
Formally, we have for IRL‡ .
                                                                        Self-Guided Regularizer To promote generating subop-
   ‡
     The original goal of IRL is to find the optimal reward function.   timal demonstrations that are well aligned with the diverse
In this paper, we generalize the definition to finding both optimal     tasks, we propose the self-guided regularizer as a weighted
reward and policy.                                                      linear function over rewards:
X                                                                        X
                  ψ0 (r) =         q(s, a)r(s, a)            (23)        min     max −H(π1 ) +         (q(s, a) − ρ(s, a))r(s, a)
                                                                        π1 ,π∈Π r∈RS×A
                             s,a                                                                      s,a
where q(s, a) is a probability function constructed via evalua-                               X
                                                                        − H(π) − ψ1 (r) +      (ρE (s, a) − ρ(s, a))r(s, a),
tion metrics (to be defined below). Generally, this regularizer
                                                                                              s,a
encourages higher rewards in regions where q(s, a) are large.
In text generation, the evaluation metrics are used to evaluate                       s.t. r1 = r, π1 = π                             (25)
certain quality aspects of a generated sentence, such as the
BLEU or ROUGE scores. Formally, we define the notation                Objective (25) now allows us to decompose the origi-
of reference metric.                                                nal problem into two sub-problems. Adopting ideas from
Definition 2 (Reference Metric) A reference metric is de-           ADMM (Boyd et al. 2011), we decompose (25) into two
fined as a function, m : S → R, mapping a sentence s ∈ S            sub-problems, an imitating self augmented data step and an
to a real value measuring the goodness of the sentence.             imitating training data step:
  For simplicity, we assume that a higher value of m(s) indi-
cates higher quality of the sentence s. Note that the standard              imitating self augmented data step:                       (26)
IRL with a fixed pool of expert trajectories does not nec-
                                                                             min max L1 (π1 , r1 ) , −H(π1 )
essarily guarantee to learn a reward function that matches                  π1 ∈Π r∈RS×A
the reference metric. We conjecture it may be caused by                        X
                                                                             +   (q(s, a) − ρ(s, a))r(s, a) + λW1 (π, π1 )
the fact that the expert trajectories alone can be noisy and                    s,a
unrepresentative.
  To enhance the state space in terms of the reference metric,        imitating training data step:                                   (27)
we focus on a sub-space, S̄ ⊆ S, with high reference-metric
                                                                       min max L2 (π, r) , −H(π)
values, i.e., S̄ , {s ∈ S : m(s) > m∗ } for some threshold             π∈Π r∈RS×A
m∗ . Next, the concentrated distribution q(s, a) is defined as                    X
                                                                       − ψ1 (r) +  (ρE (s, a) − ρ(s, a))r(s, a) + λW1 (π, π1 ) .
             1                                                                        s,a
 q(s, a) =     dπ (s)π(a | s)δ(m(s) > m∗ ) , d¯π (s)π(a | s) ,
             Z
where the indication function δ(m(s) > m∗ ) = 1 if                     The regularizer W1 (π, π1 ) above is introduced to accom-
m(s) > m∗ and 0 otherwise; Z is the normalizer to en-               modate the constraint of π1 = π in the original problem (25).
sure q to be a proper distribution. In contrast to the original     The two phases (26) and (27) are solved alternatively. By
occupancy measure, ρ(s, a) = dπ (s)π(a | s) for all state-          parametrizing all the π, π1 and r with DNNs, they can be
action pair in s ∈ S, q(s, a) is more concentrated towards          updated by solving the min-max problems similar to GAIL
regions with higher reference-metric values. In another word,       (Ho and Ermon 2016a). The only difference lies on the W1
q(s, a) is larger than ρ(s, a) as the probabilities are only de-    regularizer, where the gradients of policy parameters need to
fined on the sub-space S̄ with high reference-metric values.        be derived. Fortunately, this can be solved by the following
This method is related to reward shaping (Ng, Harada, and           theorem.
Russell 1999), however, in a more principled manner. Un-            Theorem 1 Let π be parameterized with θ, denoted as πθ .
like traditional reward-shaping strategy where the policy is        The gradient of W1 (πθ , π1 ) w.r.t.θ can be calculated as
directly trained to maximize certain heuristic thus it could be                              Z
noisy, the uncertainty in q(s, a) still allows an automatic and                                   δW1
                                                                          ∇θ W1 (πθ , π1 ) =             (πθ , π1 ) · ∇θ πθ (Y ) dY
softened exploration and exploitation trade-off.                                                  δπθ
                                                                               Z
   By plugging (22) and (23) into (21), we formulate our                           δW1                 ∇θ πθ (Y )
                                                                             =            (πθ , π1 ) ·               · πθ (Y ) dY
objective as an optimization problem to find the optimal                            δπθ                  πθ (Y )
reward r ∗ and the optimal policy π ∗ :                                        Z
                                                                                   δW1
                                                                             =            (πθ , π1 ) · log πθ (Y ) · πθ (Y ) dY
                                                                                    δπθ
  {π ∗ , r∗ } = arg min max −2H(π) − ψ1 (r)+
                                                                                                                               
                                                            (24)                           δW1
                    π∈Π r∈RS×A                                               = EY ∼πθ            (πθ )(Y ) · ∇θ log πθ (Y )
  X                   X                      X                                             δπθ
      q(s, a)r(s, a)+     ρE (s, a)r(s, a)−2   ρ(s, a)r(s, a),
  s,a                 s,a                     s,a                    where the first equality follows by the chain rule, with
                                                                    δW1
                                                                    δπθ (πθ ) defined as the first variation of W1
                                                                                                                 (Ambrosio, Gigli,
 where we have set λ = 1 in (22), and the factor “2” is
                                                                    and Savaré 2005). Furthermore, rewrite W1 (πθ , π1 ) in its
introduced for formulation convenience.                             dual form as: W1 (πθ , π1 ) = maxkDkLip
Theorem 1 applies to the parameterized π1 as well. It is          Reward update We first parameterize r with φ imple-
now clear that to calculate gradients of the W1 term, one            mented as a deep neural network (DNN), which takes a
i) first draw samples from the two policies to update the            state-action pair as input and outputs a real value, i.e.,
discriminator D; ii) then use (28) to calculate ∇θ W1 . This         rφ (·, ·) , − log D(·, ·). Consequently, the reward network
consequently allows us to optimize (26) and (27) by gradient         can be optimized as in the standard GAIL framework with
descent. It is worth noting that when the two policies in the        the following objective:
W1 arguments are close enough to each other, the second part
of the RHS of (28) will become relatively small, allowing us
to drop it for computational efficiency. The whole algorithm                      φ ∗ = arg max Eπ [log Dφ (s, a)] +             (30)
                                                                                               φ
is illustrated in Algorithm 1.
                                                                                             EπE [log(1 − Dφ (s, a))]            (31)
                  Proof of Theorem 1                                   where the expectations over π and πE are approximated by
                                                                     samples from the current policy π (which is initialized with
The gradient of W1 can be expanded, based on the definition          π1 learned in the self-improvement phase) and the training
of functional derivative, as:                                        data, respectively.
                        Z
                            δW1
     ∇θ W1 (πθ , π1 ) =           (πθ , π1 ) · ∇θ πθ (Y )dY          Policy update Given the learned reward network rφ , opti-
                             δπθ
                                                                   mizing the policy can be done by standard policy-gradient-
                      δW1                                            based methods, with the following objective:
         = EY ∼πθ          (πθ , π1 ) · ∇θ log πθ (Y ) ,
                      δπθ

where the first equality follows by the chain rule, with               π = arg max H(π) + Eπ [rφ (s, a)] − λW1 (π1 , π) . (32)
δW1                                                                             π∈Π
 δπθ (πθ ) defined as the first variation of the functional W1
(Ambrosio, Gigli, and Savaré 2005).
   To prove the second claim, by substituting the optimal            Capacity Regularizer        It has been shown in (Ho and Er-
discriminator D∗ into the dual formula of W1 and noting D∗           mon 2016b) that the adversarial regularizer can provide better
is implicitly dependent on θ, we have                                generalization. To imitate the expert policy, we adopt the reg-
                                                                     ularizer in GAIL (Ho and Ermon 2016b), which defines ψ1
    W1 (πθ , π1 ) = EY ∼πθ [D∗ (Y )] − EY 0 ∼π1 D∗ (Y 0 ) .
                                                        
                                                                     with the following form:
                                                                                     
                                                                                        EπE [g(r(s, a))] if r(s, a) ≥ 0
                                                                           ψ1 (r) ,                                         ,
       ∂W1 (πθ , π1 )                                                                   +∞                  otherwise
    ⇒
         Z ∂θ                                                                                −x − log(1 − ex ) if x < 0
                                                                                          
       ∂                                                                   where g(x) =
             D∗ (Y )πθ (Y )dY − EY 0 ∼π1 ∇θ D∗ (Y 0 )
                                                     
    =                                                                                        +∞                    otherwise
      ∂θ
      Z                        Z
    = ∇θ πθ (Y )D∗ (Y )dY + ∇θ D∗ (Y )πθ (Y )−                                   Unsupervised Style Transfer
      EY 0 ∼π1 ∇θ D∗ (Y 0 )
                                                                   Our framework can also be extended to the task of style
                                                                     transfer. The key changes needed to make are: i) Define the
    = EY ∼πθ [D∗ (Y )∇θ log πθ (Y )] + EY ∼πθ [∇θ D∗ (Y )] −         generator as a conditional distribution, instead of the uncon-
      EY 0 ∼π1 ∇θ D∗ (Y 0 )
                           
                                                                     ditional distribution used for text generation; ii) Define the
    = EY ∼πθ [D∗ (Y )∇θ log πθ (Y )]+                                buffer DB to store two bins of sentences, each corresponding
                                                                     to one style; iii) Include some additional reconstructed loss
                           ∂D∗ (Y )          ∂D∗ (Y 0 )
                                                       
      EY ∼πθ ,Y 0 ∼π1 ∇θ Y          − ∇θ Y 0              ,          in order to preserve the reversibility of the learned policy.
                             ∂Y                ∂Y 0                  Detailed descriptions are given in Appendix .
                                                                        We apply the method for non-parallel text-style-transfer,
which completes the proof.
                                                                     where the goal is to transfer one sentence in one style (e.g.,
                                                                     positive) to a similar sentence but with a different style (e.g.,
        More Details of Imitation Learning                           negative). For a fair comparison, we use the same data and
                                                                     its split method as in (Shen et al. 2017). To measure whether
This describes our solution to the imitation learning phase
                                                                     the original test sentences have been transferred to the de-
(27). According to (Ho and Ermon 2016b), let D(s, a) ,               sired sentiment, we follow the settings of (Shen et al. 2017)
exp(−r(s, a)), then (27) can be reformulated as                      and employ a pretrained CNN classifier, which achieves an
 {π, r} = arg min max −H(π)                                   (29)   accuracy of 97.4% on the validation set, to evaluate the trans-
              π∈Π r∈RS×A                                             ferred sentences. The results are presented in Table 8. Note
 + λW1 (π1 , π) + Eπ [log D(s, a)] + EπE [log(1 − D(s, a))]          that both (Li et al. 2018) and (Sudhakar, Upadhyay, and Mah-
                                                                     eswaran 2019) are supervised approaches. It is observed that
 As in the standard GAIL algorithm, (29) can be solved by            our method obtains much higher accuracy than other meth-
alternating between the following two optimization problems.         ods, including the controllable text generation method (Hu
You can also read