of 7
Active Learning under Label Shift
Supplementary Materials
8 Proofs
8.1 Proof of Theorem 1
We formalize the violation of label shift assumptions resulting from subsampling as label shift drift
[Azizzadenesheli et al., 2019].
Lemma 1.
The drift from label shift is bounded by:
1
E
X,Y
P
test
[
P
med
(
x
|
y
)
P
test
(
x
|
y
)
]
≤‖
r
s

m
err
(
h
0
,r
s

m
)
(15)
Proof.
The drift is equivalent to expected importance weights,
1
E
X,Y
P
test
[
P
med
(
x
|
y
)
P
test
(
x
|
y
)
]
=
1
X,Y
P
med
(
x
|
y
)
P
test
(
y
)
=
1
X,Y
P
med
(
x,y
)
P
test
(
y
)
P
med
(
y
)
=
1
E
X,Y
P
med
[
P
test
(
y
)
P
med
(
y
)
]
(16)
Drift can therefore be estimated in practice by randomly labeling subsampled points and measuring the average
importance weight value. We can further expand the value of drift as:
1
E
X,Y
P
med
[
P
test
(
y
)
P
med
(
y
)
]
=
1
X,Y
CP
src
(
x,y
)
P
ss
(
h
0
(
x
))
P
test
(
y
)
P
med
(
y
)
=
1
C
E
X,Y
P
src
[
P
ss
(
h
0
(
x
))
P
test
(
y
)
P
med
(
y
)
]
=
1
C
E
X,Y
P
src
[
P
ss
(
y
)
P
test
(
y
)
P
med
(
y
)
]
+
C
E
X,Y
P
src
[
(
P
ss
(
h
0
(
x
))
P
ss
(
y
))
P
test
(
y
)
P
med
(
y
)
]
=
1
Y
[
P
med
(
y
)
P
test
(
y
)
P
med
(
y
)
]
+
C
E
X,Y
P
src
[
(
P
ss
(
h
0
(
x
))
P
ss
(
y
))
P
test
(
y
)
P
med
(
y
)
]
(17)
where
C
is a constant where
P
ss
=
1
C
P
med
P
src
and
P
med
denotes the target medial distribution. The second term
corresponds to a weighted L1 error on
P
src
.
C
E
X,Y
P
src
[
(
P
ss
(
h
0
(
x
))
P
ss
(
y
))
P
test
(
y
)
P
med
(
y
)
]
≤‖
r
s

m
E
X,Y
P
src
[
|
1
[
h
0
(
x
)
6
=
y
]
|
P
test
(
y
)
P
med
(
y
)
]
=
r
s

m
err(
h
0
,r
s

m
)
(18)
where
err
(
h
0
,r
) denotes the importance weighted 0/1-error of a blackbox predictor
h
0
on
Ps
. As the first term is
thus dominated, we have that drift is bounded by the accuracy of the blackbox hypothesis.
Plugging Lemma 1 into Theorem 2 in [
Azizzadenesheli et al., 2019
] yields a generalization of Theorem 1 where
the number of unlabeled datapoints from the test distribution is
n
.
Eric Zhao, Anqi Liu, Anima Anandkumar, Yisong Yue
Theorem 4.
With probability
1
δ
, for all
n
1
:
|
|≤O
2
σ
min
θ
m

t
2
log
(
nk
δ
)
n
+
log
(
n
δ
)
n
+
log
(
n
δ
)
n
+
θ
s

m
err
(
h
0
,r
m

t
)
(19)
where
σ
min
denotes the smallest singular value of the confusion matrix and
err
(
h
0
,r
)
denotes the importance
weighted
0
/
1
-error of a blackbox predictor
h
0
on
P
src
.
Theorem 1 follows by setting
n
→∞
.
8.2 Theorem 2 and Theorem 3 Proofs
We will prove Theorem 2 and Theorem 3 for the general case where the number of unlabeled datapoints from the
test distribution is
n
. For the case depicted in the main paper, set
n
→∞
.
First, we review the IWAL-CAL active learning algorithm [
Beygelzimer et al., 2010
]. Let
err
S
i
(
h
)
[0
,
1] denote
the error of hypothesis
h
H
as estimated on
S
i
while
err
P
test
(
h
) denote the expected error of
h
on
P
test
. We
next define,
h
:= argmin
h
H
err
P
test
(
h
)
,
h
k
:= argmin
h
H
err
S
k
1
(
h
)
,
h
k
:= argmin
{
err
S
k
1
(
h
)
|
h
H
h
(
D
(
k
)
unlab
)
6
=
h
k
(
D
(
k
)
unlab
)
}
G
k
:= err
S
k
1
(
h
k
)
err
S
k
1
(
h
k
)
IWAL-CAL employs a sampling probability
P
t
= min
{
1
,s
}
for the
s
(0
,
1) which solves the equation,
G
t
=
(
c
1
s
c
1
+ 1
)
C
0
log
t
t
1
+
(
c
2
s
c
2
+ 1
)
C
0
log
t
t
1
where
C
0
is a constant bounded in Theorem 2 and
c
1
:= 5 + 2
2
,c
2
:= 5.
The most involved step in deriving generalization and sample complexity bounds for MALLS is bounding the
deviation of empirical risk estimates. This is done through the following theorem.
Theorem 5.
Let
Z
i
:= (
X
i
,Y
i
,Q
i
)
be our source data set, where
Q
i
is the indicator function on whether
(
X
i
,Y
i
)
is sampled as labeled data. The following holds for all
n
1
and all
h
∈H
with probability
1
δ
:
|
err
(
h,Z
1:
n
)
err
(
h
,Z
1:
n
)
err
(
h
) +
err
(
h
)
|
≤O
(2 +
θ
2
)
ε
n
P
min
,n
(
h
)
+
ε
n
P
min
,n
(
h
)
+
2
d
(
P
test
,P
src
) log(
2
n
|
H
|
δ
)
3
n
+
2
d
2
(
P
test
,P
src
) log(
2
n
|
H
|
δ
)
n
(20)
+
r
s

m
err
(
h
0
,r
s

m
) +
2
σ
min
θ
m

t
2
log
(
nk
δ
)
λn
+
log
(
n
δ
)
λn
+
log
(
n
δ
)
n
+
θ
s

m
err
(
h
0
,r
m

t
)
where
ε
n
:=
16 log(2(2+
n
log
2
n
)
n
(
n
+1)
|
H
|
)
n
.
For reading convenience, we set
P
src
:=
P
ulb
. This deviation bound will plug in to IWAL-CAL for generalization
and sample complexity bounds. In the remainder of this appendix section, we detail our proof of Theorem 5. We
proceed by expressing Theorem 5 in a more general form with a bounded function
f
:
X
×
Y
[
1
,
1] which will
eventually represent err(
h
)
err(
h
).
We borrow notation for the terms
W,Q
from [
Beygelzimer et al., 2010
], where
Q
i
is an indicator random variable
indicating whether the
i
th datapoint is labeled and
W
:=
Q
i
̃
Q
i
r
(
i
)
m

t
f
(
x
i
,y
i
). We use the shorthand
r
(
i
)
for the
y
i
th component of importance weight
r
. Similarly, the indicator random variable
̃
Q
i
indicates whether the
i
th
data sample is retained by the subsampler. The expectation
E
i
[
W
] is taken over the randomness of
Q
and
̃
Q
. We
Active Learning under Label Shift
also borrow [
Azizzadenesheli et al., 2019
]’s label shift notation and define
k
as the size of the output space (finite)
and denote estimated importance weights with hats, e.g.
ˆ
r
. We also introduce a variant of
W
using estimated
importance weights
r
:
ˆ
W
:=
Q
i
̃
Q
i
ˆ
r
(
i
)
m

t
f
(
x
i
,y
i
). Finally, we follow [
Cortes et al., 2010
] and use
d
α
(
P
||
P
) to
denote 2
D
α
(
P
||
P
)
where
D
α
(
P
||
P
) := log(
P
i
P
i
) is the Renyi divergence of distributions
P
and
P
.
We seek to bound with high probability,
|
|
:=
1
n
(
n
i
=1
ˆ
W
i
)
E
x,y
P
trg
[
f
(
x,y
)]
≤|
1
|
+
|
2
|
+
|
3
|
+
|
4
|
(21)
where,
1
:=
E
x,y
P
trg
[
f
(
x,y
)]
E
x,y
P
src
[
W
i
]
,
2
:=
E
x,y
P
src
[
W
i
]
1
n
n
i
=1
E
i
[
W
i
]
,
3
:=
1
n
n
i
=1
E
i
[
W
i
]
E
i
[
ˆ
W
i
]
4
:=
1
n
n
i
=1
E
i
[
ˆ
W
i
]
ˆ
W
i
1
corresponds to the drift from label shift introduced by subsampling, ∆
2
to finite-sample variance. and ∆
3
to
label shift estimation errors. The final ∆
4
corresponds to the variance from randomly sampling.
We bound ∆
4
using a Martingale technique from [
Zhang, 2005
] also adopted by [
Beygelzimer et al., 2010
]. We
take Lemmas 1, 2 from [
Zhang, 2005
] as given. We now proceed in a fashion similar to the proof of Theorem 1
from [Beygelzimer et al., 2010]. We begin with a generalization of Lemma 6 in [Beygelzimer et al., 2010].
Lemma 2.
If
0
< λ <
3
P
i
ˆ
r
(
i
)
m

t
, then
log
E
i
[exp(
λ
(
ˆ
W
i
E
i
[
ˆ
W
i
]))]
ˆ
r
i
ˆ
r
(
i
)
m

t
λ
2
2
P
i
(1
ˆ
r
(
i
)
m

t
λ
3
P
i
)
(22)
where
ˆ
r
i
:= ˆ
r
(
i
)
m

t
E
i
[
̃
Q
i
]
. If
E
i
[
ˆ
W
i
] = 0
then
log
E
i
[exp(
λ
(
ˆ
W
i
E
i
[
ˆ
W
i
]))] = 0
(23)
Proof.
First, we bound the range and variance of
ˆ
W
i
. The range is trivial
|
ˆ
W
i
|≤
Q
i
̃
Q
i
ˆ
r
(
i
)
m

t
P
i
ˆ
r
(
i
)
m

t
P
i
(24)
Since subsampling and importance weighting ideally corrects underlying label shift, we can simplify the variance
as,
E
i
[(
ˆ
W
i
E
i
[
ˆ
W
i
])
2
]
ˆ
r
i
ˆ
r
(
i
)
m

t
P
i
f
(
x
i
,y
i
)
2
r
2
i
f
(
x
i
,y
i
)
2
+ ˆ
r
2
i
f
(
x
i
,y
i
)
2
ˆ
r
i
ˆ
r
(
i
)
m

t
P
i
(25)
Following [
Beygelzimer et al., 2010
], we choose a function
g
(
x
) := (
exp
(
x
)
x
1)
/x
2
for
x
6
=
0 so that
exp(
x
) = 1 +
x
+
x
2
g
(
x
) holds. Note that
g
(
x
) is non-decreasing. Thus,
E
i
[exp(
λ
(
ˆ
W
i
E
i
[
ˆ
W
i
]))] =
E
i
[1 +
λ
(
ˆ
W
i
E
i
[
ˆ
W
i
]) +
λ
2
(
ˆ
W
i
E
i
[
ˆ
W
i
])
2
g
(
λ
(
ˆ
W
i
E
i
[
ˆ
W
i
]))]
= 1 +
λ
2
E
i
[(
ˆ
W
i
E
i
[
ˆ
W
i
])
2
g
(
λ
(
ˆ
W
i
E
i
[
ˆ
W
i
]))]
1 +
λ
2
E
i
[(
ˆ
W
i
E
i
[
ˆ
W
i
])
2
g
(
λ
ˆ
r
(
i
)
m

t
/P
i
)]
= 1 +
λ
2
E
i
[(
ˆ
W
i
E
i
[
ˆ
W
i
])
2
]
g
(
λ
ˆ
r
(
i
)
m

t
/P
i
)
1 +
λ
2
ˆ
r
i
ˆ
r
(
i
)
m

t
P
i
g
(
ˆ
r
(
i
)
m

t
λ
P
i
)
(26)
Eric Zhao, Anqi Liu, Anima Anandkumar, Yisong Yue
where the first inequality follows from our range bound and the second follows from our variance bound. The first
claim then follows from the definition of
g
(
x
) and the facts that
exp
(
x
)
x
1
x
2
/
(2(1
x/
3)) for 0
x <
3
and log(1 +
x
)
x
. The second claim follows from definition of
ˆ
W
i
and the fact that
E
i
[
ˆ
W
i
] = ˆ
rf
(
X
i
,Y
i
).
The following lemma is an analogue of Lemma 7 in [Beygelzimer et al., 2010].
Lemma 3.
Pick any
t
0
,p
min
>
0
and let
E
be the joint event
1
n
n
i
=1
ˆ
W
i
n
i
=1
E
i
[
ˆ
W
i
]
(1 +
M
)
t
2
np
min
+
t
3
np
min
and
min
{
P
i
ˆ
r
(
i
)
m

t
: 1
i
n
E
i
[
W
i
]
6
= 0
}≥
p
min
(27)
Then
Pr(
E
)
e
t
where
M
:=
1
n
n
i
=1
ˆ
r
i
.
Proof.
We follow [Beygelzimer et al., 2010] and let
λ
:= 3
p
min
2
t
9
np
min
1 +
2
t
9
np
min
(28)
Note that 0
< λ <
3
p
min
. By Lemma 2, we know that if min
{
P
i
ˆ
r
(
i
)
m

t
: 1
i
n
E
i
[
ˆ
W
i
]
6
= 0
}≥
p
min
then
1
n
i
=1
log
E
i
[exp(
λ
(
W
i
E
i
[
W
i
]))]
1
n
n
i
=1
ˆ
r
i
ˆ
r
(
i
)
m

t
λ
2
P
i
(1
ˆ
r
(
i
)
m

t
λ
3
P
i
)
M
t
2
np
min
(29)
and
t
=
t
2
np
min
+
t
3
np
min
(30)
Let
E
be the event that
1
n
n
i
=1
(
ˆ
W
i
E
i
[
ˆ
W
i
])
1
n
i
=1
log
E
i
[exp(
λ
(
ˆ
W
E
i
[
ˆ
W
]))]
t
(31)
and let
E
′′
be the event
min
{
P
i
ˆ
r
(
i
)
m

t
: 1
i
n
E
i
[
ˆ
W
i
]
6
=
0
}≥
p
min
. Together, the above two equations imply
E
E
E
′′
. By [Zhang, 2005]’s lemmas 1 and 2, Pr(
E
)
Pr(
E
E
′′
)
Pr
(
E
)
e
t
.
The following is an immediate consequence of the previous lemma.
Lemma 4.
Pick any
t
0
and
n
1
. Assume
1
ˆ
r
(
i
)
m

t
P
i
r
max
for all
1
i
n
, and let
R
n
:=
max
{
ˆ
r
(
i
)
m

t
P
i
:
1
i
n
E
i
[
ˆ
W
]
6
= 0
}
{
1
}
. We have
Pr
(
1
n
n
i
=1
ˆ
W
i
1
n
n
i
=1
E
i
[
ˆ
W
i
]
(1 +
M
)
R
n
t
2
n
+
R
n
t
3
n
)
2(2 + log
2
r
max
)
e
t/
2
(32)
Proof.
This proof follows identically to [Beygelzimer et al., 2010]’s lemma 8.
We can finally bound ∆
4
by bounding the remaining free quantity
M
.
Lemma 5.
With probability at least
1
δ
, the following holds over all
n
1
and
h
H
:
|
4
|≤
(2 +
ˆ
θ
2
)
ε
n
P
min
,n
(
h
)
+
ε
n
P
min
,n
(
h
)
(33)
where
ε
n
:=
16 log(2(2+
n
log
2
n
)
n
(
n
+1)
|
H
|
)
n
and
P
min
,n
(
h
) = min
{
P
i
: 1
i
n
h
(
X
i
)
6
=
h
(
X
i
)
}
{
1
}
.
Active Learning under Label Shift
Proof.
We define the
k
-sized vector
̃
`
(
j
) =
1
n
n
i
=1
1
y
i
=
j
ˆ
θ
(
j
). Here,
v
(
j
) is an abuse of notation and denotes the
j
th element of a vector
v
. Note that we can write
M
by instead summing over labels,
M
=
1
n
n
i
=1
ˆ
θ
i
=
k
j
=1
̃
`
(
j
).
Applying the Cauchy-Schwarz inequality, we have that
1
n
n
i
=1
ˆ
θ
i
1
n
ˆ
θ
2
̇
`
2
where
̇
`
(
j
) is another
k
-sized
vector where
̇
`
(
j
) :=
n
i
=1
1
y
i
=
j
. Since
̇
`
2
n
, we have that
M
1 +
ˆ
θ
2
. The rest of the claim follows by
lemma 4 and a union bound over hypotheses and datapoints.
The term ∆
1
is be bounded with Theorem 1. We now bound ∆
2
. This is a simple generalization bound of an
importance weighted estimate of
f
.
Lemma 6.
For any
δ >
0
, with probability at least
1
δ
, then for all
n
1
,
h
H
:
|
2
|≤
2
d
(
P
test
,P
src
) log(
2
n
|
H
|
δ
)
3
n
+
2
d
2
(
P
test
,P
src
) log(
2
n
|
H
|
δ
)
n
(34)
Proof.
This inequality is a direct application of Theorem 2 from [Cortes et al., 2010].
The following lemma bounds the remaining term ∆
1
.
Lemma 7.
For all
n
1
,h
H
:
|
1
|≤‖
r
s

m
err
(
h
0
,r
s

m
)
(35)
Proof.
This inequality follows from our Lemma 1 and [Azizzadenesheli et al., 2019]’s Theorem 2.
Theorem 5 follows by applying a triangle inequality over ∆
1
,
2
,
3
,
4
. If a warm start of
m
datapoints sampled
from
P
warm
is used, the deviation bound is instead:
|
err
(
h,Z
1:
n
)
err
(
h
,Z
1:
n
)
err
(
h
) +
err
(
h
)
|
≤O
(
(2 +
n
θ
u

t
2
+
m
θ
w

t
2
n
+
m
)
ε
n
P
min
,n
(
h
)
+
ε
n
P
min
,n
(
h
)
+
2
d
(
P
test
,P
src
) log(
2
n
|
H
|
δ
)
3(
n
+
m
)
+
2
d
2
(
P
test
,P
src
) log(
2
n
|
H
|
δ
)
n
+
m
+
n
n
+
m
r
s

m
err(
h
0
,r
s

m
)
+
n
σ
min
θ
m

t
2
log
(
nk
δ
)
λn
+
log
(
n
δ
)
λn
+
log
(
n
δ
)
n
+
θ
s

m
err(
h
0
,r
m

t
)
The only change is that variance and subsampling terms are scaled by
n
n
+
m
, both of which disappear in the limit
where
n >> m
. For the remainder of this proof, we continue to set
m
= 0.
Theorem 2 follows by replacing the deviation bound in [
Beygelzimer et al., 2010
]’s Theorem 2 with our Theorem
5. Theorem 3 similarly follows from [
Beygelzimer et al., 2010
]’s Theorem 3 but with two additions. First,
λn
datapoints are sampled for label shift estimation. Second, the number of datapoints which are either accepted or
rejected by the active learning algorithm can be much smaller than the number of datapoints sampled from
P
src
due to subsampling. We can determine this proportion with an upper-tail Chernoff bound.
Lemma 8.
When
 <
2
(
2
e
1)
/
r
s

m
, given
n
datapoints from
P
src
, subsampling will yield
n
where,
Pr
(
n
n
r
s

m
+ log
2
(
1

))

(36)
Proof.
The number of subsampled datapoints is sum of independent Bernoulli trials with mean
μ
,
μ
=
E
y
P
src
[
P
ss
(
y
)] =
E
y
P
src
[
C
P
med
(
y
)
P
src
(
y
)
]
=
E
y
P
med
[
C
] =
C
(37)
where
C
is a constant such that
C
P
med
(
y
)
P
src
(
y
)
1 for all labels
y
. Thus,
μ
=
C
1
/
r
s

m
.
Eric Zhao, Anqi Liu, Anima Anandkumar, Yisong Yue
9 Supplementary Experiments
9.1 NABirds Regional Species Experiment
We conduct an additional experiment on the NABirds dataset using the grandchild level of the class label hierarchy,
which results in 228 classes in total. These classes correspond to individual species and present a significantly
larger output space than considered in Figure 6. For realism, we retain the original training distribution in the
dataset as the source distribution; sampling I.I.D. from the original split in the experiment. To simulate a scenario
where a bird species classifier is adapted to a new region with new bird frequencies, we induce an imbalance in the
target distribution to render certain birds more common than others. Table 1 demonstrates the average accuracy
of our framework at different label budgets. We observe consistent gains in accuracy at different label budgets.
Strategy
Acc (854 Labels)
Acc (1708)
Acc (3416)
MALLS (MC-D)
0.51
0.53
0.56
Vanilla (MC-D)
0.46
0.48
0.50
Random
0.38
0.40
0.42
Table 1: NABirds (species) Experiment Average Accuracy
9.2 Change in distribution
To further analyze the learning behavior of MALLS, we can analyze the label distribution of datapoints selected
by the active learner. In Figure 8, MC-Dropout, Max-Margin and Max-Entropy strategies are evaluated on
CIFAR100 under
canonical label shift
. By analyzing the uniformity bias and the rate of convergence to the target
distribution, we can observe that MALLS exhibits a unique sampling bias which cannot be explained away as
simply a class-balancing bias. This indicates that MALLS may be successful in recovering information from
distorted uncertainty estimates.
Figure 8: Average L2 distance between labeled class distribution and uniform/target distribution with 95%
confidence intervals on 10 runs of experiments on CIFAR100 in the
canonical label shift
setting. MALLS (denoted
by ALLS) converges to the target label distribution slower than vanilla active learning but with a similar uniform
sampling bias. This suggests MALLS leverages a sampling bias different from that of vanilla active learning or
naive class-balanced sampling.
Active Learning under Label Shift
10 Experiment Details
We list our detailed experimental settings and hyperparameters which are necessary for reproducing our results.
Across all experiments, we use a stochastic gradient descent (SGD) optimizer with base learning rate 0
.
1, finetune
learning rate 0
.
02, momentum rate 0
.
9 and weight decay 5e
4
. We also share the same batch size of 128 and
RLLS [
Azizzadenesheli et al., 2019
] regularization constant of 2e
6
across all experiments. As suggested in our
analysis, we employ a uniform medial distribution to achieve a balance between distance to the target and distance
to the source distributions. For computational efficiency, all experiments are conducted with minibatch-mode
active learning. In other words, rather than retraining models upon each additional label, multiple labels are
queried simultaneously. Table 2 lists the specific hyperparameters for each experiment, categorized by dataset.
Table 3 lists the specific parameters of simulated label shifts (if any) created for individual experiments. Figure
numbers reference figures in the main paper and appendix. “Dir” is short for Dirichlet distribution, “Inh” is
short for inherent distribution, and “Uni” is short for uniform distribution.
Dataset
Model
# Datapoints
Epochs (init/fine)
# Batches
# Classes
NABirds1
Resnet-34
30,000
60/10
20
21
NABirds2
Resnet-34
30,000
60/10
20
228
CIFAR
Resnet-18
40,000
80/10
40
10
CIFAR100
Resnet-18
40,000
80/10
40
100
Table 2: Dataset-wide statistics and parameters
Figure
Dataset
Warm Ratio
Source Dist
Target Dist
Canonical?
Dirichlet
α
5(a)
MNIST
0.1
Dir
Dir
Yes
0.1
5(b)
CIFAR
0.4
Dir
Dir
Yes
0.4
6(a-b)
CIFAR100
0.4
Dir
Dir
Yes
0.1
6(c-d)
NABirds1
1.0
Inh
Inh
No
N/A
7(a-b)
CIFAR
0.3
Dir
Dir
Yes
0.7
7(c)
CIFAR
0.3
Dir
Dir
Yes
0.7
7(d)
CIFAR100
0.4
Dir
Dir
Yes
0.1
8(a)
CIFAR100
0.4
Dir
Dir
Yes
3.0
8(b)
CIFAR100
0.4
Dir
Dir
Yes
0.7
8(c)
CIFAR100
0.4
Dir
Dir
Yes
0.4
8(d)
CIFAR100
0.4
Dir
Dir
Yes
0.1
9(a)
CIFAR100
0.4
Dir
Uni
No
1.0
9(b)
CIFAR100
0.3
Uni
Dir
No
0.1
9(c-d)
CIFAR100
0.4
Dir
Dir
Yes
0.1
T1(g-i)
NABirds1
1.0
N/A
Dir
No
0.1
8
CIFAR100
0.4
Dir
Dir
Yes
0.1
Table 3: Label Shift Setting Parameters (in order of paper)