From aea04b8b626e81b846948cd7ba96b2d9fe0deed9 Mon Sep 17 00:00:00 2001 From: Joakim Skogholt Date: Thu, 18 May 2023 13:26:55 +0200 Subject: [PATCH] Added TRVirSV and TRSegCV + various fixes --- src/TR.jl | 108 ++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 97 insertions(+), 11 deletions(-) diff --git a/src/TR.jl b/src/TR.jl index 147499d..6c0617f 100644 --- a/src/TR.jl +++ b/src/TR.jl @@ -1,13 +1,6 @@ - - - - - -using Optimization -using OptimizationOptimJL - - +using Optimization # For numerical minimization of PRESS statistic +using OptimizationOptimJL # For numerical minimization of PRESS statistic struct TRSVD U::Matrix{Float64} @@ -40,8 +33,8 @@ end """ ### TO DO: ADD FRACTIONAL DERIVATIVE REGULARIZATION ### - regularizationMatrix(X; regType="legendre", regParam1=0, regParam2=1e-14) - regularizationMatrix(p::Int64; regType="legendre", regParam1=0, regParam2=1e-14) + regularizationMatrix(X; regType="L2", regParam1=0, regParam2=1e-14) + regularizationMatrix(p::Int64; regType="L2", regParam1=0, regParam2=1e-14) Calculates and returns square regularization matrix. @@ -130,6 +123,99 @@ end +""" + function TRVirCV(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14) + +Segmented virtual cross-validation (VirCV) for TR models. +Outputs: b, press, lambda_min, lambda_min_ind, GCV +b are (virtual) press-minimal regression coefficients. +""" +function TRVirCV(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14) + +U_segments = TRSegmentOrth(X, segments); +bs = vec(sum(U_segments, dims=1).^2); + +n, p = size(X); +mX = mean(X, dims=1); +X = X .- mX; +my = mean(y); +y = vec(y .- my); +X = U_segments' * X; +y = U_segments' * y; +regMat = regularizationMatrix(p; regType, regParam1, regParam2); +X = X / regMat; +U, s, V = svd(X, full=false); + +denom = broadcast(.+, broadcast(./, lambdas, s'), s')'; +H = broadcast(.+, U.^2 * broadcast(./, s, denom), bs./n); +resid = broadcast(.-, y, U * broadcast(./, s .* (U'*y), denom)); +rescv = broadcast(./, resid, broadcast(.-, 1, H)); +press = vec(sum(rescv.^2, dims=1)); +#rmsecv = sqrt.(1/n .* press); +GCV = vec(broadcast(./, sum(resid.^2, dims=1), mean(broadcast(.-, 1, H), dims=1).^2)); + +lambda_min, lambda_min_ind = findmin(press); +lambda_min_ind = lambda_min_ind[1]; + +denom2 = broadcast(.+, lambda_min ./ s', s')'; +b = V * broadcast(./, (U' * y), denom2); +b = regMat \ b; +b = [my .- mX*b; b]; + +return b, press, lambda_min, lambda_min_ind, GCV +end + + + +""" + function TRSegCV(X, y, lambdas, folds, regType="L2", regParam1=0, regParam2=1e-14) + +Segmented cross-validation based on the Sherman-Morrison-Woodbury updating formula. +Inputs: + - X : Data matrix + - y : Response vector + - lambdas : Vector of regularization parameter values + - folds : Vector of length n indicating segment membership for each sample + - regType, regParam1, regParam2 : Inputs to regularizationMatrix function + +Outputs: rmsecv, b, lambda_min, lambda_min_ind. +b are regression coefficients corresponding to the lambda value minimising the CV-error. +""" +function TRSegCV(X, y, lambdas, folds, regType="L2", regParam1=0, regParam2=1e-14) + +TR = TRSVDDecomp(X, regType, regParam1, regParam2); +n_seg = maximum(folds); +n_lambdas = length(lambdas); +my = mean(y); +y = y .- my; + +denom = broadcast(.+, broadcast(./, lambdas, TR.s'), TR.s')'; +resid = broadcast(.-, y, TR.U * broadcast(./, TR.s .* (TR.U'*y), denom)); +rescv = zeros(TR.n, n_lambdas); +sdenom = sqrt.(broadcast(./, TR.s, denom))'; + +for seg in 1:n_seg + + Useg = TR.U[vec(cv .== seg),:]; + Id = 1.0 * I(size(Useg,1)) .- 1/TR.n; + + for k in 1:n_lambdas + Uk = Useg .* sdenom[k,:]'; + rescv[vec(cv .== seg), k] = (Id - Uk * Uk') \ resid[vec(cv .== seg), k]; + end +end + +press = sum(rescv.^2, dims=1)'; +rmsecv = sqrt.(1/TR.n .* press); + +lambda_min, lambda_min_ind = findmin(rmsecv) +lambda_min_ind = lambda_min_ind[1] +b = TRRegCoeffs(TR, y, lambda_min, my) + +return b, rmsecv, lambda_min, lambda_min_ind +end + + """ TRRegCoeffs(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14) TRRegCoeffs(TR::TRSVD, y, lambdas, my=0)