Julia の GaussianProcesses のライブラリを入れようとする
背景
連続腕バンディット(それもトンプソンサンプリング)をやってみようと思った。 ベルヌーイ分布のトンプソンサンプリングは、ベータ分布の乱数生成だけが肝だが、ライブラリを使ったり、最悪Cの関数を参考にしながら自分で実装すれば特に問題はない。 ところが、ガウス過程の乱数生成は、まずガウス過程の分布を計算するのが一苦労なのと、その分布から乱数を生成する必要があり、またこれも面倒である可能性がある。 とりあえず、フォン・ノイマンの棄却法とか、分布さえ求められればなんとかなりそうな気配はあるので、ガウス過程の分布を求めるところからスタートする。
Julia の GaussianProcesses のライブラリを入れる。
なぜ Julia かといえば、特に理由はないが、python はすでに多くやられていて、特に問題なくインストールとかできそうだから。
早速、README.md どおりインストールを開始する。
julia> Pkg.add("GaussianProcesses") INFO: Cloning cache of Distances from https://github.com/JuliaStats/Distances.jl.git INFO: Cloning cache of GaussianProcesses from https://github.com/STOR-i/GaussianProcesses.jl.git INFO: Cloning cache of LineSearches from https://github.com/JuliaNLSolvers/LineSearches.jl.git INFO: Cloning cache of Optim from https://github.com/JuliaNLSolvers/Optim.jl.git INFO: Cloning cache of PositiveFactorizations from https://github.com/timholy/PositiveFactorizations.jl.git INFO: Cloning cache of ScikitLearnBase from https://github.com/cstjean/ScikitLearnBase.jl.git INFO: Installing Distances v0.4.1 INFO: Installing GaussianProcesses v0.4.0 INFO: Installing LineSearches v0.1.5 INFO: Installing Optim v0.7.8 INFO: Installing PositiveFactorizations v0.0.4 INFO: Installing ScikitLearnBase v0.3.0 INFO: Package database updated INFO: METADATA is out-of-date — you may not have the latest version of GaussianProcesses INFO: Use `Pkg.update()` to get the latest versions of your packages julia> Pkg.update() INFO: Updating METADATA... INFO: Updating cache of Combinatorics... INFO: Updating cache of StatsBase... INFO: Updating cache of Roots... INFO: Updating cache of PolynomialFactors... INFO: Updating cache of Distributions... INFO: Updating cache of Rmath... INFO: Updating cache of Compat... INFO: Updating cache of PDMats... INFO: Computing changes... INFO: Cloning cache of IterTools from https://github.com/JuliaCollections/IterTools.jl.git INFO: Upgrading Combinatorics: v0.4.0 => v0.4.1 INFO: Upgrading Compat: v0.25.2 => v0.26.0 INFO: Installing IterTools v0.1.0 INFO: Upgrading PDMats: v0.6.0 => v0.7.0 INFO: Upgrading Rmath: v0.1.6 => v0.1.7 INFO: Upgrading Roots: v0.3.1 => v0.4.0 INFO: Upgrading StatsBase: v0.15.0 => v0.17.0 INFO: Removing Iterators v0.3.1 INFO: Removing PolynomialFactors v0.0.5 INFO: Removing Primes v0.1.3 INFO: Building Rmath
ちょっと GaussianProcesses は古いとでた。
ではヘルプを見てみよう。
julia> help?> GP search: get_bigfloat_precision gperm CachingPool log1p logspace getpid graphemes isgraph getipaddr set_bigfloat_precision with_bigfloat_precision Couldn't find GP Perhaps you meant gc, !, !=, $, %, &, *, +, -, .%, .*, .+, .-, ./, .<, .>, .\, .^, .÷, .≠, .≤, .≥, /, //, :, <, <:, <<, <=, ==, =>, >, >=, >>, I, \, ^, cd, cp, e, eu, fd, im, in, lq, lu, mv, pi, qr, rm, |, |>, ~, ×, ÷, γ, π, φ, ∈, ∉, ∋, ∌, √, ∛, ∩, ∪, ≈, ≉, ≠, ≡, ≢ or ≤ No documentation found. Binding GP does not exist.
おっと、いきなりこれか。using
してないからじゃないの?
julia> using GaussianProcesses INFO: Precompiling module Optim. INFO: Recompiling stale cache file /Users/xxxxxx/.julia/lib/v0.5/PDMats.ji for module PDMats. INFO: Precompiling module Distances. INFO: Precompiling module ScikitLearnBase.
したら、
help?> GP search: get_bigfloat_precision GP gperm CachingPool log1p logspace getpid graphemes isgraph getipaddr GaussianProcesses set_bigfloat_precision with_bigfloat_precision ...
でるじゃん。
サンプルを試す。
julia> using PyPlot, GaussianProcesses ERROR: ArgumentError: Module PyPlot not found in current path. Run `Pkg.add("PyPlot")` to install the PyPlot package. in require(::Symbol) at ./loading.jl:365 in require(::Symbol) at /Applications/Julia-0.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
やはり。
julia> Pkg.add("PyPlot") ....
うーん、conda とか、numpy とかがインストールされていくのを見ると、julia でやっている意味はあるのかって一瞬思う。 (一瞬だったが)
その後 README.md にあるコマンドを順に実行する。
次は出力なので気をつける。
Dim = 1 ...
julia> plot(gp)
それっぽいのが出てきた。
カーネル関数の超パラメータの調整
カーネル関数には事前にセットされた数々の超パラメータがある。この超パラメータを調整する方法として、第二種の最尤推定法とよばれる方法がある。 PRML本だと、ガウス過程の第二種の最尤推定法は「6.4.3 超パラメータの学習」で触れられている。
では、この超パラメータの調整をREAEME.mdどおり optimize!(gp)
で行ってみよう。
julia> optimize!(gp) #Optimise the hyperparameters WARNING: DifferentiableFunction(args...) is deprecated, use OnceDifferentiable(args...) instead. in depwarn(::String, ::Symbol) at ./deprecated.jl:64 in DifferentiableFunction(::Function, ::Vararg{Function,N}) at ./deprecated.jl:50 in #get_optim_target#20(::Bool, ::Bool, ::Bool, ::Function, ::GaussianProcesses.GP) at /Users/XXXX/.julia/v0.5/GaussianProcesses/src/optimize.jl:74 in (::GaussianProcesses.#kw##get_optim_target)(::Array{Any,1}, ::GaussianProcesses.#get_optim_target, ::GaussianProcesses.GP) at ./<missing>:0 in #optimize!#19(::Bool, ::Bool, ::Bool, ::Optim.ConjugateGradient{Void,Optim.##29#31,LineSearches.#hagerzhang!}, ::Array{Any,1}, ::Function, ::GaussianProcesses.GP) at /Users/XXXX/.julia/v0.5/GaussianProcesses/src/optimize.jl:17 in optimize!(::GaussianProcesses.GP) at /Users/XXXX/.julia/v0.5/GaussianProcesses/src/optimize.jl:17 in eval(::Module, ::Any) at ./boot.jl:234 in eval(::Module, ::Any) at /Applications/Julia-0.5.app/Contents/Resources/julia/lib/julia/sys.dylib:? in eval_user_input(::Any, ::Base.REPL.REPLBackend) at ./REPL.jl:64 in macro expansion at ./REPL.jl:95 [inlined] in (::Base.REPL.##3#4{Base.REPL.REPLBackend})() at ./event.jl:68 while loading no file, in expression starting on line 0 ERROR: MethodError: no method matching set_params!(::GaussianProcesses.GP, ::Float64; noise=true, mean=true, kern=true) Closest candidates are: set_params!(::GaussianProcesses.GP, ::Array{Float64,1}; noise, mean, kern) at /Users/XXXX/.julia/v0.5/GaussianProcesses/src/GP.jl:275 set_params!{K<:GaussianProcesses.Kernel}(::GaussianProcesses.Masked{K<:GaussianProcesses.Kernel}, ::Any) at /Users/XXXX/.julia/v0.5/GaussianProcesses/src/kernels/masked_kernel.jl:55 got unsupported keyword arguments "noise", "mean", "kern" in #optimize!#19(::Bool, ::Bool, ::Bool, ::Optim.ConjugateGradient{Void,Optim.##29#31,LineSearches.#hagerzhang!}, ::Array{Any,1}, ::Function, ::GaussianProcesses.GP) at /Users/XXXX/.julia/v0.5/GaussianProcesses/src/optimize.jl:20 in optimize!(::GaussianProcesses.GP) at /Users/XXXX/.julia/v0.5/GaussianProcesses/src/optimize.jl:17
グフ… 色々古いのか。次回以降に続く。