Hopefully finished with general functions

This commit is contained in:
Joakim Skogholt 2024-04-25 18:09:32 +02:00
parent 3a4293b572
commit 1c0c6edf0f
2 changed files with 43 additions and 6 deletions

View file

@ -469,9 +469,9 @@ The LS problem is solved explicitly and no shortcuts are used.
""" """
function TRSegCVNaive(X, y, lambdas, cvfolds) function TRSegCVNaive(X, y, lambdas, cvfolds)
n, p = size(X); n, p = size(X);
rmsecvman = zeros(length(lambdas)); rmsecv = zeros(length(lambdas));
nfolds = length(unique(cvfolds)); nfolds = length(unique(cvfolds));
for j = 1:length(lambdas) for j = 1:length(lambdas)
for i = 1:nfolds for i = 1:nfolds
@ -489,9 +489,9 @@ for j = 1:length(lambdas)
end end
end end
rmsecvman = sqrt.(1/n .* rmsecvman); rmsecv = sqrt.(1/n .* rmsecv);
return rmsecvman return rmsecv
end end
""" """
@ -527,7 +527,7 @@ end
""" """
K-fold CV for the Ridge regression problem, using the 'SVD-trick' for calculating the regression coefficients. K-fold CV for the Ridge regression update problem, using the 'SVD-trick' for calculating the regression coefficients.
""" """
function TRSegCVUpdateFair(X, y, lambdas, cv, bOld) function TRSegCVUpdateFair(X, y, lambdas, cv, bOld)
@ -559,4 +559,39 @@ end
rmsecvman = sqrt.(1/n .* rmsecvman); rmsecvman = sqrt.(1/n .* rmsecvman);
return rmsecvman return rmsecvman
end
"""
K-fold CV for the Ridge regression problem, using the 'SVD-trick' for calculating the regression coefficients.
"""
function TRSegCVFair(X, y, lambdas, cv)
n, p = size(X);
rmsecv = zeros(length(lambdas));
nfolds = length(unique(cv));
for i = 1:nfolds
inds = (cv .== i);
Xdata = X[vec(.!inds),:];
ydata = y[vec(.!inds)];
mX = mean(Xdata, dims=1);
my = mean(ydata);
Xs = Xdata .- mX;
ys = ydata .- my;
U, s, V = svd(Xs, full=false);
denom = broadcast(.+, broadcast(./, lambdas, s'), s')';
denom2 = broadcast(.+, ones(n-sum(inds)), broadcast(./, lambdas', s.^2));
# Calculating regression coefficients
bcoeffs = V * broadcast(./, (U' * ys), denom);
rmsecv += sum((y[vec(inds)] .- ((X[vec(inds),:] .- mX) * bcoeffs .+ my)).^2, dims=1)';
end
rmsecv = sqrt.(1/n .* rmsecv);
return rmsecv
end end

View file

@ -27,6 +27,8 @@ export TRLooCVUpdateFair
export TRLooCVUpdateNaive export TRLooCVUpdateNaive
export TRSegCVUpdateNaive export TRSegCVUpdateNaive
export TRSegCVUpdateFair export TRSegCVUpdateFair
export TRSegCVNaive
export TRSegCVFair
include("convenience.jl") include("convenience.jl")
include("TR.jl") include("TR.jl")