Transfer learning: domain adaptation by instance-reweighting

In supervised learning, we typically train a model on labeled data (we know “the truth”) and eventually apply the model to unlabeled data (we do not know “the truth”). For example, a spam filtering model might be trained on a public email database, with emails clearly labeled as “spam” or “non-spam”. However, the model will eventually be applied to a personal inbox, where emails are not labeled. An interesting example from the life sciences is training a classifier for predicting protein interactions in some species for which biologically validated interactions are known, and applying this classifier to other species, for which no such data validated interactions exist.

But what if, in addition to missing labels, the data we apply our model to (the “target data”) is just very different from our training data (the “source data”)? For example, in a personal inbox both spam and non-spam emails may have very different characteristics compared to the emails in the public email database. Also, in protein interaction prediction, it could important to consider that species can have very different proteomes, and therefore also different protein interactions.

In cases such as the two outlined above, what we would like to do, is make sure that our model performs well on the target data, while still training on the source data. How can we do that?

The covariate shift assumption

In cases where the target data is very different from the source data, we need to think about domain adaptation. Domain adaptation can be seen as a specific case of transfer learning, and can be applied in situations where there is only a covariate shift between source and target data:

P_S(X) \neq P_T(X) \textrm{, but } P_S(Y|X=x) = P_T(Y|X=x)

Here,

  • P_S(X) represents the marginal covariate distribution of source instances.
  • P_T(X) represents the marginal covariate distribution of target instances.
  • P_S(Y|X=x) represents the conditional class distribution given source instance x.
  • P_S(Y|X=x) represents the conditional class distribution given target instance x.

In words, the first part (P_S(X) \neq P_T(X)) simply means that in general you find different emails in the public email database than in your own inbox: the target data is distributionally different from the source data. The second part (P_S(Y|X=x) = P_T(Y|X=x)) means that the class probability of an instance x is independent of whether x originated from the source or target distribution. In the example of spam filtering: if you have a specific email, then the probability of this email being spam stays the same, regardless of whether it originated from the public email database or from your personal inbox.

Model misspecification

Now you might think: If we train a classifier, we do not care about whether the source and target covariate distributions are different? We are only interested in the conditional class distribution P_T(Y|X=x), and because the assumption is that P_S(Y|X=x) = P_T(Y|X=x), we can simply train a classifier on the source data to obtain optimal performance on the target data? Well, ideally, yes.  However, it was shown that under model misspecification, covariate shift can in fact be a problem, and the thing is that typically, models are indeed misspecified: We do not know what function generated our data, but most likely it was not precisely of the form that we used for fitting the data. For example, fitting a line (e.g. using logistic regression) to separate the classes in the following case would be a clear case of model misspecification:

modelmisspecification
Our data: two features (x1 and x2) and two classes (class1 and class2). It is obvious what function generated this data (points belong to class2 if x1 < 0 and x2 > 0). Therefore, fitting a line is a clear case of model misspecification.

Model misspecification in a transfer learning setting

Back to transfer learning. Remember that in our transfer learning setting, we are training on labeled source data, and will apply the resulting classifier to unlabeled target data. Moreover, the unlabeled target data is distributionally different from the source data. Let’s extend the above example, and separate the data into source and target data:

source_and_target_data
Labeled source data, and unlabeled target data: target samples generally have higher x1 values than source samples, but target class labels are unknown.

You can see that the target data is differently distributed compared to the source data: it tends to have higher x1 values, implying that P_S(X) \neq P_T(X). Furthermore, target class labels are unknown. Therefore, in training a classifier separating class1 from class2, the only thing that we can do is train on the labeled source data. Training a logistic regression classifier on the source data gives the decision boundary in the left two plots:

unweighted_training
Left: Optimal decision boundary for the source data. Middle: The same source data decision boundary is clearly not optimal for the target data. Right: The optimal decision boundary is in fact much steeper.

The decision boundary indeed seems optimal for the source data (left plot). However, it is far from optimal for the target data (middle plot). In fact, the optimal decision boundary of the target data is much steeper (right plot). In this transfer learning setting, the model misspecification implies that it is not possible to find a logistic regression parameterization \theta, such that P_S(Y|X=x, \theta) = P_T(Y|X=x, \theta) for all x. In other words, the optimal model for the source data is different from the optimal model for the target data. This brings us to the following question: Is there a way to train a classifier on the source data, while trying to optimize for performance on the target data?

Re-weighted empirical risk minimization

It turns out, yes: We can train on the source data while optimizing for performance on the target data. Let’s first go through some math to show how. (or skip to an implementation using R if you are not interested) Recall that true risk minimization finds a parameterization \theta = \theta^\ast, such that the expected value of the loss function l(x,y,\theta) under the true joint distribution P(x,y) over X and Y is minimized:

true_risk_minimization

Empirical risk minimization approximates true risk minimization by using the empirical joint distribution over X and Y, because the true joint distribution P(x,y) is unknown:

empirical_risk_minimization

Note that in the above, (x_i, y_i) \sim P(x,y).

In our domain adaptation problem, we have two joint distributions, namely the source distribution P_S(x,y) and the target distribution P_T(x,y). In training on the empirical source distribution, we want optimize for performance on the target distribution. To do this, we use our previous assumption (P_S(Y|X=x)=P_T(Y|X=x) for all x), and apply the following trick for transferring knowledge from our source domain to our target domain:

reweighted_empirical_risk_minimization

Note that in the above, (x_i, y_i) \sim P_S(x,y). So we started with the normal formulation of true risk minimization under the target distribution, and showed that we can approximate this by re-weighting each source instance (x_i, y_i) in an empirical risk minimization under the source distribution! More specifically, each instance (x_i, y_i) needs to be re-weighted by the ratio of the marginal covariate probabilities \frac{P_T(x_i)}{P_S(x_i)}. Interestingly, the above suggests that doing re-weighted empirical risk minimization is essentially the same as performing importance sampling for computing the expected value of the loss function under the target joint distribution, with the additional assumption that the conditional class probabilities between the source and target data are the same.

How to estimate the marginal probability ratio?

The problem with the above result, is that P_T(x) and P_S(x) are difficult to determine. However, we can avoid computing these probabilities directly, by interpreting this marginal probability ratio as another probability ratio: the ratio of probabilities that x_i comes from the target data and from the source data, weighted by the ratio of the source data size N_S and target data size N_T:

\frac{P_T(x_i)}{P_S(x_i)} \approx \frac{N_S}{N_T} \frac{P(x_i \textrm{comes from the target data})}{P(x_i \textrm{comes from the source data})}

Why is this? Well, here’s an argument for the discrete case. Suppose we independently draw two random samples, one of size N_S from the source distribution, and one of size N_T from the target distribution. We merge these two random samples, and from this sample of size N_S + N_T we draw a single instance x_i. What is the probability that x_i originated from the target distribution? If n^S_i is the number of occurrences of x_i in the random source sample of size N_S, and n^T_i is the number of occurrences of x_i in the random target sample of size N_T, then the following represents the probability that x_i originated from the target distribution:

P(x_i \textrm{comes from the target data}) = \frac{n^T_i}{n^T_i+n^S_i}

Similar for the source data:

P(x_j \textrm{comes from the source data}) = \frac{n^S_i}{n^S_i+n^T_i}

Now what’s the expected value for their ratio?

instance_reweighting

So all we need to do, is estimate for each source instance the probability that it originated from the target class. How do we do that? One straightforward way of estimating these probabilities is to train a naturally probabilistic classifier, such as a logistic regression classifier.

A simple way of implementing re-weighting

We now have everything in place to train a classifier on the source data, while optimizing for performance on the target data:

  1. Compute the source instance weights:
    1. Train a logistic regression classifier separating source data from target data.
    2. Apply the classifier to each source instance $x^S_i$, thus computing p_i = P(x^S_i\textrm{ comes from the target data}).
    3. For each source instance x^S_i compute the instance weight w_i as w_i = \frac{p_i}{1-p_i}.
  2. Train a logistic regression classifier on the source data, separating class1 from class2, while re-weighting each source instance x^S_i by w_i.

In R, this could look as follows. First define some functions:


# Function to generate random data.
generate_data <- function(n) {

  range_x1 <- 1
  range_x2 <- 1

  # The features.
  x1 <- runif(n, -range_x1, range_x1)
  x2 <- runif(n, -range_x2, range_x2)

  # Generate class labels.
  y <- (x1 < 0 & x2 > 0) + 1

  # Generate source and target labels.
  prob <- (x1 + range_x1) / range_x1 / 2
  s <- 1:n %in% sample(n, n/2, prob = prob^5) + 1

  data.frame(
    x1 = x1,
    x2 = x2,
    y = factor(c("class1", "class2")[y]),
    s = factor(c("source", "target")[s])
  )
}

# Function to fit a logistic regression classifier,
# possibly weighted.
fitLRG <- function(df, weights = rep(1, nrow(df))) {
  # Compute the class weights.
  tab <- 1 / table(df$y)
  # Multiply by the instance weights
  weights <- as.numeric(weights * tab[match(df$y, names(tab))])
  # Fit a logistic regression model on the
  # source class label.
  fit <- coef(glmnet(
    x = as.matrix(df[, c("x1", "x2")]),
    y = df$y,
    lambda = seq(1, 0, -0.01),
    weights = weights,
    family = "binomial"
  ))
  fit[, ncol(fit)]
}

# Function to compute instance weights
compute_instance_weights <- function(df) {
  # Fit a logistic regression model on the
  # source/target indicator.
  fit <- glmnet(
    x = as.matrix(df[, c("x1", "x2")]),
    y = df$s,
    lambda = seq(1, 0, -0.01),
    family = "binomial"
  )
  # For each instance, compute the probability
  # that it came from the target data
  p <- predict(
    fit,
    newx = as.matrix(df[,c("x1", "x2")]),
    type = "response"
  )
  p <- p[, ncol(p)]
  p / (1 - p)
}

Now let’s do some transfer learning:


# Load a package for fitting logistic regression models.
library(glmnet)

# Set the seed for reproducibility.
set.seed(1)

# Generate some random data.
df <- generate_data(1e3)

# Train an unweighted classifier.
fit_unweighted <- fitLRG(df[df$s == "source",])

# Train a re-weighted classifier:
# 1. Compute the instance weights
weights <- compute_instance_weights(df)
# 2. Train a weighted classifier
fit_reweighted <- fitLRG(
  df[df$s == "source",],
  weights = weights[df$s == "source"]
)

The results confirm that instance re-weighting indeed leads to a decision boundary that is much closer to the optimal decision boundary for the target data:

reweighted_training
Unweighted training on the source data is clearly not optimal for the target data (left plot). Re-weighted training on the source data separates the target data much better (middle plot), and in fact results in a decision boundary that lies very close to the optimal decision boundary (right plot).

 

 

2 thoughts on “Transfer learning: domain adaptation by instance-reweighting

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s