Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models

Page created by Rita Hill
 
CONTINUE READING
Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models
Proactive Pseudo-Intervention:
                                                             Contrastive Learning For Interpretable Vision Models

                                                             Dong Wang1 , Yuewei Yang1 , Chenyang Tao1 , Zhe Gan2 , Liqun Chen1 ,
                                                                     Fanjie Kong1 , Ricardo Henao1 , Lawrence Carin1
                                                                       1
                                                                          Duke University, 2 Microsoft Corporation
                                             {dong.wang363, yuewei.yang, chenyang.tao, liqun.chen, ricardo.henao, lcarin}@duke.edu,
arXiv:2012.03369v2 [cs.CV] 29 Apr 2021

                                                                                          zhe.gan@microsoft.com

                                                                 Abstract

                                             Deep neural networks excel at comprehending complex
                                         visual signals, delivering on par or even superior perfor-
                                         mance to that of human experts. However, ad-hoc visual
                                         explanations of model decisions often reveal an alarm-
                                         ing level of reliance on exploiting non-causal visual cues
                                         that strongly correlate with the target label in training
                                         data. As such, deep neural nets suffer compromised gen-
                                         eralization to novel inputs collected from different sources,
                                         and the reverse engineering of their decision rules of-
                                         fers limited interpretability. To overcome these limitations,
                                         we present a novel contrastive learning strategy called
                                         Proactive Pseudo-Intervention (PPI) that leverages proac-
                                         tive interventions to guard against image features with no      Figure 1: Interpretation for the bird-classification mod-
                                         causal relevance. We also devise a novel causally in-           els using saliency maps generated by LRP (layer-wise
                                         formed salience mapping module to identify key image pix-       relevance propagation) and our model PPI. LRP shows
                                         els to intervene, and show it greatly facilitates model in-     that naively trained deep model makes decisions based on
                                         terpretability. To demonstrate the utility of our proposals,    the background cues (habitat, e.g., rocks, bulrushes) that
                                         we benchmark on both standard natural images and chal-          are spuriously correlated with the bird species, while our
                                         lenging medical image datasets. PPI-enhanced models con-        causally informed PPI mostly focuses on the bird anatomy,
                                         sistently deliver superior performance relative to compet-      that generalizes beyond the natural habitat.
                                         ing solutions, especially on out-of-domain predictions and
                                         data integration from heterogeneous sources. Further, our       diagnosis [48], and autonomous driving [8], among others.
                                         causally trained saliency maps are more succinct and mean-         While deep learning solutions have been positively rec-
                                         ingful relative to their non-causal counterparts.               ognized for their ability to learn black-box models in a
                                                                                                         purely data driven manner, their very nature makes them
                                                                                                         less credible for their inability to communicate the reason-
                                         1. Introduction                                                 ing for making predictions in a way that is comprehensi-
                                                                                                         ble to humans [26, 43]. This denies consequential appli-
                                            Deep neural networks hold great promise in applications      cations where the reliability and trustworthiness of a pre-
                                         requiring the analysis and comprehension of complex im-         diction are of primary concern and require expert audit,
                                         agery. Recent advances in hardware, network architectures,      e.g., in healthcare [48]. To stimulate widespread use of
                                         and model optimization, along with the increasing avail-        deep learning models, a means of interpreting predictions
                                         ability of large-scale annotated datasets [31, 12, 11], have    is necessary. However, model interpretation techniques of-
                                         enabled these models to match and sometimes outperform          ten reveal a concerning fact, that deep learning models tend
                                         human experts on a number of tasks, including natural im-       to assimilate spurious correlations that do not necessarily
                                         age classification [32], objection recognition [20], disease    capture the causal relationship between the input (image)
Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models
and output (label) [63]. This issue is particularly notable        mines the label. If we were provided with an image of a
in small-sample-size (weak supervision) scenarios or when          bird in an environment foreign to the images in the training
the sources of non-informative variation are overwhelming,         set, the model will be unable to make a reliable prediction,
thus likely to cause severe overfitting. These can lead to         thus causing robustness concerns. This generalization issue
catastrophic failures on deployment [19, 64].                      worsens with a smaller training sample size. On the other
    A growing recognition of the issues associated with the        hand, saliency maps from our PPI-enhanced model success-
lack of interpretable predictions is well documented in re-        fully focus on the bird anatomy, and thus will be robust to
cent years [1, 26, 43]. Such phenomenon has energized re-          environmental changes captured in the input images.
searchers to actively seek creative solutions. Among these,           PPI addresses causally-informed reasoning, robust learn-
two streams of work, namely saliency mapping [68, 54, 10]          ing, and model interpretation in a unified framework. A new
and causal representation learning (CRL) [28, 65, 2], stand        saliency mapping method, named Weight Back Propagation
out as some of the most promising directions. Specifically,        (WBP), is also proposed to generate more concentrated in-
saliency mapping encompasses techniques for post hoc vi-           tervention mask for PPI training. The key contributions of
sualizations on the input (image) space to facilitate the inter-   this paper include:
pretation of model predictions. This is done by projecting            • An end-to-end contrastive representation learning
the key features used in prediction back to the input space,            strategy PPI that employs proactive interventions to
resulting in the commonly known saliency maps. Impor-                   identify causally relevant features.
tantly, these maps do not directly contribute to model learn-
                                                                      • A fast and architecture-agnostic saliency mapping
ing. Alternatively, CRL solutions are built on the princi-
                                                                        module WBP that delivers better visualization and lo-
ples of establishing invariance from the data, and it entails
                                                                        calization performance.
teasing out sources of variation that are spuriously associ-
ated with the model output (labels). CRL models, while                • Experiments demonstrating significant performance
emphasizing the differences between causation and corre-                boosts from integrating PPI and WBP relative to com-
lation, are not subject to the rigor of causal inference ap-            peting solutions, especially on out-of-domain predic-
proaches, because their goal is not to obtain accurate causal           tions, data integration with heterogeneous sources and
effect estimates but rather to produce robust models with               model interpretation.
better generalization ability relative to their naively learned
counterparts [2].                                                  2. Background
    In this work, we present Proactive Pseudo-Intervention         Visual Explanations Saliency mapping collectively refers
(PPI), a solution that accounts for the needs of causal repre-     to a family of techniques to understand and interpret black-
sentation identification and visual verification. Our key in-      box image classification models, such as deep neural net-
sight is the derivation of causally-informed saliency maps,        works [1, 26, 43]. These methods project the model un-
which facilitate visual verification of model predictions and      derstanding of the targets, i.e., labels, and their predictions
enable learning that is robust to (non-causal) associations.       back to the input space, which allows for the visual inspec-
While true causation can only be established through ex-           tion of automated reasoning and for the communication of
perimental interventions, we leverage tools from contrastive       predictive visual cues to the user or human expert, aiming
representation learning to synthesize pseudo-interventions         to shed model insights or to build trust for deep-learning-
from observational data. Our procedure is motivated by the         based systems.
causal argument: perturbing the non-causal features will not           In this study, we focus on post hoc saliency map-
change the target label.                                           ping strategies, where saliency maps are constructed given
    To motivate, in Figure 1 we present an example to illus-       an arbitrary prediction model, as opposed to relying on
trate the benefits of producing causally-informed saliency         customized model architectures for interpretable predic-
maps. In this scenario, the task is to classify two bird           tions [19, 64], or to train a separate module to explicitly
species (A and B) in the wild. Due to the differences in their     produce model explanations [19, 21, 6, 18, 53]. Popu-
natural habitats, A-birds are mostly seen resting on rocks,        lar solutions under this category include: activation map-
while B-birds are more commonly found among bulrushes.             ping [69, 51], input sensitivity analysis [53], and rele-
A deep model, trained naively, will tend to associate the          vance propagation [4]. Activation mapping based methods
background characteristics with the labels, knowing these          fail at visualizing fine-grained evidence, which is particu-
strongly correlate with the bird species (labels) in the train-    larly important in explaining medical classification mod-
ing set. This is confirmed by the saliency maps derived from       els [14, 51, 60]. Input sensitivity analysis based meth-
the layer-wise relevance propagation (LRP) techniques [4]:         ods produce fine-grained saliency maps. However, these
the model also attends heavily on the background features,         maps are generally less concentrated [10, 18] and less inter-
while the difference in bird anatomy is what causally deter-       pretable. Relevance propagation based methods, like LRP
Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models
and its variants, use complex rules to prioritize positive or     the predictor and loss functions without introducing a new
large relevance, making the saliency maps visually appeal-        critic [59]. Notably, current CL methods are not immune
ing to human. However, our experiments demonstrate that           to spurious associations, a point we wish to improve in this
LRP and its variants highlight spuriously correlated features     work.
(boarderlines and backgrounds). By contrast, our WBP
backpropagates the weights through layers to compute the          Causality and Interventions. From a causality perspec-
contributions of each input pixel, which is truly faithful to     tive, humans learn via actively interacting with the environ-
the model, and WBP tends to highlight the target objects          ment. We intervene and observe changes in the outcome
themselves rather than the background. At the same time,          to infer causal dependencies. Machines instead learn from
the simplicity and efficiency makes WBP easily work with          static observations that are unable to inform the structural
other advanced learning strategies for both model diagnosis       dependencies for causal decisions. As such, perturbations
and improvements during training.                                 to the external factors, e.g., surroundings, lighting, view-
                                                                  ing angles, may drastically alter machine predictions, while
    Our work is in a similar spirit to [18, 10, 6, 60], where
                                                                  human recognition is less susceptible to such nuisance vari-
meaningful perturbations have been applied to the im-
                                                                  ations. Formally, such difference is
                                                                                                     Pbest explained with the
age during model training, to improve prediction and fa-
                                                                  do-notation [41]: P(Y |do(x)) = z P(Y |X = x, z)P(z),
cilitate interpretation. Poineering works have relied on
                                                                  where we identify x as the features, e.g., an object in the
user supplied “ground-truth” explainable masks to perturb
                                                                  image, and z as the confounders, e.g., background in the
[46, 35, 45], however such manual annotations are costly
                                                                  example above. Note that P(Y |do(x)) is fundamentally
and hence rarely available in practice. Alternatively, per-
                                                                  different from the conditional likelihood P(Y |X = x) =
turbations can be computed by solving an optimization for         P
each image. Such strategies are costly in practice and also          z P(Y |X = x, z)P(z|X = x), which machine uses for
                                                                  associative reasoning.
do not effectively block spurious features. Very recently,
exploratory effort has been made to leverage the tools from           Unfortunately, carrying out real interventional studies,
counterfactual reasoning [21] and causal analysis [40] to de-     i.e., randomized control trials, to intentionally block non-
rive visual explanations, but do not lend insights back to        causal associations, is oftentimes not a feasible option for
model training. Our work represents a fast, principled solu-      practical considerations, e.g., due to cost and ethics. This
tion that overcomes the above limitations. It automatically       work instead advocates the application of synthetic inter-
derives explainable masks faithful to the model and data,         ventions to uncover the underlying causal features from ob-
without explicit supervision from user-generated explana-         servational data. Specifically, we proactively edit x and its
tions.                                                            corresponding label y in a data-driven fashion to encourage
                                                                  the model to learn potential causal associations. Our pro-
Contrastive Learning. There has been growing interest             posal is in line with the growing appreciation for the signif-
in exploiting contrastive learning (CL) techniques for rep-       icance of establishing causality in machine learning models
resentations learning [39, 9, 24, 29, 59]. Originally devised     [49]. Via promoting invariance [2], such causally inspired
for density estimation [23], CL exploits the idea of learn-       solutions demonstrate superior robustness to superficial fea-
ing by comparison to capture the subtle features of data,         tures that do not generalize [62]. In particular, [57, 67]
i.e., positive examples, by contrasting them with negative        showed the importance and effectiveness of accounting for
examples drawn from a carefully crafted noise distribution.       interventional perspectives. Our work brings these causal
These techniques aim to avoid representation collapse, or to      views to construct a simple solution that explicitly opti-
promote representation consistency, for downstream tasks.         mizes visual interpretation and model robustness.
Recent developments, both empirical and theoretical, have
connected CL to information-theoretic foundations [59, 22],       3. Proactive Pseudo-Intervention
thus establishing them as a suite of de facto solutions for un-
supervised representation learning [9, 24].                           Below we describe the construction of Proactive
    The basic form of CL is essentially a binary classifica-      Pseudo-Intervention (PPI), a causally-informed contrastive
tion task specified to discriminate positive and negative ex-     learning scheme that seeks to simultaneously improve the
amples. In such a scenario, the binary classifier is known        accuracy, robustness, generalization and interpretability of
as the critic function. Maximizing the discriminative power       deep-learning-based computer vision models.
wrt the critic and the representation sharpens the feature en-        The PPI learning strategy, schematically summarized in
coder. Critical to the success of CL is the choice of ap-         Figure 2, consists of three main components: (i) a saliency
propriate noise distribution, where the challenging nega-         mapping module that highlights causally relevant features;
tives, i.e., those negatives that are more similar to positive    (ii) an intervention module that synthesizes contrastive
examples, are often considered more effective contrasts.          samples; and (iii) the prediction module, which is standard
In its more generalized form, CL can naturally repurpose          in recent vision models, e.g., VGG [55], ResNet [25], and
Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models
Figure 2: Illustration of the proposed PPI learning strategy. Input images are intervened by removing the saliency map based
masks, which alters the input label (e.g., negative control). For positive contrast, we use the original input as well as an input
masked with a random slaiency map. We use WBP for the generation of saliency maps.

Inception Net [58]. Motivated by the discussions from our          where σ and ω > 0 are the threshold and scaling param-
introduction, PPI establishes a feedback loop between the          eters, respectively. We set the scaling ω so that T (s) will
saliency map module and the prediction module, which is            result in a sharp transition from 0 to 1 near σ. Using (1) we
interfaced by the synthesized contrastive examples in the          define the contrastive loss as
intervention module. Under this configuration, the predic-                                      X
tion module is encouraged to modify its predictions only                           Lcon (θ) =      `(x∗i , ¬y; fθ ),          (3)
                                                                                                 i
when provided with causally-relevant synthetic interven-
tions. Note that components (i) and (ii) do not involve any        where fθ is the prediction module, `(x, y; fθ ) is the loss
additional parameters or neural network modules, which             function we wish to optimize, e.g. cross entropy, and ¬ is
makes our strategy readily applicable to the training of vir-      used to denote that the original class label has been flipped.
tually any computer vision task without major customiza-           In the binary case, ¬y = 1 − y, and in the multi-class case
tion. Details of these building blocks are given below.            it can be interpreted accordingly, e.g., using a one vs. oth-
3.1. Synthetic causal interventions for contrasts                  ers cross entropy loss. In practice, we set `(x, y; fθ ) =
                                                                   −`(x, y; fθ ). We will show in the experiments that this
   Key to our formulation is the design of a synthetic in-         simple and intuitive causal masking strategy works well in
tervention strategy that generates contrastive examples to         practice (see Tables 2 and 4, and Figure 4). Alternatively,
reinforce causal relevance during model training. Given a          we also consider a hard-masking approach in which a mini-
causal saliency map sm (x) for an input x wrt label y = m,         mal bounding box covering the thresholded saliency map is
where m = 1, . . . , M , and M is the number of classes, the       removed. See the Appendix for details.
synthetic intervention consists of removing (replacing with            Note that we are making the implicit assumption that the
zero) the causal information from x contained in sm (x),           saliency map is uniquely determined by the prediction mod-
and then using it as the contrastive learning signal.              ule fθ . While optimizing (3) explicitly attempts to improve
   For now, let us assume the causal salience map sm (x)           the fit of the prediction module fθ , it also implicitly informs
is known; the procedure to obtain the saliency map will be         the causal saliency mapping. This is sensible because if a
addressed in the next section. For notational clarity, we use      prediction is made using non-causal features, which implies
subscript i to denote entities associated with the i-th training   the associated saliency map sm (x) is also non-causal, then
sample, and omit the dependency on learnable parameters.           we should expect that after applying sm (x) to x using (1),
To remove causal information from xi and obtain a negative         we can still expect to make the correct prediction, i.e., the
contrast x∗i , we apply the following soft-masking                 true label, for both positive (the original) and negative (the
                                                                   intervened) samples.
                x∗i = xi − T (sm (xi ))     xi ,            (1)
                                                                   Saliency map regularization. Note that naively optimizing
where T (·) is a differentiable masking function and de-           (3) can lead to degenerate solutions for which any saliency
notes element-wise (Hadamard) multiplication. Specifi-             map that satisfies the causal sufficiency, i.e., encompassing
cally, we use the thresholded sigmoid for masking:                 all causal features, is a valid causal saliency map. For ex-
                                                                   ample, a trivial solution where the saliency map covers the
                                    1                              entire image may be considered causal. To protect against
        T (sm (xi )) =                             ,        (2)    such degeneracy, we propose to regularize the L1 -norm of
                         1 + exp(−ω(sm (xi ) − σ))
Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models
Table 1: WBP update rules for common transformations.                W̃ l , which we call the saliency matrix, satisfying,

    Transformation             G(·)                                                  xL = W̃ l xl ,   ∀l ∈ [0, . . . , L],         (6)
    Activation Layer           W̃ l = h ◦ W̃ l+1                      where xL is an M -dimensional vector corresponding to the
    FC Layer                   W̃ l = W̃ l+1 W l                      M distinct classes in y. Though presented in a matrix form
                                                     T0,1
    Convolutional Layer        W̃ l = W̃ l+1 ⊗ [W l ]f lip2,3
                                                                      in a slight abuse of notation, i.e., the instantiation of the op-
    BN Layer                   W̃ l = W̃σ γ
                                           l+1
                                                                      erator W̃ l effectively depends on the input x, thus all non-
    Pooling Layer              Relocate/Distribute W̃ l+1             linearities have been effectively absorbed into it. We posit
                                                                      that for an object associated with a given label y = m, its
                                                                      causal features are subsumed in the interactions between the
the saliency map to encourage succinct (sparse) representa-           m-th row of W̃ 0 and input x, i.e.,
tions, i.e., Lreg = ksm k1 , for m = 1, . . . , M .
                                                                                        [sm (x)]k = [W̃ 0 ]mk [x]k ,               (7)
Adversarial positive contrasts. Another concern with
solely optimizing (3) is that models can easily overfit to the        where [sm (x)]k denotes the k-th element of the saliency
intervention, i.e., instead of learning to capture causal rele-       map sm (x) and [W̃ 0 ]mk is a single element of W̃ 0 . A key
vance, the model learns to predict interventional operations.         observation for computation of W̃ l is that it can be done
For example, the model can learn to change its prediction             recursively. Specifically, let gl (xl ) be the transformation
when it detects that the input has been intervened, regard-           at the l-th layer, e.g., an affine transformation, convolution,
less of whether the image is missing causal features. So              activation, normalization, etc., then it holds that
motivated, we introduce adversarial positive contrasts:
                                                                                  W̃ l+1 xl+1 = W̃ l+1 gl (xl ) = W̃ l xl .        (8)
           x0i   = xi − T (sm (xj ))       xi ,    i 6= j,      (4)
                                                                      This allows for recursive computation of W̃ l via
where we intervene with a false saliency map, i.e., sm (xj )
is the saliency map from a different input xj , while still                         W̃ l = G(W̃ l+1 , gl ),   W̃ L = 1,            (9)
encouraging the model to make the correct prediction via
                               X                                      where G(·) is the update rule. We list the update rules for
                   Lad (θ) =       `(x0i , y; fθ ) ,            (5)   common transformations in deep networks in Table 1, with
                               i                                      corresponding derivations detailed below.
where x0i is the adversarial positive contrast. The complete
                                                                      Fully-connected (FC) layer. The FC transformation is the
loss for the proposed model, L = Lcls +Lcon +Lreg +Lad ,
                                                                      most basic operation in deep neural networks. Below we
consists of the contrastive loss in (3), the regularization loss,
                                                                      omit the bias term as it does not directly interact with the
Lreg , and the adversarial loss in (5).
                                                                      input. Assuming gl (xl ) = W l xl , it is readily seen that
3.2. Saliency Weight Backpropagation
                                                                           W̃ l+1 xl+1 = W̃ l+1 gl (xl ) = (W̃ l+1 W l )xl ,      (10)
    In order to generate saliency maps that inform decision-
driving features in the (raw) pixel space, we describe                so W̃ l = W̃ l+1 W l . Graphical illustration with standard
Weight Back Propagation (WBP), a novel computationally                affine mapping and ReLU activation can be found in the
efficient scheme for saliency mapping applicable to arbi-             appendix.
trary neural architectures. WBP evaluates individual contri-
                                                                      Nonlinear activation layer. Considering that an activa-
butions from each pixel to the final class-specific prediction,
                                                                      tion layer simply rescales the saliency weight matrices, i.e.,
and we empirically find the results to be more causally-
                                                                      xl+1 = gl (xl ) = hl ◦ xl , where ◦ is the composition opera-
relevant relative to competing solutions based on human
                                                                      tor, we obtain W̃ l = h ◦ W̃ l+1 . Using the ReLU activation
judgement.
                                                                      as a concrete example, we have h(xl ) = 1{xl ≥ 0}.
    To simplify our presentation, we first consider a vector
input and a linear mapping. Let xl be the internal repre-             Convolutional layer. The convolution is a generalized form
sentation of the data at the l-th layer, with l = 0 being             of linear mapping. In practice, convolutions can be ex-
the input layer, i.e., x0 = x, and l = L being the penul-             pressed as tensor products of the form W̃ l = W̃ l+1 ⊗
                                                                            T0,1
timate logit layer prior to the softmax transformation, i.e.,         [W l ]f lip2,3
                                                                                     , where W l ∈ RD2 ×D1 ×(2S+1)×(2S+1) is the
P(y|x) = softmax(xL ). To assign the relative importance              convolution kernel, T0,1 is the transpose in dimensions 0
to each hidden unit in the l-th layer, we notationally col-           and 1 and f lip2,3 is an exchange in dimensions 2 and 3.
lapse all transformations after l into an operator denoted by         See the Appendix for details.
Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models
Figure 3: Visualization of the inferred saliency maps. Left: CUB dataset. Right: ImageNet dataset.

Pooling and normalization layer. Summarization and               Table 2: Performance improvements achieved by training
standardization are two other essential operations for the       with PPI on CUB, CIFAR-10, and GA dataset. We re-
success of deep neural networks, achieved by pooling and         port means and standard deviations (SDs) from 5-fold cross-
batch normalization (BN) techniques, respectively. They          validation for GA prediction.
too can be considered as special instantiations of linear op-
erations. Here we summarize the two most popular opera-            Models             CUB     Cifar-10    GA
tions in Table 1.                                                                     (Acc)   (Acc)       (AUC)
                                                                   Classification     0.662   0.881       0.877 ± 0.040
4. Experiments                                                     +PPIGradient       0.673   0.885       0.890 ± 0.035
    To validate the utility of our approach, we consider both      +PPILRP            0.680   0.891       0.895 ± 0.037
natural and medical image datasets, and compare it to ex-          +PPIGradCAM        0.683   0.895       0.908 ± 0.036
isting state-of-the-art solutions. All the experiments are im-     +PPIW BP           0.696   0.901       0.925 ± 0.023
plemented in PyTorch. The source code will be available at
https://github.com/author_name/PPI. Due to                       for details about the masking parameters σ and ω.
space limitation, details of the experimental setup and addi-
tional analyses are deferred to the Appendix.                    4.1. Natural Image Datasets
Datasets. We present our findings on five represen-              Classification Gains In this experiment, we investigate
tative datasets: (i) CIFAR-10 [31]; (ii) ImageNet                how the different pairings of PPI and saliency mapping
(ILSVRC2012) [47]; (iii) CUB [61], a natural image               schemes (i.e., GradCAM, LRP, WBP) affect performance.
dataset with over 12k photos for classification of 200 bird      In Table 2, the first row represents VGG11 model trained
species in the wild, heavily confounded by the background        with only classification loss, and the following rows repre-
characteristics; (iv) GA [34], a new medical image dataset       sent VGG11 trained with PPI with different saliency map-
for the prediction of geographic atrophy (GA) using 3D op-       ping schemes. We see consistent performance gains in ac-
tical coherence tomography (OCT) image volumes, char-            curacy via incorporating PPI training on both CUB and
acterized by small sample size (275 subjects) and highly         CIFAR-10 datasets. The gains are mostly significant when
heterogeneous (collected from 4 different facilities); and       using our WBP for saliency mapping (improving the accu-
(v) LIDC-IDRI [33], a public medical dataset of 1, 085           racy from 0.662 to 0.696 on CUB, and from 0.881 to 0.901
lung lesion CT images annotated by 4 radiologists. Detailed      on CIFAR-10.
specifications are described in the Appendix.                    Model Interpretability In this task, we want to qualita-
                                                                 tively and quantitatively compare the causal relevance of
Baselines. The following set of popular saliency mapping
                                                                 saliency maps generated by our proposed model and its
schemes are considered as comparators for the proposed ap-
                                                                 competitors. In Figure 3, we show the saliency maps
proach: (i) Gradient: standard gradient-based salience map-
                                                                 produced by different approaches for a VGG11 model
ping; (ii) GradCAM [51]: gradient-weighted class activa-
                                                                 trained on CUB. Visually, gradient-based solutions (Grad
tion mapping; (iii) LRP [4]: layer-wise relevance propaga-
                                                                 and GradCAM) tend to yield overly dispersed maps, in-
tion and its variants.
                                                                 dicating a lack of specificity. LRP gives more appealing
Hyperparameters. The final loss of the proposed model is         saliency maps. However, these maps also heavily attend
a weighted summation of four losses: L = Lcls +w1 Lcon +         to the spurious background cues that presumably help with
w2 Lreg +w3 Lad . The weights are simply balanced to match       predictions. When trained with PPI, the saliency maps at-
the magnitude of Lcls , i.e., w3 = 1, w2 = 0.1 and w1 = 1        tend to birds body, and with WBP, the saliency maps focus
(CUB), = 1 (GA), and = 10 (LIDC). See Appendix Sec B             on the causal related pixels.
Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models
Figure 4: Saliency maps on GA dataset based on models trained with PPI and without PPI. Maps of models trained with PPI
are more clinically relevant by focusing on retinal layers likely to contain abnormalities or lesions, and more concentrated.

    To quantitatively evaluate the causal relevance of com-
peting saliency maps, we adopt the evaluation scheme pro-
posed in [26], consisting of masking out the contribut-
ing saliency pixels and then calculating the reduction in
prediction score. A larger reduction is considered better
for accurately capturing the pixels that ‘cause’ the pre-
diction. Results are summarized in Figure 5a, where we
progressively remove the top-k saliency points, with k =
100, 500, 1000, 5000, 10000 (10000 ≈ 6.6% of all pixels),                    (a) CUB                    (b) ImageNet
from the CUB test input images. Our PPI consistently out-         Figure 5: Quantitative evaluations of causal relevance of
performs its counterparts, with its lead being most substan-      competing saliency maps (higher is better).
tial in the low-k regime. Notably, for large k, PPI removes
nearly all predictive signal. This implies PPI specifically
targets the causal features. Quantitative evaluation with ad-     saliency mapping schemes (i.e., Grad, GradCAM, LRP,
ditional metrics are provided in the Appendix.                    WBP) work with PPI. For WBP, we also tested the bound-
    To test the performance of WBP itself (without being          ing box variant, denoted as WBP (box) (see the Appendix
trained with PPI), we compare WBP with different ap-              for details). In Table 2, we see consistent performance gains
proaches for a VGG11 model trained on ImageNet from Py-           in AUC score via incorporating PPI training (from 0.877 to
Torch model zoo. Figure 3(left) shows that saliency maps          0.925, can be improve to 0.937 by PPI with WBP(box)),
generated by WBP more concentrate on objects themselves.          accompanied by the reductions in model variation evalu-
Also, thanks to the fine resolution of WBP, the model pays        ated by the standard deviations of AUC from the five-fold
more attention to the patterns on the fur to identify the leop-   cross-validation. The gains are most significant when us-
ard (row 1). This is more visually consistent with human          ing our WBP for saliency mapping. We further compare the
judgement. Figure 5b demonstrates WBP identifies more             saliency maps generated by these different combinations.
causal pixels on ImageNet validation images.                      We see that without the additional supervision from PPI,
                                                                  competing solutions like Grad, GradCAM and LRP some-
4.2. OCT-GA: Geographic Atrophy Classification
                                                                  times yield non-sensible saliency maps (attending to im-
    Next we show how the proposed PPI handles the chal-           age corners). Overall, PPI encourages more concentrated
lenges of small training data and heterogeneity in medical        and less noisy saliency maps. Also, different PPI-based
image datasets. In this experiment (with our new dataset,         saliency maps agree with each other to a larger extent. Our
that we will make public), each OCT volume image con-             findings are also verified by experts (co-authors, who are
sists of 100 scans of a 512 × 1000 sized image [5]. We            ophthalmologists specializing in GA) confirming that the
use a multi-view CNN model [56] to process such 3D                PPI-based saliency maps are clinically relevant by focusing
OCT inputs, and use it as our baseline solution (see the          on retinal layers likely to contain abnormalities or lesions.
Appendix for details). We investigate how the different           These results underscore the practical value of the proposed
Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models
Table 3: AUC results for GA prediction with or without
PPI. Models are trained on one site and cross-validated on
the other sites. Darker color indicates better performance.

With PPI    A        B        C         D        Mean     STD
A            1.000   0.906    0.877     0.865    0.912    0.061
B           0.851    0.975    0.863     0.910    0.900    0.056
C           0.954    0.875    0.904     0.931    0.916    0.034
D           0.824    0.846    0.853     0.904    0.857    0.034
No PPI      A        B        C         D        Mean     STD
A           1.000    0.854    0.832     0.827    0.878    0.082
B           0.810    0.874    0.850     0.906    0.860    0.040
                                                                  Figure 6: Saliency maps on LIDC-IDR. Saliency maps of
C           0.860    0.779    0.873     0.862    0.843    0.043
                                                                  PPI+WBP are mostly consistent with the ground truths.
D           0.748    0.792    0.836     0.961    0.834    0.092
                                                                  setup from [50] to predict
proactive interventions.                                          lesions. We use Incep-
Cross-domain generalization. Common to medical im-                tion v3 [58] as our base model for both standard classifica-
age applications is that training samples are usually inte-       tion and PPI-enhanced training with various saliency map-
grated from a number of healthcare facilities (i.e., domains),    ping schemes. See the Appendix for details.
and that predictions are sometimes to be made on sub-
jects at other facilities. Despite big efforts to standardize     Lesion classification. We first compare PPI to other spe-
the image collection protocols, with different imaging sys-       cialized SOTA network architectures. Table 4 summa-
tems operated by technicians with varying skills, apparent        rizes AUC scores of Tensor Net-X [15], DenseNet [27],
domain shifts are likely to compromise the cross-domain           LoTeNet [50], Inception v3 [58], as well as our Incep-
performance of these models. We show this phenomenon              tion v3 trained with and without PPIW BP . The proposed
on the GA dataset in Table 3, where source samples are            PPIW BP (box) leads the performance chart by a considerable
collected from four different hospitals in different health       margin, improving Inception v3 from 0.92 to 0.94.
systems (A, B, C and D, see the Appendix for details).            Weakly-supervised image segmentation. In Figure 6,
Each cell contains the AUC of the model trained on site           we compare saliency maps generated by GradCAM, WBP,
X (row) and tested on site Y (column), with same-site pre-        WBP (box) to the ground truth lesion masks from expert
dictions made on hold-out samples. A significant perfor-          annotations. Note that we have only supplied patch-label
mance drop is observed for cross-domain predictions (off-         labels during training, not the pixel-level expert segmenta-
diagonals) compared to in-domain predictions (diagonals).         tion masks, which constitute a challenging task of weakly-
With the application of PPI, the performance gaps between         supervised image segmentation. In line with the observa-
in-domain and cross-domain predictions are considerably           tions from the GA experiment, our PPI-training enhanced
reduced. The overall accuracy gains of PPI further justify        WBP saliency maps are mostly consistent with the expert
the utility of causally-inspired modeling. Notably, site D        segmentations. Together with Table 4, Figure 6 confirms
manifests strong spurious correlation that help in-domain         that the proposed PPI+WBP improves both the classifica-
prediction but degrades out-of-site generalization, which is      tion performance and model interpretability.
partly resolved by the proposed PPI.
4.3. LIDC-IDRI: Lung Lesions Classification                       5. Conclusions
   To further examine the         Table 4: LIDC-IDRI clas-            We have presented Proactive Pseudo-Intervention (PPI),
practical advantages of the       sification AUC results.         a novel interpretable computer vision framework that organ-
proposed PPI in real-world                                        ically integrates saliency mapping, causal reasoning, syn-
applications, we bench-               Models              AUC     thetic intervention and contrastive learning. PPI couples
mark its utility on LIDC-             Tensor Net-X [15]   0.823   saliency mapping with contrastive training by creating ar-
IDRI; a public lung CT                DenseNet [27]       0.829   tificially intervened negative samples absent of causal fea-
scan dataset [3].       We            LoTeNet [50]        0.874   tures. To communicate model insights and facilitate causal-
followed the preprocess-              Inception v3 [58]   0.921   informed reasoning, we derived an architecture-agnostic
                                      +PPIGradCAM         0.933
ing steps outlined in [30]                                        saliency mapping scheme called Weight Back Propagation
                                      +PPIGradient        0.930
to prepare the data, and              +PPILRP             0.931   (WBP), which faithfully captures the causally-relevant pix-
adopted the experimental              +PPIW BP            0.935   els/features for model prediction. Visual inspection of the
                                      +PPIW BP (box)      0.941
Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models
saliency maps show that WBP, is more robust to spurious                       Processing Systems, pages 6967–6976, 2017. 2, 3, 12, 13,
features compared to competing approaches. Empirical re-                      14
sults on natural and medical datasets verify the combination           [11]   Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li,
of PPI and WBP consistently delivers performance boosts                       and Li Fei-Fei. Imagenet: A large-scale hierarchical image
across a wide range of tasks relative to competing solutions,                 database. In 2009 IEEE conference on computer vision and
and the gains are most significant where the application is                   pattern recognition, pages 248–255. Ieee, 2009. 1
complicated by small sample size, data heterogeneity, or               [12]   Li Deng. The mnist database of handwritten digit images for
                                                                              machine learning research [best of the web]. IEEE Signal
confounded with spurious correlations.
                                                                              Processing Magazine, 29(6):141–142, 2012. 1
                                                                       [13]   Amit Dhurandhar, Pin-Yu Chen, Ronny Luss, Chun-Chen
References                                                                    Tu, Paishun Ting, Karthikeyan Shanmugam, and Payel Das.
 [1] Julius Adebayo, Justin Gilmer, Michael Muelly, Ian Good-                 Explanations based on the missing: Towards contrastive ex-
     fellow, Moritz Hardt, and Been Kim. Sanity checks for                    planations with pertinent negatives. In Advances in Neural
     saliency maps. In Advances in Neural Information Process-                Information Processing Systems, pages 592–603, 2018. 14
     ing Systems, pages 9505–9515, 2018. 2                             [14]   Mengnan Du, Ninghao Liu, Qingquan Song, and Xia Hu. To-
 [2] Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, and David               wards explanation of dnn-based prediction with guided fea-
     Lopez-Paz. Invariant risk minimization. arXiv preprint                   ture inversion. In Proceedings of the 24th ACM SIGKDD
     arXiv:1907.02893, 2019. 2, 3                                             International Conference on Knowledge Discovery & Data
 [3] Samuel G Armato III, Geoffrey McLennan, Luc Bidaut,                      Mining, pages 1358–1367, 2018. 2, 12, 13
     Michael F McNitt-Gray, Charles R Meyer, Anthony P                 [15]   Stavros Efthymiou, Jack Hidary, and Stefan Leichenauer.
     Reeves, Binsheng Zhao, Denise R Aberle, Claudia I Hen-                   Tensornetwork for machine learning.           arXiv preprint
     schke, Eric A Hoffman, et al. The lung image database con-               arXiv:1906.06329, 2019. 8
     sortium (lidc) and image database resource initiative (idri):     [16]   Dumitru Erhan, Yoshua Bengio, Aaron Courville, and Pascal
     a completed reference database of lung nodules on ct scans.              Vincent. Visualizing higher-layer features of a deep network.
     Medical physics, 38(2):915–931, 2011. 8, 17                              University of Montreal, 1341(3):1, 2009. 13
 [4] Sebastian Bach, Alexander Binder, Grégoire Montavon,             [17]   Ruth Fong, Mandela Patrick, and Andrea Vedaldi. Un-
     Frederick Klauschen, Klaus-Robert Müller, and Wojciech                  derstanding deep networks via extremal perturbations and
     Samek. On pixel-wise explanations for non-linear classi-                 smooth masks. In Proceedings of the IEEE International
     fier decisions by layer-wise relevance propagation. PloS one,            Conference on Computer Vision, pages 2950–2958, 2019. 12
     10(7):e0130140, 2015. 2, 6, 13                                    [18]   Ruth C Fong and Andrea Vedaldi. Interpretable explanations
 [5] David S Boyer, Ursula Schmidt-Erfurth, Menno van Look-                   of black boxes by meaningful perturbation. In Proceedings
     eren Campagne, Erin C Henry, and Christopher Brittain. The               of the IEEE International Conference on Computer Vision,
     pathophysiology of geographic atrophy secondary to age-                  pages 3429–3437, 2017. 2, 3, 12, 13, 14
     related macular degeneration and the complement pathway           [19]   Hiroshi Fukui, Tsubasa Hirakawa, Takayoshi Yamashita, and
     as a therapeutic target. Retina (Philadelphia, Pa.), 37(5):819,          Hironobu Fujiyoshi. Attention branch network: Learning
     2017. 7, 15                                                              of attention mechanism for visual explanation. In Proceed-
 [6] Chun-Hao Chang, Elliot Creager, Anna Goldenberg, and                     ings of the IEEE Conference on Computer Vision and Pattern
     David Duvenaud. Explaining image classifiers by counter-                 Recognition, pages 10705–10714, 2019. 2, 13, 14
     factual generation. In International Conference on Learning       [20]   Ross Girshick, Jeff Donahue, Trevor Darrell, and Jitendra
     Representations, 2018. 2, 3, 13, 14                                      Malik. Rich feature hierarchies for accurate object detection
 [7] Aditya Chattopadhay, Anirban Sarkar, Prantik Howlader,                   and semantic segmentation. In Proceedings of the IEEE con-
     and Vineeth N Balasubramanian. Grad-cam++: General-                      ference on computer vision and pattern recognition, pages
     ized gradient-based visual explanations for deep convolu-                580–587, 2014. 1
     tional networks. In 2018 IEEE Winter Conference on Appli-         [21]   Yash Goyal, Ziyan Wu, Jan Ernst, Dhruv Batra, Devi Parikh,
     cations of Computer Vision (WACV), pages 839–847. IEEE,                  and Stefan Lee. Counterfactual visual explanations. In
     2018. 13                                                                 ICML, 2019. 2, 3, 13, 14
 [8] Chenyi Chen, Ari Seff, Alain Kornhauser, and Jianxiong            [22]   Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin
     Xiao. Deepdriving: Learning affordance for direct percep-                Tallec, Pierre Richemond, Elena Buchatskaya, Carl Doersch,
     tion in autonomous driving. In Proceedings of the IEEE                   Bernardo Avila Pires, Zhaohan Guo, Mohammad Ghesh-
     International Conference on Computer Vision, pages 2722–                 laghi Azar, et al. Bootstrap your own latent-a new approach
     2730, 2015. 1                                                            to self-supervised learning. Advances in Neural Information
 [9] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Ge-                    Processing Systems, 33, 2020. 3
     offrey Hinton. A simple framework for contrastive learning        [23]   Michael Gutmann and Aapo Hyvärinen. Noise-contrastive
     of visual representations. arXiv preprint arXiv:2002.05709,              estimation: A new estimation principle for unnormalized
     2020. 3                                                                  statistical models. In Proceedings of the Thirteenth Inter-
[10] Piotr Dabkowski and Yarin Gal. Real time image saliency                  national Conference on Artificial Intelligence and Statistics,
     for black box classifiers. In Advances in Neural Information             pages 297–304, 2010. 3
Proactive Pseudo-Intervention: Contrastive Learning For Interpretable Vision Models
[24] Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross                  AI: Interpreting, Explaining and Visualizing Deep Learning,
     Girshick. Momentum contrast for unsupervised visual rep-                pages 253–265. Springer, 2019. 13
     resentation learning. In Proceedings of the IEEE/CVF Con-        [38]   Grégoire Montavon, Alexander Binder, Sebastian La-
     ference on Computer Vision and Pattern Recognition, pages               puschkin, Wojciech Samek, and Klaus-Robert Müller.
     9729–9738, 2020. 3                                                      Layer-wise relevance propagation: an overview. In Explain-
[25] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.                  able AI: interpreting, explaining and visualizing deep learn-
     Deep residual learning for image recognition. In Proceed-               ing, pages 193–209. Springer, 2019. 13
     ings of the IEEE conference on computer vision and pattern       [39]   Aaron van den Oord, Yazhe Li, and Oriol Vinyals. Repre-
     recognition, pages 770–778, 2016. 3                                     sentation learning with contrastive predictive coding. arXiv
[26] Sara Hooker, Dumitru Erhan, Pieter-Jan Kindermans, and                  preprint arXiv:1807.03748, 2018. 3
     Been Kim. A benchmark for interpretability methods in deep       [40]   Matthew O’Shaughnessy, Gregory Canal, Marissa Connor,
     neural networks. In Advances in Neural Information Pro-                 Christopher Rozell, and Mark Davenport. Generative causal
     cessing Systems, pages 9737–9748, 2019. 1, 2, 7                         explanations of black-box classifiers. Advances in Neural
[27] Gao Huang, Zhuang Liu, Laurens Van Der Maaten, and Kil-                 Information Processing Systems, 33, 2020. 3, 13
     ian Q Weinberger. Densely connected convolutional net-           [41]   Judea Pearl. Causality. Cambridge university press, 2009. 3
     works. In Proceedings of the IEEE conference on computer         [42]   Vitali Petsiuk, Abir Das, and Kate Saenko. Rise: Random-
     vision and pattern recognition, pages 4700–4708, 2017. 8                ized input sampling for explanation of black-box models.
[28] Fredrik Johansson, Uri Shalit, and David Sontag. Learning               arXiv preprint arXiv:1806.07421, 2018. 12
     representations for counterfactual inference. In International
                                                                      [43]   Sylvestre-Alvise Rebuffi, Ruth Fong, Xu Ji, and Andrea
     conference on machine learning, pages 3020–3029, 2016. 2
                                                                             Vedaldi. There and back again: Revisiting backpropagation
[29] Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna,
                                                                             saliency methods. In Proceedings of the IEEE/CVF Con-
     Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, and
                                                                             ference on Computer Vision and Pattern Recognition, pages
     Dilip Krishnan. Supervised contrastive learning. arXiv
                                                                             8839–8848, 2020. 1, 2
     preprint arXiv:2004.11362, 2020. 3
                                                                      [44]   Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin.
[30] Simon Kohl, Bernardino Romera-Paredes, Clemens Meyer,
                                                                             ” why should i trust you?” explaining the predictions of any
     Jeffrey De Fauw, Joseph R Ledsam, Klaus Maier-Hein,
                                                                             classifier. In Proceedings of the 22nd ACM SIGKDD interna-
     SM Ali Eslami, Danilo Jimenez Rezende, and Olaf Ron-
                                                                             tional conference on knowledge discovery and data mining,
     neberger. A probabilistic u-net for segmentation of ambigu-
                                                                             pages 1135–1144, 2016. 12
     ous images. In Advances in Neural Information Processing
     Systems, pages 6965–6975, 2018. 8, 17                            [45]   Laura Rieger, Chandan Singh, William Murdoch, and Bin
                                                                             Yu. Interpretations are useful: penalizing explanations to
[31] Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple
                                                                             align neural networks with prior knowledge. In International
     layers of features from tiny images. 2009. 1, 6
                                                                             Conference on Machine Learning, pages 8116–8126. PMLR,
[32] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton.
                                                                             2020. 3
     Imagenet classification with deep convolutional neural net-
     works. Communications of the ACM, 60(6):84–90, 2017. 1           [46]   Andrew Slavin Ross, Michael C Hughes, and Finale Doshi-
[33] Curtis P Langlotz, Bibb Allen, Bradley J Erickson, Jayashree            Velez. Right for the right reasons: training differentiable
     Kalpathy-Cramer, Keith Bigelow, Tessa S Cook, Adam E                    models by constraining their explanations. In Proceedings
     Flanders, Matthew P Lungren, David S Mendelson, Jef-                    of the 26th International Joint Conference on Artificial Intel-
     frey D Rudie, et al. A roadmap for foundational research                ligence, pages 2662–2670, 2017. 3
     on artificial intelligence in medical imaging: from the 2018     [47]   Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, San-
     nih/rsna/acr/the academy workshop. Radiology, 291(3):781–               jeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy,
     791, 2019. 6                                                            Aditya Khosla, Michael Bernstein, Alexander C. Berg, and
[34] Jessica N Leuschen, Stefanie G Schuman, Katrina P Winter,               Li Fei-Fei. ImageNet Large Scale Visual Recognition Chal-
     Michelle N McCall, Wai T Wong, Emily Y Chew, Thomas                     lenge. International Journal of Computer Vision (IJCV),
     Hwang, Sunil Srivastava, Neeru Sarin, Traci Clemons, et al.             115(3):211–252, 2015. 6
     Spectral-domain optical coherence tomography characteris-        [48]   Paul Sajda. Machine learning for detection and diagnosis of
     tics of intermediate age-related macular degeneration. Oph-             disease. Annu. Rev. Biomed. Eng., 8:537–565, 2006. 1
     thalmology, 120(1):140–150, 2013. 6, 15                          [49]   Bernhard Schölkopf. Causality for machine learning. arXiv
[35] Kunpeng Li, Ziyan Wu, Kuan-Chuan Peng, Jan Ernst, and                   preprint arXiv:1911.10500, 2019. 3
     Yun Fu. Tell me where to look: Guided attention inference        [50]   Raghavendra Selvan and Erik B Dam. Tensor networks for
     network. In Proceedings of the IEEE Conference on Com-                  medical image classification. In Medical Imaging with Deep
     puter Vision and Pattern Recognition, pages 9215–9223,                  Learning, 2020. 8, 17
     2018. 3, 13                                                      [51]   Ramprasaath R Selvaraju, Michael Cogswell, Abhishek Das,
[36] Aravindh Mahendran and Andrea Vedaldi. Salient decon-                   Ramakrishna Vedantam, Devi Parikh, and Dhruv Batra.
     volutional networks. In European Conference on Computer                 Grad-cam: Visual explanations from deep networks via
     Vision, pages 120–135. Springer, 2016. 13                               gradient-based localization. In Proceedings of the IEEE In-
[37] Grégoire Montavon. Gradient-based vs. propagation-based                ternational Conference on Computer Vision, pages 618–626,
     explanations: an axiomatic comparison. In Explainable                   2017. 2, 6, 13
[52] Dasom Seo, Kanghan Oh, and Il-Seok Oh. Regional multi-          [65] Tan Wang, Jianqiang Huang, Hanwang Zhang, and Qianru
     scale approach for visually pleasing explanations of deep            Sun. Visual commonsense representation learning via causal
     neural networks. IEEE Access, 8:8572–8582, 2019. 12                  inference. In Proceedings of the IEEE/CVF Conference on
[53] Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje.              Computer Vision and Pattern Recognition Workshops, pages
     Learning important features through propagating activation           378–379, 2020. 2
     differences. In International Conference on Machine Learn-      [66] Matthew D Zeiler and Rob Fergus. Visualizing and under-
     ing, pages 3145–3153, 2017. 2, 13                                    standing convolutional networks. In European conference on
[54] Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman.                computer vision, pages 818–833. Springer, 2014. 12
     Deep inside convolutional networks: Visualising image           [67] Cheng Zhang, Kun Zhang, and Yingzhen Li. A causal
     classification models and saliency maps. arXiv preprint              view on robustness of neural networks. arXiv preprint
     arXiv:1312.6034, 2013. 2                                             arXiv:2005.01095, 2020. 3
[55] Karen Simonyan and Andrew Zisserman. Very deep convo-           [68] Yitian Zhao, Yalin Zheng, Yifan Zhao, Yonghuai Liu, Zhili
     lutional networks for large-scale image recognition. arXiv           Chen, Peng Liu, and Jiang Liu. Uniqueness-driven saliency
     preprint arXiv:1409.1556, 2014. 3                                    analysis for automated lesion detection with applications to
[56] Hang Su, Subhransu Maji, Evangelos Kalogerakis, and Erik             retinal diseases. In International Conference on Medical Im-
     Learned-Miller. Multi-view convolutional neural networks             age Computing and Computer-Assisted Intervention, pages
     for 3d shape recognition. In Proceedings of the IEEE in-             109–118. Springer, 2018. 2
     ternational conference on computer vision, pages 945–953,       [69] Bolei Zhou, Aditya Khosla, Agata Lapedriza, Aude Oliva,
     2015. 7, 16                                                          and Antonio Torralba. Learning deep features for discrimina-
[57] Raphael Suter, Djordje Miladinovic, Bernhard Schölkopf,             tive localization. In Proceedings of the IEEE conference on
     and Stefan Bauer. Robustly disentangled causal mecha-                computer vision and pattern recognition, pages 2921–2929,
     nisms: Validating deep representations for interventional ro-        2016. 2
     bustness. In International Conference on Machine Learning,
     pages 6056–6065. PMLR, 2019. 3
[58] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jon
     Shlens, and Zbigniew Wojna. Rethinking the inception archi-
     tecture for computer vision. In Proceedings of the IEEE con-
     ference on computer vision and pattern recognition, pages
     2818–2826, 2016. 4, 8, 17
[59] Yonglong Tian, Dilip Krishnan, and Phillip Isola. Con-
     trastive multiview coding. arXiv preprint arXiv:1906.05849,
     2019. 3
[60] Jorg Wagner, Jan Mathias Kohler, Tobias Gindele, Leon Het-
     zel, Jakob Thaddaus Wiedemer, and Sven Behnke. Inter-
     pretable and fine-grained visual explanations for convolu-
     tional neural networks. In Proceedings of the IEEE Con-
     ference on Computer Vision and Pattern Recognition, pages
     9097–9107, 2019. 2, 3, 12, 13
[61] C. Wah, S. Branson, P. Welinder, P. Perona, and S. Belongie.
     The Caltech-UCSD Birds-200-2011 Dataset. Technical Re-
     port CNS-TR-2011-001, California Institute of Technology,
     2011. 6
[62] Haohan Wang, Zexue He, Zachary C Lipton, and Eric P
     Xing. Learning robust representations by projecting super-
     ficial statistics out. InInternational Conference on Learning
     Representations, 2019. 3
[63] Haofan Wang, Zifan Wang, Mengnan Du, Fan Yang, Zijian
     Zhang, Sirui Ding, Piotr Mardziel, and Xia Hu. Score-cam:
     Score-weighted visual explanations for convolutional neural
     networks. In Proceedings of the IEEE/CVF Conference on
     Computer Vision and Pattern Recognition Workshops, pages
     24–25, 2020. 2, 15
[64] Lezi Wang, Ziyan Wu, Srikrishna Karanam, Kuan-Chuan
     Peng, Rajat Vikram Singh, Bo Liu, and Dimitris N Metaxas.
     Sharpen focus: Learning with attention separability and con-
     sistency. In Proceedings of the IEEE International Confer-
     ence on Computer Vision, pages 512–521, 2019. 2, 14
A. Weight Backpropagation (WBP)                                                           soft masking, we set ω to 100 and σ to 0.25. We have also
                                                                                          experimented with image-adaptive thresholds instead of a
A.1. Graphical illustration of WBP                                                        fixed σ for all inputs, i.e., set the threshold as mean value
    See Figure 7 for a graphical illustration with standard                               plus k times of the standard deviation of WBP weights of
affine mapping and ReLU activation.                                                       the whole image. We repeat the experiments a few times
                                                                                          and the results are consistent. The experiment comparison
A.2. Derivation of Convolutional Weight Backprop-                                         of these masking methods mention above is conducted on
      agation                                                                             LIDC dataset.
   Let’s denote the input variable as I ∈ RH×W , the                                         Table 5: Different causal masking methods on LIDC
convolutional filter weight as W ∈ R(2S+1)×(2S+1) , the
output variable as O ∈ RH×W , and the weight back-                                                      Models                    AUC
propagate to O as Ŵ ∈ RH×W . We omit the bias here
because it does not directly interact with the input vari-                                              WBP-soft (fixed σ)        0.931
ables. We denote ⊗ as the convolutional operator. We have                                               WBP-soft (adaptive σ)     0.941
                                                                                                        WBP-hard (point)          0.935
                O =I ⊗W                                                                                 WBP-hard (box)            0.941
                         S
                         X        S
                                  X
             Oi,j =                      Ii+i0 ,j+j 0 Wi0 +S,j 0 +S
                        i0 =−S j 0 =−S

XX                         S
                        XX X                 S
                                             X
                                                                                          C. Related Work
         Oi,j W̃i,j =                              Ii+i0 ,j+j 0 Wi0 +S,j 0 +S W̃i,j
                                                                                              In this work, we propose a contrastive causal represen-
 i   j                   i    j   i0 =−S j 0 =−S
                                          S      S
                                                                                          tation learning strategy, i.e., Proactive Pseudo-Intervention
                                                                                          (PPI), that leverages proactive interventions to identify
XX                      XX                X      X
         Oi,j W̃i,j =             Ii,j                    W̃i+i0 ,j+j 0 W−i0 +S,−j 0 +S
 i   j                   i    j          i0 =−S j 0 =−S                                   causally-relevant image features. This approach is com-
XX
         Oi,j W̃i,j =
                        XX
                                  Ii,j (W̃ ⊗ [W ]f lipi,j )i,j                            plemented with a novel causal salience map visualization
 i   j                   i    j                                                           module, i.e., Weight Back Propagation (WBP), that identi-
   Hence the weight backpropagate through a convolu-                                      fies important pixels in the raw input image, which greatly
tional layer is W̃ l = W̃ l+1 ⊗ [W l ]f lip . For the 3D                                  facilitates interpretability of predictions.
cases, I l ∈ RD1 ×H×W ,the weight back propagates to Ol                                       Prior related works will be discussed in this section.
is W̃ l+1 ∈ RD2 ×H×W and the convolutional weight is                                      Compared with alternative post-hot saliency mapping meth-
W l ∈ RD2 ×D1 ×(2S+1)×(2S+1) . To match the depth of                                      ods, WBP outperforms these methods as both a standalone
W̃ l+1 , the W l is transposed in the first two dimensions.                               causal saliency map and a trainable model for model in-
                          T0,1
So W̃ l = W̃ l+1 ⊗ [W l ]f lip 2,3
                                   . If the convolutional layer                           terpretation. Compared with other trainable interpretation
                                                         l+1
is downsizing the input variable (i.e., strides), the W̃ijk                               models, the proposed PPI+WBP improves both model per-
is padded with zeros around the weights (left,right,up, and                               formance and model interpretations.
down) to for the input elements that the convolutional filter                             C.1. Post-hoc Saliency Maps
strides over. The number of padding zeros is equal to the
number of strides minus 1.                                                                    We compare WBP with other post-hoc saliency map-
                                                                                          ping methods to show why WBP is able to target the causal
B. Details on Causal Masking                                                              features, and generate more succinct and reliable saliency
                                                                                          maps.
    In this work, we consider three types of causal masking:                                  Perturbation Based Methods These methods make per-
(i) the point-wise soft causal masking defined by Equation                                turbations to individual inputs or neurons and monitor the
(2) in the main text, (ii) hard masking, and (iii) box mask-                              impact on output neurons in the network. [66] occludes dif-
ing. For the hard masking, for each image, we keep points                                 ferent segments of an input image and visualized the change
with WBP weight larger than k times of the standard devia-                                in the activations of subsequent layers. Several methods fol-
tion of WBP weights of the whole image. We test k from 1                                  low a similar idea, but use other importance measures or oc-
to 7 and achieve similar results. As k = 7 performs slightly                              clusion strategies [42, 44, 52]. More complicated works aim
better, we set k as 7 for all experiments. For the box mask-                              to generate an explanation by optimizing for a perturbed
ing, we use the center of mass for these kept points as the                               version of the image [18, 17, 10, 14]. [60] proposes a new
center to draw a box. The height and width of this box                                    adversarial defense technique which filters gradients during
is defined as centerh/w ± 1.2stdh/w . In this way at least                                optimization to achieve fine-grained explanation. However,
90% of filtered points are contained in the box. For the                                  such perturbation based methods are computationally inten-
Figure 7: Illustration of the Weight Backpropagation (WBP) through a fully connected layer with ReLU activation layer.

sive and involve sophisticated model designs, which make it          Table 6: A list of commonly used LRP rules.[38]
extremely hard to be integrated with other advance learning
strategies.                                                      Rules       Formula
                                                                                       a w
   Backpropagation Based Methods Backpropagation
                                                                                  P
                                                                 LRP         Rj = k P j ajk        Rk
                                                                                       0,j j wjk
based methods (BBM) propagate an importance signal from                           P      aj w jk
                                                                 LRP        Rj = k +P aj wjk Rk
an output neuron backwards through the layers to the in-                                  0,j
                                                                                                      +
                                                                                       aj (wjk +γwjk    )
put. These methods are usually fast to compute and pro-
                                                                                  P
                                                                 LRPγ        Rj = k P a (w +γw           + Rk
                                                                                       0,j j     jk      jk )
duce fine-grained importance/relevancy maps. WBP is one                           P       (a w )    +
                                                                                                                (a w )−
of such method.                                                  LRPαβ       Rj = k (α P j(ajjkwjk )+ − β P j(ajjkwjk )− )Rk
                                                                                          0,j                   0,j
                                                                             Rj = k P1 1 Rk
                                                                                  P
    The pioneer methods in this category backpropagate           LRPf lat
                                                                                       j
a gradient to the image, and branches of studies extend                           P w2
                                                                 LRPw2       Rj = j P ijw2 Rj
this work by manipulating the gradient. These methods                                 i   ij
                                                                                  P xi wij −li wij  +
                                                                                                      −hi w− ij
are discussed and compared in [36, 16]. However, these           LRPZ β      Rj = j P x w −l w+ −h w− ij Rj
                                                                                           i   i   ij   i   ij   i
maps are generally less concentrated [10, 18] and less in-
terpretable. Other BBMs such as Layer-wise Relevance
Propagation [4], DeepLift [53] employ top-down relevancy
propagation rules. DeepLift is sensitive to the reference       gory (such as CAM, Grad-CAM, guided Grad-CAM, Grad-
inputs, which needs more human efforts and background           CAM++) use a linear combination of class activation maps
knowledge to produce appealing saliency maps. The nature        from convolutional layers to derive a saliency map. The
of depending on reference inputs limits its ability on model    main difference between them is how to the linear combina-
diagnosis and couple with learning strategies to continu-       tion weights are computed. The generation of saliency maps
ously improving models’ performance. LRP decomposes             is easy and these methods can be coupled with advanced
the relevance, R, from a neuron, k, in the upper layer to ev-   training strategies to improve training [35]. However, they
ery connected neurons, j, in the lower layer. The decompo-      fail at visualizing fine-grained evidence, which is particu-
sition is distributed through gradients under the suggested     larly important in explaining medical classification models.
implementation [38]. Our experiments on GA and CUB              Additionally, it is not guaranteed that the resulting expla-
datasets show that vanilla LRP performs similar to gradi-       nations are faithful and reflect the decision making process
ent based methods, which is also demonstrated in [37]. The      of the model [14, 51, 60]. Grad-CAM++ [7] proposes to
variants of LRP use complex rules to prioritize positive or     introduce higher-order derivatives to capture pixel-level im-
large relevance, making the saliency map visually appeal-       portance, while its high computational cost in calculating
ing to human. However, our experiments demonstrate the          the second- and third-order derivatives makes it impractical
unfaithfulness of LRP and its variants as they highlight spu-   for training purposes.
riously correlated features (boarderlines and backgrounds).
By contrast, our WBP backpropagates the the weights of          C.2. Interpretable Models
through layers to compute the contributions of each input          Unlike the post hoc saliency map generation described
pixel, which is truly faithful to the model, and WBP tends      above, an alternative approach is to train a separate module
to highlight the target objects themselves rather than the      to explicitly produce model explanations [19, 21, 6, 18, 53].
background. At the same time, the simplicity and effi-          Such post hoc causal explanations can be generated with
ciency makes WBP easily work with other advanced learn-         black-box classifiers based on a learned low-dimensional
ing strategies for both model diagnosis and improvements        representation of the data [40] . Related to our work is
during training.                                                adversarial-based visual explanation method is developed
   Activation Based Methods Methods under this cate-            in [60], highlighting the key features in the input image for
You can also read