TRSegCVUpdateFair

This commit is contained in:
Joakim Skogholt 2024-04-25 17:23:45 +02:00
parent f50a442350
commit 82e2cd4496
2 changed files with 42 additions and 4 deletions

View file

@ -49,8 +49,8 @@ for i = 1:n
ys = ydata .- my;
U, s, V = svd(Xs, full=false);
denom = broadcast(.+, broadcast(./, lambdasu, s'), s')'; # denom = lambda/s + lambda = (lambda + s's) / lambda
denom2 = broadcast(.+, ones(n-1), broadcast(./, lambdasu', s.^2)) # denom2 = 1 + lambda/(s's) = (s's + lambda) / (s's)
denom = broadcast(.+, broadcast(./, lambdasu, s'), s')';
denom2 = broadcast(.+, ones(n-1), broadcast(./, lambdasu', s.^2))
# Calculating regression coefficients and residual
bcoeffs = V * broadcast(./, (U' * ys), denom) .+ bOld .- V * broadcast(./, V' * bOld, denom2);
@ -511,9 +511,9 @@ for j = 1:length(lambdas)
ydata = y[vec(.!inds)];
mX = mean(Xdata, dims=1);
my = mean(ydata);
my = mean(ydata);
Xs = Xdata .- mX;
ys = ydata .- my;
ys = ydata .- my;
betas = [Xs; sqrt(lambdas[j]) * I(p)] \ [ys; sqrt(lambdas[j]) * bOld];
rmsecvman[j] += sum((y[vec(inds)] - ((X[vec(inds),:] .- mX) * betas .+ my)).^2);
@ -525,3 +525,40 @@ rmsecvman = sqrt.(1/n .* rmsecvman);
return rmsecvman
end
"""
K-fold CV for the Ridge regression problem, using the 'SVD-trick' for calculating the regression coefficients.
"""
function TRSegCVUpdateFair(X, y, lambdas, cv, bOld)
cv = cvfolds
n, p = size(X);
rmsecvman = zeros(length(lambdasu));
nfolds = length(unique(cv));
for i = 1:nfolds
inds = (cv .== i);
Xdata = X[vec(.!inds),:];
ydata = y[vec(.!inds)];
mX = mean(Xdata, dims=1);
my = mean(ydata);
Xs = Xdata .- mX;
ys = ydata .- my;
U, s, V = svd(Xs, full=false);
denom = broadcast(.+, broadcast(./, lambdasu, s'), s')';
denom2 = broadcast(.+, ones(n-sum(inds)), broadcast(./, lambdasu', s.^2));
# Calculating regression coefficients
bcoeffs = V * broadcast(./, (U' * ys), denom) .+ bOld .- V * broadcast(./, V' * bOld, denom2);
rmsecvman += sum((y[vec(inds)] .- ((X[vec(inds),:] .- mX) * bcoeffs .+ my)).^2, dims=1)';
end
rmsecvman = sqrt.(1/n .* rmsecvman);
return rmsecvman
end

View file

@ -26,6 +26,7 @@ export plegendre
export TRLooCVUpdateFair
export TRLooCVUpdateNaive
export TRSegCVUpdateNaive
export TRSegCVUpdateFair
include("convenience.jl")
include("TR.jl")