Saturday, June 5, 2010

Cross Validation

One of a widely used method of data partitioning is the cross validation. The purpose of the cross validation is to assess the accuracy of a model, especially in prediction jobs. Naive researchers often test the accuracy of their model with the data that is used to build the model, and the result yield a low error. This is because the model has already seen that data, so it may not generalize (apply in other dataset) well.

Let's try a linear regression model. We are trying to fit a steam data with polyfit() function, and then try to evaluate the model over the domain of the data.

load steam;
[p,s]=polyfit(x,y,1);
xf=linspace(min(x),max(x));
yf=polyval(p,xf);
plot(x,y,'.'); hold on;
plot(xf,yf,'r'); hold off;
xlabel('Average Temperature (\circ F)');
ylabel('Steam per Month (pounds)');
axis tight;
% prediction error
yhat=polyval(p,x);
pe = mean((y-yhat).^2);

The output of the above code is
and the prediction error is
>> pe
pe =
    0.7289


However, as stated before. This should not be a way to evaluate the performance of the model. We should leave some data for testing, and take the rest for training the model.

For example, if we leave 20% of the data for testing, the cross validation method is shown as

%cross validation method 20% testing
indtest=2:5:25;
xtest=x(indtest);
ytest=y(indtest);
xtrain=x;
ytrain=y;
xtrain(indtest)=[];
ytrain(indtest)=[];
[p,s]=polyfit(xtrain,ytrain,1);
yf=polyval(p,xtest);
%prediction error
pe_cross=mean((ytest-yf).^2);

the prediction error for this cross validation method is
>> pe_cross
pe_cross =
    0.9930

The ultimate form of the cross validation method is called K-fold cross validation. Here, the data is partition into K part. One part is set aside for testing, and the other (K-1) part is for training (building the model). The process iteratively tests and trains the model K time to cover all the partitions.

k=5;
n=length(x);

indcross=1+mod(randperm(n),k);
pe_ncross=zeros(1,k);
for i=1:k
    indtest=indcross==i;
    xtest=x(indtest);
    ytest=y(indtest);
    xtrain=x(~indtest);
    ytrain=y(~indtest);
    [p,s]=polyfit(xtrain,ytrain,1);
    yf=polyval(p,xtest);
    %prediction error
    pe_ncross(i)=mean((ytest-yf).^2);
end
penf=mean(pe_ncross)


The average prediction error for this 5-fold cross validation is
>> penf
penf =
    0.8732


Compare to the Matlab's crossval() function..
fitf=@(xtrain,ytrain,xtest)(polyval(polyfit(xtrain,ytrain,1),xtest));
cvMse=crossval('mse',x',y','kfold',5,'predfun',fitf)

cvMse =
    0.8910

No comments:

Post a Comment