diff --git a/src/convenience.jl b/src/convenience.jl index 566dbca..346146d 100644 --- a/src/convenience.jl +++ b/src/convenience.jl @@ -16,62 +16,6 @@ XVal = my_split["XVal"]; using Random -""" - calculateRMSECV(X, y, regfunction, funcargs, n_splits=1, n_folds=5, rngseed=42, emscpreproc=false, emscdegree=6) - calculateRMSECV(X, y, splits, regfunction, funcargs, emscpreproc=false, emscdegree=6) - -Calculates RMSECV. -Second function calculates RMSECV according to data split given by variable 'split' (which should be output of -function computeCVSplitInds). - -Returns meanrmsecv, rmsecv where rmsecv is kmax x n_splits matrix, and meanrmsecv is vector of length kmax. -""" -function calculateRMSECV(X, y, regfunction, funcargs, n_splits=1, n_folds=5, rngseed=42, emscpreproc=false, emscdegree=6) - -splits = createCVSplitInds(X, n_splits, n_folds, rngseed); -meanrmse, rmse = calculateRMSECV(X, y, splits, regfunction, funcargs, emscpreproc, emscdegree); - -return meanrmsecv, rmsecv -end - -function calculateRMSECV(X, y, splits, regfunction, funcargs, emscpreproc=false, emscdegree=6) - -n_splits = size(splits,2); -B, _ = regfunction(X, y, funcargs...); # <- Slow, but works in general. Maybe add some special cases for known functions? - println("her") -kmax = size(B, 2); -rmsecv = zeros(kmax, nsplits); -n_folds = length(unique(splits[:,1])) - -for i in 1:n_splits - - for j=1:n_folds - XTrain = X[splits[:,j] .!= j,:]; - XTest = X[splits[:,j] .== j,:]; - yTrain = y[splits[:,j] .!= j,:]; - yTest = y[splits[:,j] .== j,:]; - - if emscpreproc - XTrain, output = EMSC(XTrain, emscdegree, "svd", 1, -1, 0); # nRef, baseDeg, intF - XTest, _ = EMSC(XTest, output["model"]); - end - - B, _ = regfunction(XTrain, yTrain, funcargs...); - - for k=1:kmax - yTestPred, _ = predRegression(XTest, B[:,k], yTest); - rmsecv[k, i] += sum((yTestPred - yTest).^2); - end - end - -end - -rmsecv = sqrt.(rmsecv ./ n); -meanrmsecv = mean(rmsecv, dims=2); - -return meanrmsecv, rmsecv -end - """ function createCVSplitInds(X, n_splits=1, n_folds=5, rngseed=42)