Quantcast
Channel: Mike Love’s blog
Viewing all articles
Browse latest Browse all 22

PCA on training and test data

$
0
0

In the past months, I heard some talks where dimension reduction (e.g. taking the top k principal components) was used on the full data set before splitting the data into training and test sets. My first intuition was that this kind of “peeking” at the test set would inflate the accuracy on the test set. On the other hand, one could argue that as long as the dimension reduction is unsupervised (not aware of class labels), then it should make no difference. After simulating some examples, I can’t find a situation where the accuracy on a test set used in dimension reduction is inflated relative to a “doubly” held-out test set.

Here is some code I have used to look into this. I make a mixture of Gaussian data in n dimensions, where some dimensions are better than others at separating the classes. I do PCA on the first 2/3 of the rows, train a logistic regression on the first 1/3, and compare predictions on the middle 1/3 and the last 1/3:


compare <- function(m = 300, n = 10, how.much.signal = 1, k = 1) {
  # some random gaussian data
  x <- matrix(rnorm(m*n), nrow=m, ncol=n)
  # add some random noise to half the rows
  y <- rep(0:1, times=m/2)
  ramped.noise <- sapply(1:n/n * how.much.signal, function(i) rnorm(m/2,i,1))
  x[y == 1,] <- x[y == 1,] + ramped.noise
  # find the top principal compent using the first 2/3 rows
  pc <- prcomp(x[1:(m*2/3),])
  # train a model on first 1/3
  data <- data.frame(y=y, (x %*% pc$rotation[,1:k]))
  fit <- glm(y ~ ., data=data[1:m/3,], family="binomial")
  # predict on all data
  prediction <- predict(fit, newdata=data) > 0
  # test on middle 1/3
  middle.third <- (m/3 + 1):(m*2/3)
  error.middle <- mean(y[middle.third] != prediction[middle.third])
  # test on last 1/3
  last.third <- (m*2/3 + 1):m
  error.last <- mean(y[last.third] != prediction[last.third])
  return(error.last - error.middle)
}

repeated.comparison <- replicate(1000,compare())

Confidence intervals for 1000 differences between the mean misclassification rates:

> t.test(repeated.comparison)$conf.int
[1] -0.0073857758  0.0002057758
attr(,"conf.level")
[1] 0.95


Viewing all articles
Browse latest Browse all 22

Trending Articles