diff --git a/src/convenience.jl b/src/convenience.jl index 3576960..ba12398 100644 --- a/src/convenience.jl +++ b/src/convenience.jl @@ -27,10 +27,10 @@ Returns rmsecv, meanrmsecv, where rmsecv is kmax x n_splits matrix, and meanrmse """ function calculateRMSECV(X, y, regfunction, funcargs; n_splits=1, n_folds=5, rngseed=42, emscpreproc=false, emscdegree=6) -splits = createCVSplitInds(X, n_splits, n_folds, rngseed); -rmsecv, n_comps, meanrmsecv = calculateRMSECV(X, y, splits, regfunction, funcargs, emscpreproc=emscpreproc, emscdegree=emscdegree); +splits = createCVSplitInds(X, n_splits, n_folds, rngseed); +rmsecv, meanrmsecv = calculateRMSECV(X, y, splits, regfunction, funcargs, emscpreproc=emscpreproc, emscdegree=emscdegree); -return rmsecv, n_comps, meanrmsecv +return rmsecv, meanrmsecv end function calculateRMSECV(X, y, splits, regfunction, funcargs; emscpreproc=false, emscdegree=6) @@ -41,7 +41,6 @@ B, _ = regfunction(X, y, funcargs...); # <- Slow, but works in general. Mayb kmax = size(B, 2); rmsecv = zeros(kmax, n_splits); n_folds = length(unique(splits[:,1])) -n_comps = convert(Vector{Int64}, zeros(n_splits)); for i in 1:n_splits @@ -65,13 +64,12 @@ for i in 1:n_splits end end - _, n_comps[i] = findmin(rmsecv[:,i]); end rmsecv = sqrt.(rmsecv ./ n); meanrmsecv = mean(rmsecv, dims=2); -return rmsecv, n_comps, meanrmsecv +return rmsecv, meanrmsecv end """