Minor change to calculateRMSECV

This commit is contained in:
Joakim Skogholt 2023-05-13 07:52:47 +02:00
parent f6e7820917
commit 0941657ced

View file

@ -27,10 +27,10 @@ Returns rmsecv, meanrmsecv, where rmsecv is kmax x n_splits matrix, and meanrmse
"""
function calculateRMSECV(X, y, regfunction, funcargs; n_splits=1, n_folds=5, rngseed=42, emscpreproc=false, emscdegree=6)
splits = createCVSplitInds(X, n_splits, n_folds, rngseed);
rmsecv, n_comps, meanrmsecv = calculateRMSECV(X, y, splits, regfunction, funcargs, emscpreproc=emscpreproc, emscdegree=emscdegree);
splits = createCVSplitInds(X, n_splits, n_folds, rngseed);
rmsecv, meanrmsecv = calculateRMSECV(X, y, splits, regfunction, funcargs, emscpreproc=emscpreproc, emscdegree=emscdegree);
return rmsecv, n_comps, meanrmsecv
return rmsecv, meanrmsecv
end
function calculateRMSECV(X, y, splits, regfunction, funcargs; emscpreproc=false, emscdegree=6)
@ -41,7 +41,6 @@ B, _ = regfunction(X, y, funcargs...); # <- Slow, but works in general. Mayb
kmax = size(B, 2);
rmsecv = zeros(kmax, n_splits);
n_folds = length(unique(splits[:,1]))
n_comps = convert(Vector{Int64}, zeros(n_splits));
for i in 1:n_splits
@ -65,13 +64,12 @@ for i in 1:n_splits
end
end
_, n_comps[i] = findmin(rmsecv[:,i]);
end
rmsecv = sqrt.(rmsecv ./ n);
meanrmsecv = mean(rmsecv, dims=2);
return rmsecv, n_comps, meanrmsecv
return rmsecv, meanrmsecv
end
"""