Fixed centring error in SegCV and SegCVUpdate, added plegendre, same form for both updates as in paper draft

This commit is contained in:
Joakim Skogholt 2024-02-09 12:44:45 +01:00
parent 06cceabc6e
commit 38ca6f1aa7
2 changed files with 97 additions and 66 deletions

156
src/TR.jl
View file

@ -4,29 +4,27 @@
"""
Fast k-fold cv for updating regression coefficients
"""
function TRSegCVUpdate(X, y, lambdas, cv, bold, regType="L2", regParam1=0, regParam2=1e-14)
function TRSegCVUpdate(X, y, lambdas, cv, bOld, regType="L2", derOrder=0)
n, p = size(X);
mX = mean(X, dims=1);
X = X .- mX;
my = mean(y);
y = y .- my;
# Finding appropriate regularisation matrix
if regType == "bc"
regMat = [I(p); zeros(regParam1,p)]; for i = 1:regParam1 regMat = diff(regMat, dims = 1); end
regMat = [I(p); zeros(derOrder,p)];
for i = 1:derOrder regMat = diff(regMat, dims = 1); end
elseif regType == "legendre"
regMat = [I(p); zeros(regParam1,p)]; for i = 1:regParam1 regMat = diff(regMat, dims = 1); end
P, _ = plegendre(regParam-1, p);
regMat[end-regParam1+1:end,:] = sqrt(regParam2) * P;
regMat = [I(p); zeros(derOrder,p)];
for i = 1:derOrder regMat = diff(regMat, dims = 1); end
P, _ = plegendre(derOrder-1, p);
epsilon = 1e-14;
regMat[end-derOrder+1:end,:] = sqrt(epsilon) * P;
elseif regType == "L2"
regMat = I(p);
elseif regType == "std"
regMat = Diagonal(vec(std(X, dims=1)));
elseif regType == "GL" # Grünwald-Letnikov fractional derivative regulariztion
# regParam1 is alpha (order of fractional derivative)
elseif regType == "GL" # GL fractional derivative regulariztion
C = ones(p)*1.0;
for k in 2:p
C[k] = (1-(regParam1+1)/(k-1)) * C[k-1];
C[k] = (1-(derOrder+1)/(k-1)) * C[k-1];
end
regMat = zeros(p,p);
@ -36,18 +34,25 @@ elseif regType == "GL" # Grünwald-Letnikov fractional derivative regulariztion
end
end
X = X / regMat;
U, s, V = svd(X, full=false);
n_seg = maximum(cv);
n_lambdas = length(lambdas);
# Preliminary calculations
n, p = size(X);
mX = mean(X, dims=1);
X = X .- mX;
my = mean(y);
y = y .- my;
X = X / regMat;
U, s, V = svd(X, full=false);
n_seg = maximum(cv);
n_lambdas = length(lambdas);
# Finding residuals
denom = broadcast(.+, broadcast(./, lambdas, s'), s')';
denom2 = broadcast(.+, ones(n), broadcast(./, lambdas', s.^2))
resid = broadcast(.-, y, U * (broadcast(./, s .* (U'*y), denom) + s .* broadcast(.-, 1, broadcast(./, 1, denom2)) .* (V' * bold)))
yhat = broadcast(./, s .* (U'*y), denom)
yhat += s .* broadcast(.-, 1, broadcast(./, 1, denom2)) .* (V' * bOld)
resid = broadcast(.-, y, U * yhat)
# Finding cross-validated residuals
rescv = zeros(n, n_lambdas);
sdenom = sqrt.(broadcast(./, s, denom))';
@ -57,20 +62,23 @@ for seg in 1:n_seg
Id = 1.0 * I(size(Useg,1)) .- 1/n;
for k in 1:n_lambdas
Uk = Useg .* sdenom[k,:]';
Uk = Useg .* sdenom[k,:]';
rescv[vec(cv .== seg), k] = (Id - Uk * Uk') \ resid[vec(cv .== seg), k];
end
end
press = sum(rescv.^2, dims=1)';
rmsecv = sqrt.(1/n .* press);
# Calculating rmsecv and regression coefficients
press = sum(rescv.^2, dims=1)';
rmsecv = sqrt.(1/n .* press);
bcoeffs = V * broadcast(./, (U' * y), denom);
bcoeffs = regMat \ bcoeffs;
# Creating regression coefficients for uncentred data
if my != 0
bcoeffs = [my .- mX*bcoeffs; bcoeffs];
end
# Finding rmsecv-minimal lambda value and associated regression coefficients
lambda_min, lambda_min_ind = findmin(rmsecv)
lambda_min_ind = lambda_min_ind[1]
b_lambda_min = bcoeffs[:,lambda_min_ind]
@ -82,68 +90,65 @@ end
Updates regression coefficient by solving the augmented TR problem [Xs; sqrt(lambda)*I] * beta = [ys; sqrt(lambda)*b_old]
Note that many regularization types are supported but the regularization is on the difference between new and old reg. coeffs.
and so most regularization types are probably not meaningful.
Inputs:
- X/p : Size of regularization matrix or data matrix (size of reg. mat. will then be size(X,2)
- regType : "L2" (returns identity matrix)
"bc" (boundary condition, forces zero on right endpoint for derivative regularization) or
"legendre" (no boundary condition, but fills out reg. mat. with lower order polynomial trends to get square matrix)
"std" (standardization, FILL OUT WHEN DONE)
- regParam1 : Int64, Indicates degree of derivative regularization (0 gives L\\_2)
- regParam2 : For regType=="plegendre" added polynomials are multiplied by sqrt(regParam2)
Output
"""
function TRLooCVUpdate(X, y, lambdas, bold, regType="L2", regParam1=1, regParam2=1)
n, p = size(X);
mX = mean(X, dims=1);
X = X .- mX;
my = mean(y);
y = y .- my;
function TRLooCVUpdate(X, y, lambdas, bOld, regType="L2", derOrder=0)
# Finding appropriate regularisation matrix
if regType == "bc"
regMat = [I(p); zeros(regParam1,p)]; for i = 1:regParam1 regMat = diff(regMat, dims = 1); end
regMat = [I(p); zeros(derOrder,p)];
for i = 1:derOrder regMat = diff(regMat, dims = 1); end
elseif regType == "legendre"
regMat = [I(p); zeros(regParam1,p)]; for i = 1:regParam1 regMat = diff(regMat, dims = 1); end
P, _ = plegendre(regParam-1, p);
regMat[end-regParam1+1:end,:] = sqrt(regParam2) * P;
regMat = [I(p); zeros(derOrder,p)];
for i = 1:derOrder regMat = diff(regMat, dims = 1); end
P, _ = plegendre(derOrder-1, p);
epsilon = 1e-14;
regMat[end-derOrder+1:end,:] = sqrt(epsilon) * P;
elseif regType == "L2"
regMat = I(p);
elseif regType == "std"
regMat = Diagonal(vec(std(X, dims=1)));
elseif regType == "GL" # Grünwald-Letnikov fractional derivative regulariztion
# regParam1 is alpha (order of fractional derivative)
elseif regType == "GL" # GL fractional derivative regulariztion
C = ones(p)*1.0;
for k in 2:p
C[k] = (1-(regParam1+1)/(k-1)) * C[k-1];
C[k] = (1-(derOrder+1)/(k-1)) * C[k-1];
end
regMat = zeros(p,p);
for i in 1:p
regMat[i:end, i] = C[1:end-i+1];
end
end
X = X / regMat;
U, s, V = svd(X, full=false)
# Preliminary calculations
n, p = size(X);
mX = mean(X, dims=1);
X = X .- mX;
my = mean(y);
y = y .- my;
X = X / regMat;
U, s, V = svd(X, full=false);
n_seg = maximum(cv);
n_lambdas = length(lambdas);
# Main calculations
denom = broadcast(.+, broadcast(./, lambdas, s'), s')';
denom2 = broadcast(.+, ones(n), broadcast(./, lambdas', s.^2))
resid = broadcast(.-, y, U * (broadcast(./, s .* (U'*y), denom) + s .* broadcast(.-, 1, broadcast(./, 1, denom2)) .* (V' * bold)))
yhat = broadcast(./, s .* (U'*y), denom);
yhat += s .* broadcast(.-, 1, broadcast(./, 1, denom2)) .* (V' * bOld)
resid = broadcast(.-, y, U * yhat)
H = broadcast(.+, U.^2 * broadcast(./, s, denom), 1/n);
rescv = broadcast(./, resid, broadcast(.-, 1, H));
press = vec(sum(rescv.^2, dims=1));
rmsecv = sqrt.(1/n .* press);
GCV = vec(broadcast(./, sum(resid.^2, dims=1), mean(broadcast(.-, 1, H), dims=1).^2));
gcv = vec(broadcast(./, sum(resid.^2, dims=1), mean(broadcast(.-, 1, H), dims=1).^2));
idminPRESS = findmin(press)[2][1]; # First index selects coordinates, second selects '1st coordinate'
idminGCV = findmin(GCV)[2][1]; # First index selects coordinates, second selects '1st coordinate'
lambdaPRESS = lambdas[idminPRESS];
lambdaGCV = lambdas[idminGCV];
# Finding lambda that minimises rmsecv and GCV
idminrmsecv = findmin(press)[2][1];
idmingcv = findmin(gcv)[2][1];
lambdarmsecv = lambdas[idminrmsecv];
lambdagcv = lambdas[idmingcv];
# Calculating regression coefficients
bcoeffs = V * broadcast(./, (U' * y), denom);
bcoeffs = regMat \ bcoeffs;
@ -151,11 +156,11 @@ if my != 0
bcoeffs = [my .- mX*bcoeffs; bcoeffs];
end
bpress = bcoeffs[:, idminPRESS];
bgcv = bcoeffs[:, idminGCV];
brmsecv = bcoeffs[:, idminrmsecv];
bgcv = bcoeffs[:, idmingcv];
return bpress, bgcv, rmsecv, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV
return brmsecv, bgcv, rmsecv, gcv, idminrmsecv, lambdarmsecv, idmingcv, lambdagcv
end
@ -338,8 +343,6 @@ U, s, V = svd(X, full=false);
n_seg = maximum(cv);
n_lambdas = length(lambdas);
my = mean(y);
y = y .- my;
denom = broadcast(.+, broadcast(./, lambdas, s'), s')';
resid = broadcast(.-, y, U * broadcast(./, s .* (U'*y), denom));
@ -372,3 +375,30 @@ b_lambda_min = bcoeffs[:,lambda_min_ind]
return b_lambda_min, rmsecv, lambda_min, lambda_min_ind
end
"""
function plegendre(d, p)
Calculates orthonormal Legendre polynomials using a QR factorisation.
Inputs:
- d : polynomial degree
- p : size of vector
Outputs:
- Q : (d+1) x p matrix with basis
- R : matrix from QR-factorisation
"""
function plegendre(d, p)
P = ones(p, d+1);
x = (-1:2/(p-1):1)';
for k in 1:d
P[:,k+1] = x.^k;
end
factorisation = qr(P);
Q = Matrix(factorisation.Q)';
R = Matrix(factorisation.R);
return Q, R
end

View file

@ -22,6 +22,7 @@ export TRSegCV
export regularizationMatrix
export TRLooCVUpdate
export TRSegCVUpdate
export plegendre
include("convenience.jl")
include("TR.jl")