Simple function fitting demo
using PyPlot, PyCall #PyPlotとPyCallを使用
using Distributions #Juliaの分布ライブラリーを使用
パラメータWを設置. Wは、W[1], W[2], W[3]を持つ配列。
1 2 3 4 5 6 7 8 9 10 11 |
# true param W = Array([1.0, 0.0, 1.0]) W[1] Out: 1.0 W[2] Out: 0.0 W[3] Out: 1.0 |
1 2 3 4 |
# generate data sigma = 0.5 N = 20 X = linspace(-0.4,2.4,N) |
Out: -0.4:0.14736842105263157:2.4
ここは、-0.4から2.4までの2.8の間に19の区切りを入れるということで、0.14736842105263157 x 29 = 2.8
1 2 3 4 5 |
X_min = minimum(X) Out: -0.4 X_max = maximum(X) Out: 2.4 |
次は、-0.4から2.4 の2.8区間に99の区切りを入れるので、0.028282828282828285 x 99 = 2.8
1 2 3 4 |
# regression1 X_all = linspace(X_min, X_max, 100) -0.4:0.028282828282828285:2.4 |
WeightのW1を直線回帰係数をY切片zeroで算定すれば、W1はx*y/x^2
1 |
W1 = sum(Y.*X) / sum(X.^2) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
sum(Y.*X) Out: 79.22002482506197 sum(X.^2) Out: 34.44210526315789 W1 Out: 2.3000924078181195 Y1 = [W1*x for x in X_all] Out: 100-element Array{Float64,1}: -0.920037 -0.854984 -0.789931 -0.724878 -0.659824 -0.594771 -0.529718 -0.464665 -0.399612 -0.334559 -0.269506 -0.204453 -0.1394 ? 4.80464 4.86969 4.93474 4.9998 5.06485 5.1299 5.19496 5.26001 5.32506 5.39012 5.45517 5.52022 |
で直線回帰終了。
次に二次曲線回帰。まずはXの箱、3×20の配列の箱を作成
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
# regression2 X2 = zeros(3, N) Out: 3×20 Array{Float64,2}: 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 X2[1,:] = 1 X2[2,:] = X X2[3,:] = X.^2 X2 Out:3×20 Array{Float64,2}: 1.0 1.0 1.0 … 1.0 1.0 1.0 1.0 -0.4 -0.252632 -0.105263 1.95789 2.10526 2.25263 2.4 0.16 0.0638227 0.0110803 3.83335 4.43213 5.07435 5.76 |
W2は、sum(Y.*X) / sum(X.^2)
1 2 3 4 5 |
W2 = inv(X2*X2') * X2*Y Out: 3-element Array{Float64,1}: 0.964462 -0.0204804 0.957519 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
Y2 = [W2[1] + W2[2]*x + W2[3]*x^2 for x in X_all] Out: 100-element Array{Float64,1}: 1.12586 1.10438 1.08443 1.06602 1.04913 1.03378 1.01996 1.00768 0.996922 0.987699 0.980007 0.973848 0.96922 ? 5.09978 5.2131 5.32796 5.44435 5.56227 5.68173 5.80271 5.92523 6.04928 6.17486 6.30198 6.43062 |
最後にPyPlotでグラフ表示
1 2 3 4 5 6 7 8 9 |
# show data figure() plot(X_all, Y1, "b-") plot(X_all, Y2, "g-") plot(X, Y, "ko") legend(["model1","model2","data"], loc="upper left", fontsize=16) xlabel("\$x\$", fontsize=20) ylabel("\$y\$", fontsize=20) show() |