From 1c0c6edf0ff2c5e96f56541b0a6ac8e0fd5c393e Mon Sep 17 00:00:00 2001 From: Joakim Skogholt Date: Thu, 25 Apr 2024 18:09:32 +0200 Subject: [PATCH] Hopefully finished with general functions --- src/TR.jl | 47 +++++++++++++++++++++++++++++++++++++++++------ src/Ting.jl | 2 ++ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/src/TR.jl b/src/TR.jl index 0d62df6..36fbe64 100644 --- a/src/TR.jl +++ b/src/TR.jl @@ -469,9 +469,9 @@ The LS problem is solved explicitly and no shortcuts are used. """ function TRSegCVNaive(X, y, lambdas, cvfolds) -n, p = size(X); -rmsecvman = zeros(length(lambdas)); -nfolds = length(unique(cvfolds)); +n, p = size(X); +rmsecv = zeros(length(lambdas)); +nfolds = length(unique(cvfolds)); for j = 1:length(lambdas) for i = 1:nfolds @@ -489,9 +489,9 @@ for j = 1:length(lambdas) end end -rmsecvman = sqrt.(1/n .* rmsecvman); +rmsecv = sqrt.(1/n .* rmsecv); -return rmsecvman +return rmsecv end """ @@ -527,7 +527,7 @@ end """ -K-fold CV for the Ridge regression problem, using the 'SVD-trick' for calculating the regression coefficients. +K-fold CV for the Ridge regression update problem, using the 'SVD-trick' for calculating the regression coefficients. """ function TRSegCVUpdateFair(X, y, lambdas, cv, bOld) @@ -559,4 +559,39 @@ end rmsecvman = sqrt.(1/n .* rmsecvman); return rmsecvman +end + +""" +K-fold CV for the Ridge regression problem, using the 'SVD-trick' for calculating the regression coefficients. +""" +function TRSegCVFair(X, y, lambdas, cv) + +n, p = size(X); +rmsecv = zeros(length(lambdas)); +nfolds = length(unique(cv)); + +for i = 1:nfolds + inds = (cv .== i); + Xdata = X[vec(.!inds),:]; + ydata = y[vec(.!inds)]; + + mX = mean(Xdata, dims=1); + my = mean(ydata); + Xs = Xdata .- mX; + ys = ydata .- my; + + U, s, V = svd(Xs, full=false); + + denom = broadcast(.+, broadcast(./, lambdas, s'), s')'; + denom2 = broadcast(.+, ones(n-sum(inds)), broadcast(./, lambdas', s.^2)); + + # Calculating regression coefficients + bcoeffs = V * broadcast(./, (U' * ys), denom); + rmsecv += sum((y[vec(inds)] .- ((X[vec(inds),:] .- mX) * bcoeffs .+ my)).^2, dims=1)'; + +end + +rmsecv = sqrt.(1/n .* rmsecv); + + return rmsecv end \ No newline at end of file diff --git a/src/Ting.jl b/src/Ting.jl index ca8de51..59d88d5 100644 --- a/src/Ting.jl +++ b/src/Ting.jl @@ -27,6 +27,8 @@ export TRLooCVUpdateFair export TRLooCVUpdateNaive export TRSegCVUpdateNaive export TRSegCVUpdateFair +export TRSegCVNaive +export TRSegCVFair include("convenience.jl") include("TR.jl")