Added fast LooCVUpdate

This commit is contained in:
Joakim Skogholt 2024-02-05 14:46:37 +01:00
parent 7325edec01
commit 63e68c3175
2 changed files with 83 additions and 0 deletions

View file

@ -2,6 +2,88 @@
"""
Updates regression coefficient by solving the augmented TR problem [Xs; sqrt(lambda)*I] * beta = [ys; sqrt(lambda)*b_old]
Note that many regularization types are supported but the regularization is on the difference between new and old reg. coeffs.
and so most regularization types are probably not meaningful.
Inputs:
- X/p : Size of regularization matrix or data matrix (size of reg. mat. will then be size(X,2)
- regType : "L2" (returns identity matrix)
"bc" (boundary condition, forces zero on right endpoint for derivative regularization) or
"legendre" (no boundary condition, but fills out reg. mat. with lower order polynomial trends to get square matrix)
"std" (standardization, FILL OUT WHEN DONE)
- regParam1 : Int64, Indicates degree of derivative regularization (0 gives L\\_2)
- regParam2 : For regType=="plegendre" added polynomials are multiplied by sqrt(regParam2)
Output
"""
function TRLooCVUpdate(X, y, lambdas, bold, regType="L2", regParam1=1, regParam2=1)
n, p = size(X);
mX = mean(X, dims=1);
X = X .- mX;
my = mean(y);
y = y .- my;
if regType == "bc"
regMat = [I(p); zeros(regParam1,p)]; for i = 1:regParam1 regMat = diff(regMat, dims = 1); end
elseif regType == "legendre"
regMat = [I(p); zeros(regParam1,p)]; for i = 1:regParam1 regMat = diff(regMat, dims = 1); end
P, _ = plegendre(regParam-1, p);
regMat[end-regParam1+1:end,:] = sqrt(regParam2) * P;
elseif regType == "L2"
regMat = I(p);
elseif regType == "std"
regMat = Diagonal(vec(std(X, dims=1)));
elseif regType == "GL" # Grünwald-Letnikov fractional derivative regulariztion
# regParam1 is alpha (order of fractional derivative)
C = ones(p)*1.0;
for k in 2:p
C[k] = (1-(regParam1+1)/(k-1)) * C[k-1];
end
regMat = zeros(p,p);
for i in 1:p
regMat[i:end, i] = C[1:end-i+1];
end
end
X = X / regMat;
U, s, V = svd(X, full=false)
denom = broadcast(.+, broadcast(./, lambdas, s'), s')';
denom2 = broadcast(.+, ones(n), broadcast(./, lambdas', s.^2))
resid = broadcast(.-, y, U * (broadcast(./, s .* (U'*y), denom) + s .* broadcast(.-, 1, broadcast(./, 1, denom2)) .* (V' * bold)))
H = broadcast(.+, U.^2 * broadcast(./, s, denom), 1/n);
rescv = broadcast(./, resid, broadcast(.-, 1, H));
press = vec(sum(rescv.^2, dims=1));
rmsecv = sqrt.(1/n .* press);
GCV = vec(broadcast(./, sum(resid.^2, dims=1), mean(broadcast(.-, 1, H), dims=1).^2));
idminPRESS = findmin(press)[2][1]; # First index selects coordinates, second selects '1st coordinate'
idminGCV = findmin(GCV)[2][1]; # First index selects coordinates, second selects '1st coordinate'
lambdaPRESS = lambdas[idminPRESS];
lambdaGCV = lambdas[idminGCV];
bcoeffs = V * broadcast(./, (U' * y), denom);
bcoeffs = regMat \ bcoeffs;
if my != 0
bcoeffs = [my .- mX*bcoeffs; bcoeffs];
end
bpress = bcoeffs[:, idminPRESS];
bgcv = bcoeffs[:, idminGCV];
return bpress, bgcv, rmsecv, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV
end
"""
### TO DO: ADD FRACTIONAL DERIVATIVE REGULARIZATION <-- Check that it is correctly added:) ###

View file

@ -20,6 +20,7 @@ export importData
export TRLooCV
export TRSegCV
export regularizationMatrix
export TRLooCVUpdate
include("convenience.jl")
include("TR.jl")