diff --git a/src/TR.jl b/src/TR.jl index 985e563..5b906b1 100644 --- a/src/TR.jl +++ b/src/TR.jl @@ -8,6 +8,21 @@ using Optimization using OptimizationOptimJL + +struct TRSVD + U::Matrix{Float64} + s::Vector{Float64} + V::Matrix{Float64} + mX::Matrix{Float64} + regType::String + regParam1::Float64 + regMat::Matrix{Float64} + n::Int64 + p::Int64 +end + + + """ ### TO DO: ADD FRACTIONAL DERIVATIVE REGULARIZATION ### @@ -53,4 +68,222 @@ elseif regType == "std" end return regMat +end + + + +""" + function TRSVDDecomp(X, regType="L2", regParam1=0, regParam2=1e-14) + +Calculates regularization matrix (using function "RegularizationMatrix"), +and centres and transforms data matrix according to "X / regMat". +Output is an object of type "TRSVD" and is used as input to other TR functions. +""" +function TRSVDDecomp(X, regType="L2", regParam1=0, regParam2=1e-14) + +n, p = size(X); +mX = mean(X, dims=1); +X = X .- mX; +regMat = regularizationMatrix(X; regType, regParam1, regParam2); +X = X / regMat; +U, s, V = svd(X, full=false); +TRObj = TRSVD(U, s, V, mX, regType, regParam1, regMat, n, p); + +return TRObj +end + + + +""" + TRRegCoeffs(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14) + TRRegCoeffs(TR::TRSVD, y, lambdas, my=0) + +Calculates regression coefficients for TR model. +First function returns "bcoeffs, TR::TRSVD", +second function returns "bcoeffs". +""" +function TRRegCoeffs(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14) + TR = TRSVDDecomp(X, regType, regParam1, regParam2); + my = mean(y); + y = y .- my; + + @inline bcoeffs = TRRegCoeffs(TR, y, lambdas, my); + return bcoeffs, TR +end + +function TRRegCoeffs(TR::TRSVD, y, lambdas, my=0) + # Don't forget about centering (both X and y) - Maybe do it outside of this function? + denom = broadcast(.+, broadcast(./, lambdas, TR.s'), TR.s')'; + bcoeffs = TR.V * broadcast(./, (TR.U' * y), denom); + bcoeffs = TR.regMat \ bcoeffs; + if my != 0 + bcoeffs = [my .- TR.mX*bcoeffs; bcoeffs]; + end + return bcoeffs +end + + +""" + TRPress(TR::TRSVD, y, lambdas) + TRPress(TR::TRSVD, y, lambdas, H, resid) + +Calculates and returns press-values (as vector) for lambda values given as input. +""" +function TRPress(TR::TRSVD, y, lambdas) + denom = broadcast(.+, broadcast(./, lambdas, TR.s'), TR.s')'; + resid = broadcast(.-, y, TR.U * broadcast(./, TR.s .* (TR.U'*y), denom)); + H = broadcast(.+, TR.U.^2 * broadcast(./, TR.s, denom), 1/TR.n); + @inline press = TRPress(TR, y, lambdas, denom, H, resid); + + return press +end + +function TRPress(TR::TRSVD, y, lambdas, H, resid) + rescv = broadcast(./, resid, broadcast(.-, 1, H)); + press = vec(sum(rescv.^2, dims=1)); + + return press +end + + + +""" + function TRGCV(TR::TRSVD, y, lambdas) + function TRGCV(TR::TRSVD, y, lambdas, H, resid) + +Calculates and returns GCV-values (as vector) for lambda values given as input. +""" +function TRGCV(TR::TRSVD, y, lambdas) + denom = broadcast(.+, broadcast(./, lambdas, TR.s'), TR.s')'; + resid = broadcast(.-, y, TR.U * broadcast(./, TR.s .* (TR.U'*y), denom)); + H = broadcast(.+, TR.U.^2 * broadcast(./, TR.s, denom), 1/TR.n); + + @inline GCV = TRGCV(TR, y, lambdas, denom, H, resid); + + return GCV +end + +function TRGCV(TR::TRSVD, y, lambdas, H, resid) + GCV = vec(broadcast(./, sum(resid.^2, dims=1), mean(broadcast(.-, 1, H), dims=1).^2)); + return GCV; +end + + + +""" + function TRLooCV(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14) + function TRLooCV(TR::TRSVD, y, lambdas) + +Calculates PRESS- and GCV-minimal regression coefficients from the reg. param values in lambdas. +Outputs: BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV +""" +function TRLooCV(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14) + +TR = TRSVDDecomp(X, regType, regParam1, regParam2); +BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV = TRLooCV(TR, y, lambdas); +return BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV +end + +function TRLooCV(TR::TRSVD, y, lambdas) + +my = mean(y); +y = y .- my; +denom = broadcast(.+, broadcast(./, lambdas, TR.s'), TR.s')'; +H = broadcast(.+, TR.U.^2 * broadcast(./, TR.s, denom), 1/TR.n); +resid = broadcast(.-, y, TR.U * broadcast(./, TR.s .* (TR.U'*y), denom)); +@inline press = TRPress(TR, y, lambdas, H, resid); +@inline GCV = TRGCV(TR, y, lambdas, H, resid); +idminPRESS = findmin(press)[2][1]; # First index selects coordinates, second selects '1st coordinate' +idminGCV = findmin(GCV)[2][1]; # First index selects coordinates, second selects '1st coordinate' +lambdaPRESS = lambdas[idminPRESS]; +lambdaGCV = lambdas[idminGCV]; +BPRESS = TRRegCoeffs(TR, y, lambdaPRESS, my); +BGCV = TRRegCoeffs(TR, y, lambdaGCV, my); + +return BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV +end + + + +""" + function PlotTRLooCV(BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV) + +The function uses the output from TRLooCV to plot data, PRESS- and GCV-curves, as well as +PRESS- and GCV-minimal regression coefficients. +""" +function PlotTRLooCV(BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV) + +plta = plot((TR.U * diagm(TR.s) * TR.V' .+ TR.mX)', legend=false) + +pltb = plot(log10.(lambdas), xlabel="log10(lambda)", press, label="PRESS"); +plot!(log10.(lambdas), GCV, label="GCV") +#plot!(log10(lambdaPRESS), press[idminPRESS]) +pltc = plot(BPRESS[2:end], label="B-press") +plot!(BGCV[2:end], label="B-GCV") +#pltd = plot(X', legend=false) +plt = plot(plta, pltb, pltc, layout=(2,2)) +display(plt) +end + + +""" + function TRLooCVNum(TR::TRSVD, y, lambdaInit=1) + TRLooCVNum(X, y, lambdaInit=1, regType="L2", regParam1=0, regParam2=1e-14) + +Finds regularization paramter value minimising the PRESS-statistic +and returns "b, lambda_min". +""" +function TRLooCVNum(TR::TRSVD, y, lambdaInit=1) + +function pressfunc(lambdaval) +@inline pressval = TRPress(TR, y, lambdaval[1]) + +return pressval +end + +my = mean(y); +y = y .- my; + +prob = OptimizationProblem((x, p) -> pressfunc(x), [1.0], []) +sol = solve(prob, NelderMead())[1]; +b = TRRegCoeffs(TR, y, sol, my); + +return b, sol[1] +end + +function TRLooCVNum(X, y, lambdaInit=1, regType="L2", regParam1=0, regParam2=1e-14) +TR = TRSVDDecomp(X, regType, regParam1, regParam2); + +@inline b, lambda_min = TRLooCVNum(TR, y, lambdaInit) +end + +""" + function TRGCVNum(TR::TRSVD, y, lambdaInit=1) + TRGCVNum(X, y, lambdaInit=1, regType="L2", regParam1=0, regParam2=1e-14) + +Finds regularization paramter value minimising the PRESS-statistic +and returns "b, lambda_min". +""" +function TRGCVNum(TR::TRSVD, y, lambdaInit=1) + +function gcvfunc(lambdaval) +@inline gcvval = TRGCV(TR, y, lambdaval[1]); + +return gcvval +end + +my = mean(y); +y = y .- my; + +prob = OptimizationProblem((x, p) -> gcvfunc(x), [1.0], []) +sol = solve(prob, NelderMead())[1]; +b = TRRegCoeffs(TR, y, sol, my); + +return b, sol[1] +end + +function TRGCVNum(X, y, lambdaInit=1, regType="L2", regParam1=0, regParam2=1e-14) +TR = TRSVDDecomp(X, regType, regParam1, regParam2); + +@inline b, lambda_min = TRGCVNum(TR, y, lambdaInit) end \ No newline at end of file