Added missing stuff from TR.jl file

This commit is contained in:
Joakim Skogholt 2023-05-13 20:17:21 +02:00
parent 25eb5a7c2a
commit ad235d7913

233
src/TR.jl
View file

@ -8,6 +8,21 @@ using Optimization
using OptimizationOptimJL using OptimizationOptimJL
struct TRSVD
U::Matrix{Float64}
s::Vector{Float64}
V::Matrix{Float64}
mX::Matrix{Float64}
regType::String
regParam1::Float64
regMat::Matrix{Float64}
n::Int64
p::Int64
end
""" """
### TO DO: ADD FRACTIONAL DERIVATIVE REGULARIZATION ### ### TO DO: ADD FRACTIONAL DERIVATIVE REGULARIZATION ###
@ -53,4 +68,222 @@ elseif regType == "std"
end end
return regMat return regMat
end
"""
function TRSVDDecomp(X, regType="L2", regParam1=0, regParam2=1e-14)
Calculates regularization matrix (using function "RegularizationMatrix"),
and centres and transforms data matrix according to "X / regMat".
Output is an object of type "TRSVD" and is used as input to other TR functions.
"""
function TRSVDDecomp(X, regType="L2", regParam1=0, regParam2=1e-14)
n, p = size(X);
mX = mean(X, dims=1);
X = X .- mX;
regMat = regularizationMatrix(X; regType, regParam1, regParam2);
X = X / regMat;
U, s, V = svd(X, full=false);
TRObj = TRSVD(U, s, V, mX, regType, regParam1, regMat, n, p);
return TRObj
end
"""
TRRegCoeffs(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14)
TRRegCoeffs(TR::TRSVD, y, lambdas, my=0)
Calculates regression coefficients for TR model.
First function returns "bcoeffs, TR::TRSVD",
second function returns "bcoeffs".
"""
function TRRegCoeffs(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14)
TR = TRSVDDecomp(X, regType, regParam1, regParam2);
my = mean(y);
y = y .- my;
@inline bcoeffs = TRRegCoeffs(TR, y, lambdas, my);
return bcoeffs, TR
end
function TRRegCoeffs(TR::TRSVD, y, lambdas, my=0)
# Don't forget about centering (both X and y) - Maybe do it outside of this function?
denom = broadcast(.+, broadcast(./, lambdas, TR.s'), TR.s')';
bcoeffs = TR.V * broadcast(./, (TR.U' * y), denom);
bcoeffs = TR.regMat \ bcoeffs;
if my != 0
bcoeffs = [my .- TR.mX*bcoeffs; bcoeffs];
end
return bcoeffs
end
"""
TRPress(TR::TRSVD, y, lambdas)
TRPress(TR::TRSVD, y, lambdas, H, resid)
Calculates and returns press-values (as vector) for lambda values given as input.
"""
function TRPress(TR::TRSVD, y, lambdas)
denom = broadcast(.+, broadcast(./, lambdas, TR.s'), TR.s')';
resid = broadcast(.-, y, TR.U * broadcast(./, TR.s .* (TR.U'*y), denom));
H = broadcast(.+, TR.U.^2 * broadcast(./, TR.s, denom), 1/TR.n);
@inline press = TRPress(TR, y, lambdas, denom, H, resid);
return press
end
function TRPress(TR::TRSVD, y, lambdas, H, resid)
rescv = broadcast(./, resid, broadcast(.-, 1, H));
press = vec(sum(rescv.^2, dims=1));
return press
end
"""
function TRGCV(TR::TRSVD, y, lambdas)
function TRGCV(TR::TRSVD, y, lambdas, H, resid)
Calculates and returns GCV-values (as vector) for lambda values given as input.
"""
function TRGCV(TR::TRSVD, y, lambdas)
denom = broadcast(.+, broadcast(./, lambdas, TR.s'), TR.s')';
resid = broadcast(.-, y, TR.U * broadcast(./, TR.s .* (TR.U'*y), denom));
H = broadcast(.+, TR.U.^2 * broadcast(./, TR.s, denom), 1/TR.n);
@inline GCV = TRGCV(TR, y, lambdas, denom, H, resid);
return GCV
end
function TRGCV(TR::TRSVD, y, lambdas, H, resid)
GCV = vec(broadcast(./, sum(resid.^2, dims=1), mean(broadcast(.-, 1, H), dims=1).^2));
return GCV;
end
"""
function TRLooCV(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14)
function TRLooCV(TR::TRSVD, y, lambdas)
Calculates PRESS- and GCV-minimal regression coefficients from the reg. param values in lambdas.
Outputs: BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV
"""
function TRLooCV(X, y, lambdas, regType="L2", regParam1=0, regParam2=1e-14)
TR = TRSVDDecomp(X, regType, regParam1, regParam2);
BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV = TRLooCV(TR, y, lambdas);
return BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV
end
function TRLooCV(TR::TRSVD, y, lambdas)
my = mean(y);
y = y .- my;
denom = broadcast(.+, broadcast(./, lambdas, TR.s'), TR.s')';
H = broadcast(.+, TR.U.^2 * broadcast(./, TR.s, denom), 1/TR.n);
resid = broadcast(.-, y, TR.U * broadcast(./, TR.s .* (TR.U'*y), denom));
@inline press = TRPress(TR, y, lambdas, H, resid);
@inline GCV = TRGCV(TR, y, lambdas, H, resid);
idminPRESS = findmin(press)[2][1]; # First index selects coordinates, second selects '1st coordinate'
idminGCV = findmin(GCV)[2][1]; # First index selects coordinates, second selects '1st coordinate'
lambdaPRESS = lambdas[idminPRESS];
lambdaGCV = lambdas[idminGCV];
BPRESS = TRRegCoeffs(TR, y, lambdaPRESS, my);
BGCV = TRRegCoeffs(TR, y, lambdaGCV, my);
return BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV
end
"""
function PlotTRLooCV(BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV)
The function uses the output from TRLooCV to plot data, PRESS- and GCV-curves, as well as
PRESS- and GCV-minimal regression coefficients.
"""
function PlotTRLooCV(BPRESS, BGCV, TR, press, GCV, idminPRESS, lambdaPRESS, idminGCV, lambdaGCV)
plta = plot((TR.U * diagm(TR.s) * TR.V' .+ TR.mX)', legend=false)
pltb = plot(log10.(lambdas), xlabel="log10(lambda)", press, label="PRESS");
plot!(log10.(lambdas), GCV, label="GCV")
#plot!(log10(lambdaPRESS), press[idminPRESS])
pltc = plot(BPRESS[2:end], label="B-press")
plot!(BGCV[2:end], label="B-GCV")
#pltd = plot(X', legend=false)
plt = plot(plta, pltb, pltc, layout=(2,2))
display(plt)
end
"""
function TRLooCVNum(TR::TRSVD, y, lambdaInit=1)
TRLooCVNum(X, y, lambdaInit=1, regType="L2", regParam1=0, regParam2=1e-14)
Finds regularization paramter value minimising the PRESS-statistic
and returns "b, lambda_min".
"""
function TRLooCVNum(TR::TRSVD, y, lambdaInit=1)
function pressfunc(lambdaval)
@inline pressval = TRPress(TR, y, lambdaval[1])
return pressval
end
my = mean(y);
y = y .- my;
prob = OptimizationProblem((x, p) -> pressfunc(x), [1.0], [])
sol = solve(prob, NelderMead())[1];
b = TRRegCoeffs(TR, y, sol, my);
return b, sol[1]
end
function TRLooCVNum(X, y, lambdaInit=1, regType="L2", regParam1=0, regParam2=1e-14)
TR = TRSVDDecomp(X, regType, regParam1, regParam2);
@inline b, lambda_min = TRLooCVNum(TR, y, lambdaInit)
end
"""
function TRGCVNum(TR::TRSVD, y, lambdaInit=1)
TRGCVNum(X, y, lambdaInit=1, regType="L2", regParam1=0, regParam2=1e-14)
Finds regularization paramter value minimising the PRESS-statistic
and returns "b, lambda_min".
"""
function TRGCVNum(TR::TRSVD, y, lambdaInit=1)
function gcvfunc(lambdaval)
@inline gcvval = TRGCV(TR, y, lambdaval[1]);
return gcvval
end
my = mean(y);
y = y .- my;
prob = OptimizationProblem((x, p) -> gcvfunc(x), [1.0], [])
sol = solve(prob, NelderMead())[1];
b = TRRegCoeffs(TR, y, sol, my);
return b, sol[1]
end
function TRGCVNum(X, y, lambdaInit=1, regType="L2", regParam1=0, regParam2=1e-14)
TR = TRSVDDecomp(X, regType, regParam1, regParam2);
@inline b, lambda_min = TRGCVNum(TR, y, lambdaInit)
end end