Semi-supervised Learning for Quantification of Pulmonary Edema in Chest X-Ray Images - MIT

Page created by Clayton Lane
 
CONTINUE READING
Semi-supervised Learning for Quantification of Pulmonary Edema in Chest X-Ray Images - MIT
Semi-supervised Learning for Quantification of
                                               Pulmonary Edema in Chest X-Ray Images

                                         Ruizhi Liao1 , Jonathan Rubin2 , Grace Lam1 , Seth Berkowitz3 , Sandeep Dalal2 ,
                                                     William Wells1,4 , Steven Horng3 , and Polina Golland1
                                                     1
                                                        Massachusetts Institute of Technology, Cambridge, MA, USA
arXiv:1902.10785v3 [cs.CV] 10 Apr 2019

                                                                             ruizhi@mit.edu
                                                        2
                                                          Philips Research North America, Cambridge, MA, USA
                                         3
                                             Beth Israel Deaconess Medical Center, Harvard Medical School, Boston, MA, USA
                                             4
                                               Brigham and Women’s Hospital, Harvard Medical School, Boston, MA, USA

                                                 Abstract. We propose and demonstrate machine learning algorithms to
                                                 assess the severity of pulmonary edema in chest x-ray images of conges-
                                                 tive heart failure patients. Accurate assessment of pulmonary edema in
                                                 heart failure is critical when making treatment and disposition decisions.
                                                 Our work is grounded in a large-scale clinical dataset of over 300,000 x-
                                                 ray images with associated radiology reports. While edema severity labels
                                                 can be extracted unambiguously from a small fraction of the radiology
                                                 reports, accurate annotation is challenging in most cases. To take advan-
                                                 tage of the unlabeled images, we develop a Bayesian model that includes
                                                 a variational auto-encoder for learning a latent representation from the
                                                 entire image set trained jointly with a regressor that employs this rep-
                                                 resentation for predicting pulmonary edema severity. Our experimental
                                                 results suggest that modeling the distribution of images jointly with the
                                                 limited labels improves the accuracy of pulmonary edema scoring com-
                                                 pared to a strictly supervised approach. To the best of our knowledge,
                                                 this is the first attempt to employ machine learning algorithms to auto-
                                                 matically and quantitatively assess the severity of pulmonary edema in
                                                 chest x-ray images.

                                                 Keywords: Semi-supervised Learning · Chest X-Ray Images · Conges-
                                                 tive Heart Failure.

                                         1     Introduction
                                         We propose and demonstrate a semi-supervised learning algorithm to support
                                         clinical decisions in congestive heart failure (CHF) by quantifying pulmonary
                                         edema. Limited ground truth labels are one of the most significant challenges in
                                         medical image analysis and many other machine learning applications in health-
                                         care. It is of great practical interest to develop machine learning algorithms that
                                         take advantage of the entire data set to improve the performance of strictly su-
                                         pervised classification or regression methods. In this work, we develop a Bayesian
                                         model that learns probabilistic feature representations from the entire image set
                                         with limited labels for predicting edema severity.
Semi-supervised Learning for Quantification of Pulmonary Edema in Chest X-Ray Images - MIT
2      R. Liao et al.

    Chest x-ray images are commonly used in CHF patients to assess pulmonary
edema, which is one of the most direct symptoms of CHF [11]. CHF causes pul-
monary venous pressure to elevate, which in turn causes the fluid to leak from
the blood vessels in the lungs into the lung tissue. The excess fluid in the lungs
is called pulmonary edema. Heart failure patients have extremely heterogenous
responses to treatment [2]. The assessment of pulmonary edema severity will en-
able clinicians to make better treatment plans based on prior patient responses
and will facilitate clinical research studies that require quantitative phenotyping
of the patient status [1]. While we focus on CHF, the quantification of pulmonary
edema is also useful elsewhere in medicine. Quantifying pulmonary edema in a
chest x-ray image could be used as a surrogate for patient intravascular vol-
ume status, which would rapidly advance research in sepsis and other disease
processes where volume status is critical.
    Quantifying pulmonary edema is more challenging than detection of patholo-
gies in chest x-ray images [12, 13] because grading of pulmonary edema severity
relies on much more subtle image findings (features). Accurate grading of the
pulmonary edema severity is challenging for medical experts as well [4]. Our
work is grounded in a large-scale clinical dataset that includes approximately
330,000 frontal view chest x-ray images and associated radiology reports, which
serve as the source of the severity labels. Of these, about 30,000 images are of
CHF patients. Labels extracted from radiology reports via keyword matching are
available for about 6,000 images. Thus our image set includes a large number
of images, but only a small fraction of images is annotated with edema severity
labels.
    We use variational auto-encoder (VAE) to capture the image distribution
from both unlabeled and labeled images to improve the accuracy of edema sever-
ity grading. Auto-encoder neural networks have shown promise for representa-
tional modeling [10]. Earlier work attempted to learn a separate VAE for each
label category from unlabeled and labeled data [9]. We argue and demonstrate
in our experiments that this structure does not fit our application well, because
pulmonary edema severity score is based on subtle image features and should be
represented as a continuous quantity. Instead, we learn one VAE from the entire
image set. By training the VAE jointly with a regressor, we ensure it captures
compact feature representations for scoring pulmonary edema severity. Similar
setups have also been employed in computer vision [7]. The experimental results
show that our method outperforms the multi-VAE approach [9], the entropy
minimization based self-learning approach [3], and strictly supervised learning.
To the best of our knowledge, this paper demonstrates the first attempt to em-
ploy machine learning algorithms to automatically and quantitatively assess the
severity of pulmonary edema from chest x-ray images.

2   Methods

Let x ∈ Rn×n be a 2D x-ray image and y ∈ {0, 1, 2, 3} be the corresponding
edema severity label. Our dataset includes a set of N images x = {xi }N
                                                                      i=1 with
Semi-supervised Learning for Chest X-Ray Images                        3

           (a)                                      (b)

Fig. 1. Graphical model and inference algorithm. (a): Probabilistic graphical model,
where x represents chest x-ray image, z represents latent feature representation, and
y represents pulmonary edema severity. (b): Our computational model. We use neu-
ral networks for implementing the encoder, decoder, and regressor. The dashed line
(decoder) is used in training only. The network architecture is provided in the supple-
mentary material.

the first NL images annotated with severity labels y = {yi }N
                                                            i=1 . Here, we derive a
                                                              L

learning algorithm that constructs a compact probabilistic feature representation
z that is learned from all images and used to predict pulmonary edema severity.
Fig. 1 illustrates the Bayesian model and the inference algorithm.

Learning. The learning algorithm maximizes the log probability of the data
with respect to parameters θ:
                                    NL
                                    X                             N
                                                                  X
                 log p(x, y; θ) =         log p(xi , yi ; θ) +             log p(xi ; θ).            (1)
                                    i=1                          i=NL +1

We model z as a continuous latent variable with a prior distribution p(z), which
generates images and predicts pulmonary edema severity. Unlike [9] that con-
structs a separate encoder q(z|x, y) for each value of discrete label y, we use a
single encoder q(z|x) to capture image structure relevant to labels. Distribution
q(z|x) serves as a variational approximation for p(z|x, y) for the lower bound:

    L1 (θ; xi , yi ) = log p(xi , yi ; θ) − DKL (q(zi |xi ; θ)||p(zi |xi , yi )),
                                                                                                
                     = Eq(zi |xi ;θ) log p(xi , yi ; θ) + log p(zi |xi , yi ) − log q(zi |xi ; θ)
                                                                                           
                     = Eq(zi |xi ;θ) log p(xi , yi |zi ; θ) + log p(zi ) − log q(zi |xi ; θ)
                                                          
                     = Eq(zi |xi ;θ) log p(xi , yi |zi ; θ) − DKL (q(zi |xi ; θ)||p(zi )).

We assume that x, z, and y form a Markov chain, i.e., y ⊥                ⊥ x | z, and therefore
                                                                                             
   L1 (θ; xi , yi ) = Eq(zi |xi ;θE ) log p(xi |zi ; θD ) + Eq(zi |xi ;θE ) log p(yi |zi ; θR )
                       − DKL (q(zi |xi ; θE )||p(zi )),                                              (2)

where θE are the parameters of the encoder, θD are the parameters of the decoder,
and θR are the parameters of the regressor. Similarly, we have a variational lower
4       R. Liao et al.

bound for log p(xi ; θ):
                                                       
       L2 (θ; xi ) = Eq(zi |xi ;θE ) log p(xi |zi ; θD ) − DKL (q(zi |xi ; θE )||p(zi )).                    (3)
   By substituting Eq. (2) and Eq. (3) into Eq. (1), we obtain a lower bound for
the log probability of the data and aim to minimize the negative lower bound:
                       NL
                       X                           N
                                                   X
    J (θ; x, y) = −          L1 (θ; xi , yi ) −             L2 (θ; xi )
                       i=1                        i=NL +1
                   N
                   X                                          NL
                                                              X                                         
               =         DKL (q(zi |xi ; θE )||p(zi )) −             Eq(zi |xi ;θE ) log p(yi |zi ; θR )
                   i=1                                         i=1
                       N
                       X                                       
                   −         Eq(zi |xi ;θE ) log p(xi |zi ; θD ) .                                           (4)
                       i=1

Latent Variable Prior p(z). We let the latent variable prior p(z) be a multi-
variate normal distribution, which serves to regularize the latent representation
of images.

Latent Representation q(z|x). We apply the reparameterization trick used
in [10]. Conditioned on image xi , the latent representation becomes a multi-
variate Gaussian variable, zi |xi ∼ N (zi ; µi , Λi ), where µi is a D-dimensional
vector [µik ]D
             k=1 and Λi is a diagonal covariance matrix represented by its diago-
nal elements as [λ2ik ]D
                       k=1 . Thus, the first term in Eq. (4) becomes:

                                             D
                                          1X
                                             log λ2ik − µ2ik − λ2ik + const.
                                                                   
               JKL (θE ; xi ) = −                                                                            (5)
                                          2
                                            k=1

We implement the encoder as a neural network fE (x; θE ) that estimates the
mean and the variance of z|x. Samples of z can be readily generated from this
estimated Gaussian distribution. We use one sample per image for training the
model.

Ordinal Regression p(y|z). In radiology reports, pulmonary edema severity
is categorized into four groups: no/mild/ moderate/severe. Our goal is to assess
the severity of pulmonary edema as a continuous quantity. We employ ordinal
representation to capture the ordering of the categorical labels. We use a 3-bit
representation yi = [yij ]3j=1 for the four severity levels. The three bits represent
the probability of any edema, of moderate or severe edema, and of severe edema
respectively (i.e., “no” is [0, 0, 0], “mild” is [1, 0, 0], “moderate” is [1, 1, 0], and
“severe” is [1, 1, 1]). This encoding yields probabilistic output, i.e., both the
estimate of the edema severity and also uncertainty in the estimate. The three
bits are assumed to be conditionally independent given the image:
                                    3
                                    Y                                         1−yij
               p(yi |zi ; θR ) =          fRj (zi ; θR )yij 1 − fRj (zi ; θR )        ,
                                    j=1
Semi-supervised Learning for Chest X-Ray Images                             5

where yij is a binary label and fRj (zi ; θR ) is interpreted as the conditional prob-
ability p(yij = 1|zi ). fR (·) is implemented as a neural network. The second term
in Eq. (4) becomes the cross entropy:
                               3
                               X                                3
                                                                X                                     
 JR (θE , θR ; yi , zi ) = −         yij log fRj (zi ; θR ) −         (1 − yij ) log 1 − fRj (zi ; θR ) .
                               j=1                              j=1
                                                                                                           (6)

Decoding p(x|z). We assume that image pixels are conditionally independent
(Gaussian) given the latent representation. Thus, the third term in Eq. (4) be-
comes:

    JD (θE , θD ; xi , zi ) = − log N (xi ; fD (zi ; θD ), Σi )
                              1
                            = (xi − fD (zi ; θD ))T Σi−1 (xi − fD (zi ; θD )) + const.,                    (7)
                              2
where fD (·) is a neural network decoder that generates an image implied by the
latent representation z, and Σi is a diagonal covariance matrix.

Loss Function. Combining Eq. (5), Eq. (6) and Eq. (7), we obtain the loss
function for training our model:
                            N
                            X                        NL
                                                     X                                 N
                                                                                       X
J (θE , θR , θD ; x, y) =         JKL (θE ; xi ) +         JR (θE , θR ; yi , zi ) +         JD (θE , θD ; xi , zi )
                            i=1                      i=1                               i=1
                                  N D
                              1 XX
                                          log λ2ik − µ2ik − λ2ik
                                                                  
                       =−
                              2 i=1
                                     k=1
                                                                                                       
                              N L
                              X X     3                           X3                                 
                            −           yij log f j (zi ; θR ) +
                                                       R             (1 − yij ) log 1 − f j (zi ; θR ) 
                                                                                                       R
                               i=1       j=1                             j=1
                                     N
                              1X
                                    xi − fD (zi ; θD ))T Σi−1 (xi − fD (zi ; θD ) .
                                                                                 
                            +                                                                              (8)
                              2 i=1

We employ the stochastic gradient-based optimization procedure Adam [8] to
minimize the loss function. Our training procedure is outlined in the supplemen-
tary materiel. The pulmonary edema severity category extracted from radiology
reports is a discrete approximation of the actual continuous severity level. To
capture this, we compute the expected severity:

    ŷ = 0 × (1 − ŷ1 ) + 1 × (ŷ1 − ŷ2 ) + 2 × (ŷ2 − ŷ3 ) + 3 × ŷ3 = ŷ1 + ŷ2 + ŷ3 .

3    Implementation Details
The size of the chest x-ray images in our dataset varies and is around 3000×3000
pixels. We randomly rotate and translate the images (differently at each epoch)
6      R. Liao et al.

      No edema           Mild edema        Moderate edema        Severe edema

Fig. 2. Representative chest x-ray images with varying severity of pulmonary edema.

on the fly during training and crop them to 2048×2048 pixels as part of data
augmentation. We maintain the original image resolution to preserve subtle dif-
ferences between different levels of pulmonary edema severity.
    The encoder is implemented as a series of residual blocks [5]. The decoder
is implemented as a series of transposed convolutional layers, to build an out-
put image of the same size as the input image (2048×2048). The regressor is
implemented as a series of residual blocks with an averaging pooling layer fol-
lowed by two fully connected layers. The regressor output ŷ has 3 channels. The
latent representation z has a size of 128×128. During training, one sample is
drawn from z per image. The KL-loss (Eq. (5)) and the image reconstruction
error (Eq. (7)) in the loss function are divided by the latent feature size and
the image size respectively. The variances in Eq. (7) are set to 10, which gives a
weight of 0.1 to the image reconstruction error. The learning rate for the Adam
optimizer training is 0.001 and the minibatch size is 4. The model is trained on a
training dataset and evaluated on a separate validation dataset every few epochs
during training. The model checkpoint with the lowest error on the validation
dataset is used for testing. The neural network architecture is provided in the
supplementary material.

4   Experiments

Data. Approximately 330,000 frontal view x-ray images and their associated
radiology reports were collected as part of routine clinical care in the emergency
department of Beth Israel Deaconess Medical Center and subsequent in-hospital
stay. A subset of the image set has been released [6].
    In this work, we extracted the pulmonary edema severity labels from the re-
ports by searching for keywords that are highly correlated with a specific stage
of pulmonary edema. Due to the high variability of wording in radiology reports,
the same keywords can mean different clinical findings in varying disease context.
For example, perihilar infiltrate means moderate pulmonary edema for a heart
failure patient, but means pneumonia in a patient with a fever. To extract mean-
ingful labels from the reports using the keywords, we limited our label extraction
to a CHF cohort. This cohort selection yielded close to 30,000 images, of which
5,771 images could be labeled via our keyword matching. Representative images
Semi-supervised Learning for Chest X-Ray Images            7

Fig. 3. Summary of the results. Left table: Root mean squared errors and Pearson
correlation coefficients for each method. Right plot: Predicted edema severity scores on
each label category.

of each severity level are shown in Fig. 2. The data details are summarized in
the supplementary material.

Evaluation. We randomly split the images into training (4,537 labeled im-
ages, 334,664 unlabeled images), validation (628 labeled images), and test (606
labeled images) image sets. There is no patient overlap between the sets. Unla-
beled images of the patients in the validation and test sets are excluded from
training. The labeled data split is 80%/10%/10% into training/validation/test
respectively.
    We evaluated four methods: (i) Supervised : purely supervised training that
uses labeled images only; (ii) EM : supervised training with labeled images
that imputes labels for unlabeled images and minimizes the entropy of pre-
dictions [3]; (iii) DGM : Semi-supervised training with deep generative models
(multiple VAEs) as described in [9]; (iv) VAE R: Our method that learns proba-
bilistic feature representations from the entire image set with limited labels. For
the baseline supervised learning method, we investigated different neural net-
work architectures previously demonstrated for chest x-ray images [12, 13] and
did not find the network architecture changes the supervised learning results
significantly.
    We evaluate the methods on the test image set using the root mean squared
(RMS) error and Pearson correlation coefficient (CC).

Results. Fig. 3 summarizes the prediction performance of the four methods. The
method that jointly learns probabilistic feature representations outperforms the
other three models.

5    Conclusions

In this paper, we demonstrated a regression model augmented with a VAE
trained on a large image dataset with a limited number of labeled images. Our
results suggest that it is difficult for a generative model to learn distinct data
clusters for the labels that rely on subtle image features. In contrast, learning
compact feature representations jointly from images and limited labels can help
8       R. Liao et al.

inform prediction by capturing structure shared by the image distribution and
the conditional distribution of labels given images.
    We demonstrated the first attempt to employ machine learning algorithms to
automatically and quantitatively assess the severity of pulmonary edema from
chest x-ray images. Our results suggest that granular information about a pa-
tient’s status captured in medical images can be extracted by machine learning
algorithms, which promises to enable clinicians to deliver better care by quantita-
tively summarizing an individual patient’s medical history, for example response
to different treatments. This work also promises to enable clinical research stud-
ies that require quantitative summarization of patient status.

References

 1. Chakko, S., Woska, D., Martinez, H., De Marchena, E., Futterman, L., Kessler,
    K.M., Myerburg, R.J.: Clinical, radiographic, and hemodynamic correlations in
    chronic congestive heart failure: conflicting results may lead to inappropriate care.
    The American journal of medicine 90(1), 353–359 (1991)
 2. Francis, G.S., Cogswell, R., Thenappan, T.: The heterogeneity of heart failure: will
    enhanced phenotyping be necessary for future clinical trial success? (2014)
 3. Grandvalet, Y., Bengio, Y.: Semi-supervised learning by entropy minimization. In:
    Advances in neural information processing systems. pp. 529–536 (2005)
 4. Hammon, M., Dankerl, P., Voit-Höhne, H.L., Sandmair, M., Kammerer, F.J., Uder,
    M., Janka, R.: Improving diagnostic accuracy in assessing pulmonary edema on
    bedside chest radiographs using a standardized scoring approach. BMC anesthesi-
    ology 14(1), 94 (2014)
 5. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition.
    In: Proceedings of the IEEE conference on CVPR. pp. 770–778 (2016)
 6. Johnson, A.E., Pollard, T.J., Berkowitz, S., Greenbaum, N.R., Lungren, M.P.,
    Deng, C.y., Mark, R.G., Horng, S.: MIMIC-CXR: A large publicly available
    database of labeled chest radiographs. arXiv preprint arXiv:1901.07042 (2019)
 7. Kamnitsas, K., Castro, D., Le Folgoc, L., Walker, I., Tanno, R., Rueckert, D.,
    Glocker, B., Criminisi, A., Nori, A.: Semi-supervised learning via compact latent
    space clustering. In: International Conference on Machine Learning. pp. 2464–2473
    (2018)
 8. Kingma, D.P., Ba, J.: Adam: A method for stochastic optimization. arXiv preprint
    arXiv:1412.6980 (2014)
 9. Kingma, D.P., Mohamed, S., Rezende, D.J., Welling, M.: Semi-supervised learn-
    ing with deep generative models. In: Advances in Neural Information Processing
    Systems. pp. 3581–3589 (2014)
10. Kingma, D.P., Welling, M.: Auto-encoding variational bayes. arXiv preprint
    arXiv:1312.6114 (2013)
11. Mahdyoon, H., Klein, R., Eyler, W., Lakier, J.B., Chakko, S., Gheorghiade, M.:
    Radiographic pulmonary congestion in end-stage congestive heart failure. Ameri-
    can Journal of Cardiology 63(9), 625–627 (1989)
12. Rajpurkar, P., Irvin, J., Zhu, K., Yang, B., Mehta, H., Duan, T., Ding, D., Bagul,
    A., Langlotz, C., Shpanskaya, K., et al.: Chexnet: Radiologist-level pneumonia de-
    tection on chest x-rays with deep learning. arXiv preprint arXiv:1711.05225 (2017)
Semi-supervised Learning for Chest X-Ray Images         9

13. Wang, X., Peng, Y., Lu, L., Lu, Z., Bagheri, M., Summers, R.M.: Chestx-ray8:
    Hospital-scale chest x-ray database and benchmarks on weakly-supervised classi-
    fication and localization of common thorax diseases. In: Proceedings of the IEEE
    conference on CVPR. pp. 3462–3471. IEEE (2017)
10        R. Liao et al.

Supplementary Material

Algorithm 1 Stochastic learning with minibatch for the model in Fig. 1
     θE , θR , θD ← Initialize parameters
     repeat
          repeat                                                       . Training on labeled data
               yn , xn ← Random minibatch of n image and label pairs
               zn ← Samples from q(z|xn ; θE ) (one sample per image)
               g ← ∇θE JKL (θE ; xn )+∇θE ,θR JR (θE , θR ; yn , zn )+∇θE ,θD JD (θE , θD ; xn , zn )
               θE , θR , θD ← Update parameters using gradients g (e.g., Adam [8])
          until the last minibatch of labeled set
          repeat                                                   . Training on unlabeled data
               xn ← Random minibatch of n images
               zn ← Samples from q(z|xn ; θE ) (one sample per image)
               g ← ∇θE JKL (θE ; xn ) + ∇θE ,θD JD (θE , θD ; xn , zn )
               θE , θD ← Update parameters using gradients g (e.g., Adam [8])
          until the last minibatch of unlabeled set
     until convergence

Table 1. Summary of chest x-ray image data and pulmonary edema severity labels.

 Patient cohort Edema severity                 Number of patients        Number of images
                No edema                       441 (0.69%)               1003 (0.29%)
 Congestive     Mild edema                     979 (1.52%)               3058 (0.89%)
 heart failure  Moderate edema                 478 (0.74%)               1414 (0.41%)
 (CHF)          Severe Edema                   116 (0.18%)               296    (0.09%)
                Unlabeled                      1880 (2.92%)              22021 (6.40%)
 Others         Unlabeled                      60406 (93.95%)            316479 (91.92%)
Semi-supervised Learning for Chest X-Ray Images   11
12        R. Liao et al.

           Encoder                             Decoder                               Regressor
         X-ray, 2048x2048           z_mean, 16x16x64 z_var, 16x16x64        z_mean, 16x16x64 z_var, 16x16x64

           7x7 conv, 8, /4                             Sample and reshape                      Sample and reshape

                                                       z, 8x8x256                               z, 128x128x1
           5x5 conv, 8, /2
             5x5 conv, 8                3x3 conv_transpose, 128, x2                    3x3 conv, 8, /2
                                                                                        3x3 conv, 8
             5x5 conv, 8
             5x5 conv, 8                 3x3 conv_transpose, 64, x2                     3x3 conv, 8
                                                                                        3x3 conv, 8

          3x3 conv, 16, /2               3x3 conv_transpose, 32, x2
            3x3 conv, 16                                                              3x3 conv, 16, /2

                                         3x3 conv_transpose, 16, x2                     3x3 conv, 16
            3x3 conv, 16
            3x3 conv, 16                                                                3x3 conv, 16
                                          3x3 conv_transpose, 8, x2                     3x3 conv, 16

          3x3 conv, 32, /2
            3x3 conv, 32                  3x3 conv_transpose, 4, x2                   3x3 conv, 32, /2
                                                                                        3x3 conv, 32
            3x3 conv, 32
                                          3x3 conv_transpose, 2, x2
            3x3 conv, 32                                                                3x3 conv, 32
                                                                                        3x3 conv, 32
                                         3x3 conv_transpose, 1, x2
          3x3 conv, 64, /2
            3x3 conv, 64                                                              3x3 conv, 64, /2
                                                                                        3x3 conv, 64
            3x3 conv, 64               Reconstructed x-ray, 2048x2048

            3x3 conv, 64                                                                3x3 conv, 64
                                                                                        3x3 conv, 64

          3x3 conv, 128, /2
           3x3 conv, 128                                                             3x3 conv, 128, /2
                                                                                       3x3 conv, 128
           3x3 conv, 128
           3x3 conv, 128                                                               3x3 conv, 128
                                                                                       3x3 conv, 128
                   Reshape

                                                                                         pool, /2
 z_mean, 16x16x64 z_var, 16x16x64

                                                                                  fully connected, 512->64

                                                                                   fully connected, 64>3

                                                                                           y, 3x1

     Fig. Neural network architectures for the encoder, the decoder, and the regressor.
You can also read