Transfer learning: domain adaptation by instance-reweighting

In supervised learning, we typically train a model on labeled data (we know “the truth”) and eventually apply the model to unlabeled data (we do not know “the truth”). For example, a spam filtering model might be trained on a public email database, with emails clearly labeled as “spam” or “non-spam”. However, the model will eventually be applied to a personal inbox, where emails are not labeled. An interesting example from the life sciences is training a classifier for predicting protein interactions in some species for which biologically validated interactions are known, and applying this classifier to other species, for which no such data validated interactions exist.

But what if, in addition to missing labels, the data we apply our model to (the “target data”) is just very different from our training data (the “source data”)? For example, in a personal inbox both spam and non-spam emails may have very different characteristics compared to the emails in the public email database. Also, in protein interaction prediction, it could important to consider that species can have very different proteomes, and therefore also different protein interactions.

In cases such as the two outlined above, what we would like to do, is make sure that our model performs well on the target data, while still training on the source data. How can we do that?

The covariate shift assumption

In cases where the target data is very different from the source data, we need to think about domain adaptation. Domain adaptation can be seen as a specific case of transfer learning, and can be applied in situations where there is only a covariate shift between source and target data:

P_S(X) \neq P_T(X) \textrm{, but } P_S(Y|X=x) = P_T(Y|X=x)

Here,

  • P_S(X) represents the marginal covariate distribution of source instances.
  • P_T(X) represents the marginal covariate distribution of target instances.
  • P_S(Y|X=x) represents the conditional class distribution given source instance x.
  • P_S(Y|X=x) represents the conditional class distribution given target instance x.

In words, the first part (P_S(X) \neq P_T(X)) simply means that in general you find different emails in the public email database than in your own inbox: the target data is distributionally different from the source data. The second part (P_S(Y|X=x) = P_T(Y|X=x)) means that the class probability of an instance x is independent of whether x originated from the source or target distribution. In the example of spam filtering: if you have a specific email, then the probability of this email being spam stays the same, regardless of whether it originated from the public email database or from your personal inbox.

Model misspecification

Now you might think: If we train a classifier, we do not care about whether the source and target covariate distributions are different? We are only interested in the conditional class distribution P_T(Y|X=x), and because the assumption is that P_S(Y|X=x) = P_T(Y|X=x), we can simply train a classifier on the source data to obtain optimal performance on the target data? Well, ideally, yes.  However, it was shown that under model misspecification, covariate shift can in fact be a problem, and the thing is that typically, models are indeed misspecified: We do not know what function generated our data, but most likely it was not precisely of the form that we used for fitting the data. For example, fitting a line (e.g. using logistic regression) to separate the classes in the following case would be a clear case of model misspecification:

modelmisspecification
Our data: two features (x1 and x2) and two classes (class1 and class2). It is obvious what function generated this data (points belong to class2 if x1 < 0 and x2 > 0). Therefore, fitting a line is a clear case of model misspecification.

Model misspecification in a transfer learning setting

Back to transfer learning. Remember that in our transfer learning setting, we are training on labeled source data, and will apply the resulting classifier to unlabeled target data. Moreover, the unlabeled target data is distributionally different from the source data. Let’s extend the above example, and separate the data into source and target data:

source_and_target_data
Labeled source data, and unlabeled target data: target samples generally have higher x1 values than source samples, but target class labels are unknown.

You can see that the target data is differently distributed compared to the source data: it tends to have higher x1 values, implying that P_S(X) \neq P_T(X). Furthermore, target class labels are unknown. Therefore, in training a classifier separating class1 from class2, the only thing that we can do is train on the labeled source data. Training a logistic regression classifier on the source data gives the decision boundary in the left two plots:

unweighted_training
Left: Optimal decision boundary for the source data. Middle: The same source data decision boundary is clearly not optimal for the target data. Right: The optimal decision boundary is in fact much steeper.

The decision boundary indeed seems optimal for the source data (left plot). However, it is far from optimal for the target data (middle plot). In fact, the optimal decision boundary of the target data is much steeper (right plot). In this transfer learning setting, the model misspecification implies that it is not possible to find a logistic regression parameterization \theta, such that P_S(Y|X=x, \theta) = P_T(Y|X=x, \theta) for all x. In other words, the optimal model for the source data is different from the optimal model for the target data. This brings us to the following question: Is there a way to train a classifier on the source data, while trying to optimize for performance on the target data?

Re-weighted empirical risk minimization

It turns out, yes: We can train on the source data while optimizing for performance on the target data. Let’s first go through some math to show how. (or skip to an implementation using R if you are not interested) Recall that true risk minimization finds a parameterization \theta = \theta^\ast, such that the expected value of the loss function l(x,y,\theta) under the true joint distribution P(x,y) over X and Y is minimized:

true_risk_minimization

Empirical risk minimization approximates true risk minimization by using the empirical joint distribution over X and Y, because the true joint distribution P(x,y) is unknown:

empirical_risk_minimization

Note that in the above, (x_i, y_i) \sim P(x,y).

In our domain adaptation problem, we have two joint distributions, namely the source distribution P_S(x,y) and the target distribution P_T(x,y). In training on the empirical source distribution, we want optimize for performance on the target distribution. To do this, we use our previous assumption (P_S(Y|X=x)=P_T(Y|X=x) for all x), and apply the following trick for transferring knowledge from our source domain to our target domain:

reweighted_empirical_risk_minimization

Note that in the above, (x_i, y_i) \sim P_S(x,y). So we started with the normal formulation of true risk minimization under the target distribution, and showed that we can approximate this by re-weighting each source instance (x_i, y_i) in an empirical risk minimization under the source distribution! More specifically, each instance (x_i, y_i) needs to be re-weighted by the ratio of the marginal covariate probabilities \frac{P_T(x_i)}{P_S(x_i)}. Interestingly, the above suggests that doing re-weighted empirical risk minimization is essentially the same as performing importance sampling for computing the expected value of the loss function under the target joint distribution, with the additional assumption that the conditional class probabilities between the source and target data are the same.

How to estimate the marginal probability ratio?

The problem with the above result, is that P_T(x) and P_S(x) are difficult to determine. However, we can avoid computing these probabilities directly, by interpreting this marginal probability ratio as another probability ratio: the ratio of probabilities that x_i comes from the target data and from the source data, weighted by the ratio of the source data size N_S and target data size N_T:

\frac{P_T(x_i)}{P_S(x_i)} \approx \frac{N_S}{N_T} \frac{P(x_i \textrm{comes from the target data})}{P(x_i \textrm{comes from the source data})}

Why is this? Well, here’s an argument for the discrete case. Suppose we independently draw two random samples, one of size N_S from the source distribution, and one of size N_T from the target distribution. We merge these two random samples, and from this sample of size N_S + N_T we draw a single instance x_i. What is the probability that x_i originated from the target distribution? If n^S_i is the number of occurrences of x_i in the random source sample of size N_S, and n^T_i is the number of occurrences of x_i in the random target sample of size N_T, then the following represents the probability that x_i originated from the target distribution:

P(x_i \textrm{comes from the target data}) = \frac{n^T_i}{n^T_i+n^S_i}

Similar for the source data:

P(x_j \textrm{comes from the source data}) = \frac{n^S_i}{n^S_i+n^T_i}

Now what’s the expected value for their ratio?

instance_reweighting

So all we need to do, is estimate for each source instance the probability that it originated from the target class. How do we do that? One straightforward way of estimating these probabilities is to train a naturally probabilistic classifier, such as a logistic regression classifier.

A simple way of implementing re-weighting

We now have everything in place to train a classifier on the source data, while optimizing for performance on the target data:

  1. Compute the source instance weights:
    1. Train a logistic regression classifier separating source data from target data.
    2. Apply the classifier to each source instance $x^S_i$, thus computing p_i = P(x^S_i\textrm{ comes from the target data}).
    3. For each source instance x^S_i compute the instance weight w_i as w_i = \frac{p_i}{1-p_i}.
  2. Train a logistic regression classifier on the source data, separating class1 from class2, while re-weighting each source instance x^S_i by w_i.

In R, this could look as follows. First define some functions:


# Function to generate random data.
generate_data <- function(n) {

  range_x1 <- 1
  range_x2 <- 1

  # The features.
  x1 <- runif(n, -range_x1, range_x1)
  x2 <- runif(n, -range_x2, range_x2)

  # Generate class labels.
  y <- (x1 < 0 & x2 > 0) + 1

  # Generate source and target labels.
  prob <- (x1 + range_x1) / range_x1 / 2
  s <- 1:n %in% sample(n, n/2, prob = prob^5) + 1

  data.frame(
    x1 = x1,
    x2 = x2,
    y = factor(c("class1", "class2")[y]),
    s = factor(c("source", "target")[s])
  )
}

# Function to fit a logistic regression classifier,
# possibly weighted.
fitLRG <- function(df, weights = rep(1, nrow(df))) {
  # Compute the class weights.
  tab <- 1 / table(df$y)
  # Multiply by the instance weights
  weights <- as.numeric(weights * tab[match(df$y, names(tab))])
  # Fit a logistic regression model on the
  # source class label.
  fit <- coef(glmnet(
    x = as.matrix(df[, c("x1", "x2")]),
    y = df$y,
    lambda = seq(1, 0, -0.01),
    weights = weights,
    family = "binomial"
  ))
  fit[, ncol(fit)]
}

# Function to compute instance weights
compute_instance_weights <- function(df) {
  # Fit a logistic regression model on the
  # source/target indicator.
  fit <- glmnet(
    x = as.matrix(df[, c("x1", "x2")]),
    y = df$s,
    lambda = seq(1, 0, -0.01),
    family = "binomial"
  )
  # For each instance, compute the probability
  # that it came from the target data
  p <- predict(
    fit,
    newx = as.matrix(df[,c("x1", "x2")]),
    type = "response"
  )
  p <- p[, ncol(p)]
  p / (1 - p)
}

Now let’s do some transfer learning:


# Load a package for fitting logistic regression models.
library(glmnet)

# Set the seed for reproducibility.
set.seed(1)

# Generate some random data.
df <- generate_data(1e3)

# Train an unweighted classifier.
fit_unweighted <- fitLRG(df[df$s == "source",])

# Train a re-weighted classifier:
# 1. Compute the instance weights
weights <- compute_instance_weights(df)
# 2. Train a weighted classifier
fit_reweighted <- fitLRG(
  df[df$s == "source",],
  weights = weights[df$s == "source"]
)

The results confirm that instance re-weighting indeed leads to a decision boundary that is much closer to the optimal decision boundary for the target data:

reweighted_training
Unweighted training on the source data is clearly not optimal for the target data (left plot). Re-weighted training on the source data separates the target data much better (middle plot), and in fact results in a decision boundary that lies very close to the optimal decision boundary (right plot).

 

 

Bias-variance decomposition

In machine learning and statistics, predictive models are typically not perfect. But what does ‘not perfect’ mean? For example, take the case of weather forecasting, and imagine a weatherman who is simply overly optimistic, and too often predicts sunny weather and high temperatures. On the other hand, one could also think of a weatherman who, for some complicated reasons, on some days grossly overestimates the temperature, while on other days grossly underestimating it, such that on average he is closer to the true temperature than the first weatherman. In absolute terms, the errors in degrees Celsius/Fahrenheit made by both weathermen could be comparable. However, the first weathermen may be considered more reliable because he is consistent, but the second may be considered more reliable because on average he is closer to the truth. This is analogous to the bias-variance trade-off in machine learning and statistics: if we see a model making errors, is this generally the result of bias (e.g. consistently predicting high temperatures), of variance (e.g. wildly varying temperature predictions across days), of noise (weather is just unpredictable…), or of a combination of all?

It turns out that mathematically speaking, the error made by a model can indeed be decomposed into two terms corresponding to bias and variance, plus one additional term, representing noise (e.g. daily fluctuations in temperature we cannot account for). To prove this is not difficult. However, as I found many proofs online somewhat lacking in detail, I have written my own, which will consequently be one of the longer proofs you will find online:

 

bias_variance_noise_decomposition

A nice example for demonstrating bias and variance, is the estimation of population variance from sample variance. For example, suppose that we want to estimate the variance in height of all American men, from knowing the heights of only 3 American men? In other words, we want to use the sample variance (of 3 American men) as an approximation of the population variance (of all American men). In principle, for points x_1, x_2, \dots, x_n, the variance is computed as follows:

\frac{1}{n}\sum_{1}^{n}\left(x_i - \bar{x}\right)^2

 

However, it turns out that if you want to use the sample variance as an approximation of the population variance, the above calculation is biased, and you need what’s called Bessel’s correction in calculating the variance:

\frac{1}{n-1}\sum_{1}^{n}\left(x_i - \bar{x}\right)^2

 

The above calculation is not biased. Many proofs of this can be found online, three of them on Wikipedia. I will give a demonstration of this bias by simulation: 10000 times, a sample of size 3 is drawn from the standard normal distribution (zero mean and unit variance). For each of the 10000 samples, the variance is calculated with and without Bessel’s correction. The results are summarized in the density plot below.

bessel_correction
A sample of size 3 is drawn from the standard normal distribution, and the variance is calculated, both with and without Bessel’s correction. This is repeated 10000 times.

The plot above confirms that in calculating the variance using Bessel’s correction, the average of the 10000 approximations of the population variance is very close to the true value (1), much closer in fact than when not using Bessel’s correction. However, its spread is larger compared to not using Bessel’s correction. In other words, Bessel’s correction leads to lower bias but higher variance. Therefore, Bessel’s correction is weatherman no. 2.

A slightly more involved example of the bias-variance trade-off is the following. Imagine (“population”) data that was generated using the sine function on the domain [-\pi, \pi], and adding some Gaussian noise. A complete cycle could look like this:

data
A complete cycle with and without Gaussian noise.

Now, suppose we would be given three training data points, generated in the way described above:

  1. Randomly sample three points x_1, x_2, x_3 from the domain [-\pi, \pi].
  2. Generate Gaussian noise \epsilon_1, \epsilon_2, \epsilon_3 , one for each of the x_i.
  3. Compute y_i = sin(x_i) + \epsilon_i

I should emphasize that we do not actually know the three data points were generated using the sine function, we just see the three data points. Based on these three points (x_1, y_1), (x_2, y_1), (x_3, y_3), we want to fit two models, a very simple one (a constant, or 0th-order polynomial), and a more complex one (a line, or 1st-order polynomial). Suppose furthermore that we would repeat the sampling and fitting 1000 times, and each time we measure the error with respect to the “population” data. For fitting the constant, the results may look like this:

model_fits
Fitting a constant to three points, randomly generated by the sine function and adding noise; repeated 1000 times.

For fitting the line, the results may look like this:

model_fits
Fitting a line to three points, randomly generated by the sine function and adding noise; repeated 1000 times.

It can already be seen that fitting a line seems to better capture the overall increasing trend of the data than does fitting a constant. However, this comes at the expense of high variance in the line fits compared to the constant fits.

This bias-variance trade-off across 1000 fits is even more clearly visualized by summarizing the empirical errors across the 1000 fits, with the empirical error decomposed into bias and variance, respectively. For fitting constants, the decomposition looks like this:

bias_variance_noise
Bias-variance decomposition of the errors observed when fitting a constant 1000 times.

For fitting a line, the decomposition looks as follows:

bias_variance_noise
Bias-variance decomposition of the errors observed when fitting a line 1000 times.

It is clear that although the complex model (the line) on average has a better fit (i.e. low bias), this comes at the expense of much larger variance between the individual fits. Hence, in statistical and machine learning modeling, the impact of model complexity on bias and variance should always be considered carefully.