From 9ed3bdaea8746bf66f8520cc896f100585c6b9e2 Mon Sep 17 00:00:00 2001 From: Joakim Skogholt Date: Wed, 24 Apr 2024 10:32:53 +0200 Subject: [PATCH] Fixed update regression coefficients and added Naive and Fair function for loocv update --- src/TR.jl | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++-- src/Ting.jl | 2 ++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/src/TR.jl b/src/TR.jl index 2a7d869..02d1cdd 100644 --- a/src/TR.jl +++ b/src/TR.jl @@ -1,5 +1,66 @@ +""" +Solves the model update problem explicitly as a least squares problem with stacked matrices. +In practice the most naive way of approaching the update problem +""" +function TRLooCVUpdateNaive(X, y, lambdasu, bold) + +n, p = size(X); +rmsecvman = zeros(length(lambdasu)); +for i = 1:n + inds = setdiff(1:n, i); + Xdata = X[inds,:]; + ydata = y[inds]; + mX = mean(Xdata, dims=1); + my = mean(ydata); + Xs = Xdata .- mX; + ys = ydata .- my; + p2 = size(Xdata, 2); + + for j = 1:length(lambdasu) + betas = [Xs; sqrt(lambdasu[j]) * I(p2)] \ [ys ; sqrt(lambdasu[j]) * bold]; + rmsecvman[j] += (y[i] - (((X[i,:]' .- mX) * betas)[1] + my))^2; + end +end + +rmsecvman = sqrt.(1/n .* rmsecvman); + +return rmsecvman +end + +""" +Uses the 'svd-trick' for efficient calculation of regression coefficients, but does not use leverage corrections. +Hence regression coefficients are calculated for all lambda values +""" +function TRLooCVUpdateFair(X, y, lambdasu, bold) + +n, p = size(X); +rmsecvman = zeros(length(lambdasu)) + +for i = 1:n + inds = setdiff(1:n, i); + Xdata = X[inds,:]; + ydata = y[inds]; + + mX = mean(Xdata, dims=1); + my = mean(ydata); + Xs = Xdata .- mX; + ys = ydata .- my; + U, s, V = svd(Xs, full=false); + + denom = broadcast(.+, broadcast(./, lambdasu, s'), s')'; # denom = lambda/s + lambda = (lambda + s's) / lambda + denom2 = broadcast(.+, ones(n-1), broadcast(./, lambdasu', s.^2)) # denom2 = 1 + lambda/(s's) = (s's + lambda) / (s's) + + # Calculating regression coefficients and residual + bcoeffs = V * broadcast(./, (U' * ys), denom) .+ bold .- V * broadcast(./, V' * bold, denom2); + rmsecvman += ((y[i] .- ((X[i,:]' .- mX) * bcoeffs .+ my)).^2)'; +end + +rmsecvman = sqrt.(1/n .* rmsecvman); + +return rmsecvman, bcoeffs +end """ Fast k-fold cv for updating regression coefficients @@ -70,7 +131,7 @@ end # Calculating rmsecv and regression coefficients press = sum(rescv.^2, dims=1)'; rmsecv = sqrt.(1/n .* press); -bcoeffs = V * broadcast(./, (U' * y), denom); +bcoeffs = V * broadcast(./, (U' * ys), denom) .+ bold .- V * broadcast(./, V' * bold, denom2); bcoeffs = regMat \ bcoeffs; # Creating regression coefficients for uncentred data @@ -148,7 +209,7 @@ lambdarmsecv = lambdas[idminrmsecv]; lambdagcv = lambdas[idmingcv]; # Calculating regression coefficients -bcoeffs = V * broadcast(./, (U' * y), denom); +bcoeffs = V * broadcast(./, (U' * ys), denom) .+ bold .- V * broadcast(./, V' * bold, denom2); bcoeffs = regMat \ bcoeffs; if my != 0 diff --git a/src/Ting.jl b/src/Ting.jl index 6d01d21..fe9e3c9 100644 --- a/src/Ting.jl +++ b/src/Ting.jl @@ -23,6 +23,8 @@ export regularizationMatrix export TRLooCVUpdate export TRSegCVUpdate export plegendre +export TRLooCVUpdateFair +export TRLooCVUpdateNaive include("convenience.jl") include("TR.jl")