Added fast k-fold update of reg coeffs

This commit is contained in:
Joakim Skogholt 2024-02-05 15:32:24 +01:00
parent 63e68c3175
commit 06cceabc6e
2 changed files with 76 additions and 0 deletions

View file

@ -1,7 +1,82 @@
"""
Fast k-fold cv for updating regression coefficients
"""
function TRSegCVUpdate(X, y, lambdas, cv, bold, regType="L2", regParam1=0, regParam2=1e-14)
n, p = size(X);
mX = mean(X, dims=1);
X = X .- mX;
my = mean(y);
y = y .- my;
if regType == "bc"
regMat = [I(p); zeros(regParam1,p)]; for i = 1:regParam1 regMat = diff(regMat, dims = 1); end
elseif regType == "legendre"
regMat = [I(p); zeros(regParam1,p)]; for i = 1:regParam1 regMat = diff(regMat, dims = 1); end
P, _ = plegendre(regParam-1, p);
regMat[end-regParam1+1:end,:] = sqrt(regParam2) * P;
elseif regType == "L2"
regMat = I(p);
elseif regType == "std"
regMat = Diagonal(vec(std(X, dims=1)));
elseif regType == "GL" # Grünwald-Letnikov fractional derivative regulariztion
# regParam1 is alpha (order of fractional derivative)
C = ones(p)*1.0;
for k in 2:p
C[k] = (1-(regParam1+1)/(k-1)) * C[k-1];
end
regMat = zeros(p,p);
for i in 1:p
regMat[i:end, i] = C[1:end-i+1];
end
end
X = X / regMat;
U, s, V = svd(X, full=false);
n_seg = maximum(cv);
n_lambdas = length(lambdas);
my = mean(y);
y = y .- my;
denom = broadcast(.+, broadcast(./, lambdas, s'), s')';
denom2 = broadcast(.+, ones(n), broadcast(./, lambdas', s.^2))
resid = broadcast(.-, y, U * (broadcast(./, s .* (U'*y), denom) + s .* broadcast(.-, 1, broadcast(./, 1, denom2)) .* (V' * bold)))
rescv = zeros(n, n_lambdas);
sdenom = sqrt.(broadcast(./, s, denom))';
for seg in 1:n_seg
Useg = U[vec(cv .== seg),:];
Id = 1.0 * I(size(Useg,1)) .- 1/n;
for k in 1:n_lambdas
Uk = Useg .* sdenom[k,:]';
rescv[vec(cv .== seg), k] = (Id - Uk * Uk') \ resid[vec(cv .== seg), k];
end
end
press = sum(rescv.^2, dims=1)';
rmsecv = sqrt.(1/n .* press);
bcoeffs = V * broadcast(./, (U' * y), denom);
bcoeffs = regMat \ bcoeffs;
if my != 0
bcoeffs = [my .- mX*bcoeffs; bcoeffs];
end
lambda_min, lambda_min_ind = findmin(rmsecv)
lambda_min_ind = lambda_min_ind[1]
b_lambda_min = bcoeffs[:,lambda_min_ind]
return b_lambda_min, rmsecv, lambda_min, lambda_min_ind
end
""" """
Updates regression coefficient by solving the augmented TR problem [Xs; sqrt(lambda)*I] * beta = [ys; sqrt(lambda)*b_old] Updates regression coefficient by solving the augmented TR problem [Xs; sqrt(lambda)*I] * beta = [ys; sqrt(lambda)*b_old]

View file

@ -21,6 +21,7 @@ export TRLooCV
export TRSegCV export TRSegCV
export regularizationMatrix export regularizationMatrix
export TRLooCVUpdate export TRLooCVUpdate
export TRSegCVUpdate
include("convenience.jl") include("convenience.jl")
include("TR.jl") include("TR.jl")