Added k-fold cv

This commit is contained in:
Joakim Skogholt 2023-05-13 07:37:37 +02:00
parent 3254f847e0
commit f6e7820917
2 changed files with 93 additions and 14 deletions

View file

@ -28,7 +28,9 @@ export createDataSplitBinaryStratified
export importData
export calculateRMSE
export predRegression
export modelSelectionStatistics
export modelSelection
export createCVSplitInds
export calculateRMSECV
export PCR
export bidiag2

View file

@ -15,7 +15,64 @@ XVal = my_split["XVal"];
using Random
"""
calculateRMSECV(X, y, regfunction, funcargs; n_splits=1, n_folds=5, rngseed=42, emscpreproc=false, emscdegree=6)
calculateRMSECV(X, y, splits, regfunction, funcargs; emscpreproc=false, emscdegree=6)
Calculates RMSECV.
Second function calculates RMSECV according to data split given by variable 'split' (which should be output of
function computeCVSplitInds).
Returns rmsecv, meanrmsecv, where rmsecv is kmax x n_splits matrix, and meanrmsecv is vector of length kmax.
"""
function calculateRMSECV(X, y, regfunction, funcargs; n_splits=1, n_folds=5, rngseed=42, emscpreproc=false, emscdegree=6)
splits = createCVSplitInds(X, n_splits, n_folds, rngseed);
rmsecv, n_comps, meanrmsecv = calculateRMSECV(X, y, splits, regfunction, funcargs, emscpreproc=emscpreproc, emscdegree=emscdegree);
return rmsecv, n_comps, meanrmsecv
end
function calculateRMSECV(X, y, splits, regfunction, funcargs; emscpreproc=false, emscdegree=6)
n = size(X, 1);
n_splits = size(splits,2);
B, _ = regfunction(X, y, funcargs...); # <- Slow, but works in general. Maybe add some special cases for known functions?
kmax = size(B, 2);
rmsecv = zeros(kmax, n_splits);
n_folds = length(unique(splits[:,1]))
n_comps = convert(Vector{Int64}, zeros(n_splits));
for i in 1:n_splits
for j=1:n_folds
XTrain = X[splits[:,i] .!= j,:];
XTest = X[splits[:,i] .== j,:];
yTrain = y[splits[:,i] .!= j,:];
yTest = y[splits[:,i] .== j,:];
if emscpreproc
XTrain, output = EMSC(XTrain, emscdegree, "svd", 1, -1, 0); # nRef, baseDeg, intF
XTest, _ = EMSC(XTest, output["model"]);
end
B, _ = regfunction(XTrain, yTrain, funcargs...);
for k=1:kmax
yTestPred, _ = predRegression(XTest, B[:,k], yTest);
rmsecv[k, i] += sum((yTestPred - yTest).^2);
end
end
_, n_comps[i] = findmin(rmsecv[:,i]);
end
rmsecv = sqrt.(rmsecv ./ n);
meanrmsecv = mean(rmsecv, dims=2);
return rmsecv, n_comps, meanrmsecv
end
"""
function createCVSplitInds(X, n_splits=1, n_folds=5, rngseed=42)
@ -27,7 +84,6 @@ Extra samples if any are assigned to the lower indices.
"""
function createCVSplitInds(X, n_splits=1, n_folds=5, rngseed=42)
n = size(X,1)
splits = convert(Matrix{Int64}, zeros(n, n_splits)); # fold membership coded as 1, 2, ..., n_folds
fold_size = convert(Int64, floor(n/n_folds))
@ -54,13 +110,32 @@ end
"""
function modelSelectionStatistics(results)
function modelSelection(results, results_type, selection_rule="min")
Takes as input the rmse output from calculateRMSE and returns
the number of components minimising the validation error
together with the test set results.
### Inputs
- results : Matrix/Tensor with results, output from calculateRMSE[CV].
- results_type : "k-fold" or "train-val-test".
- selection_rule : "min" only for now. Can add 1 S.E., Chi^2, etc.
### Outputs:
- results_sel : Results for the selected number of components.
- n_comps : The number of components chosen for each split.
"""
function modelSelectionStatistics(results)
function modelSelection(results, results_type, selection_rule="min")
if results_type == "k-fold"
n_splits = size(results, 2);
n_comps = convert(Vector{Int64}, zeros(n_splits));
results_sel = zeros(n_splits);
for i in 1:n_splits
_, n_comps[i] = findmin(results[:,i]);
results_sel[i] = results[n_comps[i], i];
end
elseif results_type == "train-val-test"
n_iter = size(results, 3);
results_sel = zeros(n_iter);
@ -70,12 +145,14 @@ for i=1:n_iter
_, n_comps[i] = findmin(results[2,:,i]);
results_sel[i] = results[3, n_comps[i], i];
end
end
return results_sel, n_comps
end
"""
function predRegression(X, beta, y)
function predRegression(X::Vector{Float64}, beta, y)
@ -163,7 +240,7 @@ end
meanrmse = dropdims(mean(rmse, dims=3), dims=3);
return meanrmse, rmse
return rmse, meanrmse
end