最近いくつかの新しい確率プログラミングライブラリがリリースされました。BeanMachine, GEN, Turingです。
BeanMachineはPyTorch上に実装され、GEN, TuringはJuliaに提供される確率プログラミングのライブラリです。
Juliaは主にデータサイエンス向けに作られた新しい言語ですが、従来のRやPython向けと同様にベイジアンライブラリが提供されるようになってきました。GEN,Turingがその例です。
Turingは、Hong Ge氏によって開発され、ボランティアによってメンテナンスされています。MITライセンスで提供されているライブラリです。GENはMITが開発、提供元のベイジアンライブラリです。
ここではJulia向けの二つのライブラリのうち、GENの概要と機能を評価しています。Turingについては、別のPOSTで解説します。
GENの特徴
JITコンパイラで動作するJuliaに実装されます。この点はRやPythonのインターフェイスを提供するSTANと共通しています。
- 確率推論アルゴリズムです。ユーザーに変分推論、シーケンシャル・モンテカルロ、マルコフチェイン・モンテカルロ、ニューラルネットワークのハイブリッドなアルゴリズムの開発環境を提供します。
- 自動微分を含む柔軟なAPIを提供しています。
- 確率構造に基づいた効果的な推論モデル。GENの生成モデルと推論モデルは動的に計算グラフを構成します。リバーシブルジャンプとinvolutive MCMCアルゴリズムで効果的な推論を実行します。最新のMCMCの一つ、involutive MCMCについては以下の論文を参照してください。
"Automating Involutive MCMC using Probabilistic and Differentiable Programming”
GENの設定とインストール
GENはJuliaのライブラリなので、Juliaの実行環境がなければ、Juliaをインストールします。
以下のサイトに移動し、ホストマシンに対応したバージョンをダウンロードします。
https://julialang.org/downloads/
MACの場合は、ダウンロードしたバイナリをApplicationsフォルダにドラッグするだけです。
これでアプリケーションをダブルクリックすると、JuliaのREPLが起動します。
Julia REPL(REPL:read-eval-print loop)
Juliaの環境設定
juliaをコマンドラインから起動したり、ファイル名を指定して実行するため、環境変数にPATHを切っておきます。
デフォルトのログインシェルがbashであれば、bashrcにインストールしたアプリケーションの実行ファイルのPATHを追加します。
vi ~/.bashrc
export PATH=/us/local/bin: $PATH
export PATH=/Applications/Julio-1.9.app/Contents/Resources/julia/bin:$PATH
2行目がjulia用に追加したPATHです。
新しいシェルのコマンドラインインターフェイスを起動してPATHを確認してみます。
echo $PATH
$ echo $PATH
/opt/homebrew/bin:/opt/homebrew/sbin:/usr/local/opt/php@7.3/sbin:/usr/local/opt/php@7.3/bin:/Users/username/anaconda3/bin:/Users/username/anaconda3/condabin:/Applications/Julia-1.9.app/Contents/Resources/julia/bin:/usr/local/bin:/usr/local/bin:/System/Cryptexes/App/usr/bin:/usr/bin:/bin:/usr/sbin:/sbin:/opt/X11/bin:/Applications/Server.app/Contents/ServerRoot/usr/bin:/Applications/Server.app/Contents/ServerRoot/usr/sbin:/Library/Apple/usr/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/local/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/appleinternal/bin
MACの場合、デフォルトのシェルがzshになっているので、zshを使用しているユーザーの場合、zshrcにパスを追加します。
vi ~/.zshrc
# >>> conda initialize >>>
# !! Contents within this block are managed by 'conda init' !!
__conda_setup="$('/Users/username/anaconda3/bin/conda' 'shell.zsh' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
eval "$__conda_setup"
else
if [ -f "/Users/username/anaconda3/etc/profile.d/conda.sh" ]; then
. "/Users/username/anaconda3/etc/profile.d/conda.sh"
else
export PATH="/Users/username/anaconda3/bin:$PATH"
fi
fi
unset __conda_setup
# <<< conda initialize <<<
export PATH=/opt/R/arm64/bin:$PATH
export PATH=/Applications/Julia-1.9.app/Contents/Resources/julia/bin:$PATH
最後の行が追加したパスです。
以下は、zshから起動した場合、
(base)xxxxx~ % julia
_
_ _ _(_)_ | Documentation: https://docs.julialang.org
(_) | (_) (_) |
_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.
| | | | | | |/ _` | |
| | |_| | | | (_| | | Version 1.9.2 (2023-07-05)
_/ |\__'_|_|_|\__'_| | Official https://julialang.org/ release
|__/ |
julia>
GENの追加
Juliaの準備ができたので、GENをインストールしましょう。
Juliaを起動し、Julia REPL(REPL:read-eval-print loop)から']'
を入力します。
juliaのパッケージマネージャーが起動し、プロンプトが変更されます。パッケージマネージャのプロンプトでadd Gen
と入力します。
(base) UsernoMacBook-Air:~ username$ julia
_
_ _ _(_)_ | Documentation: https://docs.julialang.org
(_) | (_) (_) |
_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.
| | | | | | |/ _` | |
| | |_| | | | (_| | | Version 1.9.2 (2023-07-05)
_/ |\__'_|_|_|\__'_| | Official https://julialang.org/ release
|__/ |
(@v1.9) pkg>
(@v1.9) pkg> add Gen
インストールを開始します。
インストール後に、テストを実行することができます。juliaのREPLから以下のコマンドを入力します。
Julia> using Pkg; Pkg.test("Gen")
julia> using Pkg; Pkg.test("Gen")
Testing Gen
Status `/private/var/folders/z3/y2vb_4653kjcytsygg5bkv7w0000gn/T/jl_NeEvI0/Project.toml`
[34da2185] Compat v4.9.0
[864edb3b] DataStructures v0.18.15
[31c24e10] Distributions v0.25.100
[f6369f11] ForwardDiff v0.10.36
[de31a74c] FunctionalCollections v0.5.0
[ea4f424c] Gen v0.4.5
[682c06a0] JSON v0.21.4
[1914dd2f] MacroTools v0.5.11
[d96e819e] Parameters v0.12.3
[37e2e3b7] ReverseDiff v1.15.1
[276daf66] SpecialFunctions v2.3.1
[37e2e46d] LinearAlgebra `@stdlib/LinearAlgebra`
[9a3f8284] Random `@stdlib/Random`
[8dfed614] Test `@stdlib/Test`
Status `/private/var/folders/z3/y2vb_4653kjcytsygg5bkv7w0000gn/T/jl_NeEvI0/Manifest.toml`
[49dc2e85] Calculus v0.5.1
[d360d2e6] ChainRulesCore v1.16.0
[bbf7d656] CommonSubexpressions v0.3.0
[34da2185] Compat v4.9.0
[9a962f9c] DataAPI v1.15.0
[864edb3b] DataStructures v0.18.15
[163ba53b] DiffResults v1.1.0
[b552c78f] DiffRules v1.15.1
[31c24e10] Distributions v0.25.100
[ffbed154] DocStringExtensions v0.9.3
[fa6b7ba4] DualNumbers v0.6.8
[1a297f60] FillArrays v1.6.1
[f6369f11] ForwardDiff v0.10.36
[069b7b12] FunctionWrappers v1.1.3
[de31a74c] FunctionalCollections v0.5.0
[ea4f424c] Gen v0.4.5
[34004b35] HypergeometricFunctions v0.3.23
[92d709cd] IrrationalConstants v0.2.2
[692b3bcd] JLLWrappers v1.5.0
[682c06a0] JSON v0.21.4
[2ab3a3ac] LogExpFunctions v0.3.26
[1914dd2f] MacroTools v0.5.11
[e1d29d7a] Missings v1.1.0
[77ba4419] NaNMath v1.0.2
[bac558e1] OrderedCollections v1.6.2
[90014a1f] PDMats v0.11.17
[d96e819e] Parameters v0.12.3
[69de0a69] Parsers v2.7.2
[aea7be01] PrecompileTools v1.2.0
[21216c6a] Preferences v1.4.0
[1fd47b50] QuadGK v2.8.2
[189a3867] Reexport v1.2.2
[37e2e3b7] ReverseDiff v1.15.1
[79098fc4] Rmath v0.7.1
[a2af1166] SortingAlgorithms v1.1.1
[276daf66] SpecialFunctions v2.3.1
[90137ffa] StaticArrays v1.6.2
[1e83bf80] StaticArraysCore v1.4.2
[82ae8749] StatsAPI v1.6.0
[2913bbd2] StatsBase v0.34.0
[4c63d2b9] StatsFuns v1.3.0
[3a884ed6] UnPack v1.0.2
[efe28fd5] OpenSpecFun_jll v0.5.5+0
[f50d1b31] Rmath_jll v0.4.0+0
[0dad84c5] ArgTools v1.1.1 `@stdlib/ArgTools`
[56f22d72] Artifacts `@stdlib/Artifacts`
[2a0f44e3] Base64 `@stdlib/Base64`
[ade2ca70] Dates `@stdlib/Dates`
[f43a241f] Downloads v1.6.0 `@stdlib/Downloads`
[7b1f6079] FileWatching `@stdlib/FileWatching`
[b77e0a4c] InteractiveUtils `@stdlib/InteractiveUtils`
[b27032c2] LibCURL v0.6.3 `@stdlib/LibCURL`
[76f85450] LibGit2 `@stdlib/LibGit2`
[8f399da3] Libdl `@stdlib/Libdl`
[37e2e46d] LinearAlgebra `@stdlib/LinearAlgebra`
[56ddb016] Logging `@stdlib/Logging`
[d6f4376e] Markdown `@stdlib/Markdown`
[a63ad114] Mmap `@stdlib/Mmap`
[ca575930] NetworkOptions v1.2.0 `@stdlib/NetworkOptions`
[44cfe95a] Pkg v1.9.2 `@stdlib/Pkg`
[de0858da] Printf `@stdlib/Printf`
[3fa0cd96] REPL `@stdlib/REPL`
[9a3f8284] Random `@stdlib/Random`
[ea8e919c] SHA v0.7.0 `@stdlib/SHA`
[9e88b42a] Serialization `@stdlib/Serialization`
[6462fe0b] Sockets `@stdlib/Sockets`
[2f01184e] SparseArrays `@stdlib/SparseArrays`
[10745b16] Statistics v1.9.0 `@stdlib/Statistics`
[4607b0f0] SuiteSparse `@stdlib/SuiteSparse`
[fa267f1f] TOML v1.0.3 `@stdlib/TOML`
[a4e569a6] Tar v1.10.0 `@stdlib/Tar`
[8dfed614] Test `@stdlib/Test`
[cf7118a7] UUIDs `@stdlib/UUIDs`
[4ec0a83e] Unicode `@stdlib/Unicode`
[e66e0078] CompilerSupportLibraries_jll v1.0.5+0 `@stdlib/CompilerSupportLibraries_jll`
[deac9b47] LibCURL_jll v7.84.0+0 `@stdlib/LibCURL_jll`
[29816b5a] LibSSH2_jll v1.10.2+0 `@stdlib/LibSSH2_jll`
[c8ffd9c3] MbedTLS_jll v2.28.2+0 `@stdlib/MbedTLS_jll`
[14a3606d] MozillaCACerts_jll v2022.10.11 `@stdlib/MozillaCACerts_jll`
[4536629a] OpenBLAS_jll v0.3.21+4 `@stdlib/OpenBLAS_jll`
[05823500] OpenLibm_jll v0.8.1+0 `@stdlib/OpenLibm_jll`
[bea87d4a] SuiteSparse_jll v5.10.1+6 `@stdlib/SuiteSparse_jll`
[83775a58] Zlib_jll v1.2.13+0 `@stdlib/Zlib_jll`
[8e850b90] libblastrampoline_jll v5.8.0+0 `@stdlib/libblastrampoline_jll`
[8e850ede] nghttp2_jll v1.48.0+0 `@stdlib/nghttp2_jll`
[3f19e933] p7zip_jll v17.4.0+0 `@stdlib/p7zip_jll`
Precompiling project...
53 dependencies successfully precompiled in 23 seconds. 7 already precompiled.
Testing Running tests...
WARNING: method definition for process! at /Users/username/.julia/packages/Gen/Dne3u/src/modeling_library/switch/update.jl:85 declares type variable DV but does not use it.
WARNING: method definition for #importance_resampling#354 at /Users/username/.julia/packages/Gen/Dne3u/src/inference/importance.jl:70 declares type variable W but does not use it.
WARNING: method definition for #importance_resampling#354 at /Users/username/.julia/packages/Gen/Dne3u/src/inference/importance.jl:70 declares type variable V but does not use it.
Test Summary: | Pass Total Time
autodiff | 2 2 0.2s
Test Summary: | Pass Total Time
diff arithmetic | 12 12 0.0s
Test Summary: | Pass Total Time
diff vectors | 20 20 0.1s
Test Summary: |Time
diff dictionaries | None 0.0s
Test Summary: |Time
diff sets | None 0.0s
Test Summary: | Pass Total Time
dynamic selection | 18 18 0.1s
Test Summary: | Pass Total Time
all selection | 4 4 0.0s
Test Summary: | Pass Total Time
complement selection | 13 13 0.0s
Test Summary: | Pass Total Time
static assignment to/from array | 12 12 0.1s
Test Summary: | Pass Total Time
dynamic assignment to/from array | 12 12 0.2s
Test Summary: | Pass Total Time
dynamic assignment copy constructor | 4 4 0.0s
Test Summary: | Pass Total Time
internal vector assignment to/from array | 8 8 0.1s
Test Summary: | Pass Total Time
dynamic assignment merge | 10 10 0.0s
Test Summary: | Pass Total Time
dynamic assignment variadic merge | 2 2 0.1s
Test Summary: | Pass Total Time
static assignment merge | 10 10 0.2s
Test Summary: | Pass Total Time
static assignment variadic merge | 2 2 0.1s
Test Summary: | Pass Total Time
static assignment errors | 12 12 0.0s
Test Summary: | Pass Total Time
dynamic assignment errors | 7 7 0.0s
Test Summary: | Pass Total Time
dynamic assignment overwrite | 16 16 0.0s
Test Summary: | Pass Total Time
dynamic assignment constructor | 2 2 0.0s
Test Summary: | Pass Total Time
choice map equality | 6 6 0.0s
Test Summary: | Pass Total Time
choice map nested view | 8 8 0.2s
Test Summary: | Pass Total Time
filtering choicemaps with selections | 6 6 0.0s
Test Summary: | Pass Total Time
invalid choice map constructor | 1 1 0.0s
Test Summary: | Pass Total Time
choicemap hashing | 25 25 0.1s
Test Summary: | Pass Total Time
update(...) shorthand assuming unchanged args (dynamic modeling lang) | 3 3 2.9s
Test Summary: | Pass Total Time
regenerate(...) shorthand assuming unchanged args (dynamic modeling lang) | 2 2 0.1s
Test Summary: | Pass Total Time
update(...) shorthand assuming unchanged args (static modeling lang) | 3 3 0.6s
Test Summary: | Pass Total Time
regenerate(...) shorthand assuming unchanged args (static modeling lang) | 2 2 0.2s
Test Summary: | Pass Total Time
macros in dynamic DSL | 7 7 0.1s
Test Summary: | Pass Total Time
macros in static DSL | 7 7 0.2s
Test Summary: | Pass Total Time
Dynamic DSL | 148 148 3.5s
Test Summary: | Pass Total Time
static DSL | 172 172 1.2s
Test Summary: | Pass Total Time
optional positional args (calling + GFI) | 42 42 4.6s
Test Summary: | Pass Total Time
optional positional args (combinators) | 16 16 1m28.8s
Test Summary: | Pass Total Time
static IR | 138 138 3.6s
Test Summary: | Pass Total Time
tilde syntax | 12 12 0.3s
Test Summary: | Pass Total Time
importance sampling | 22 22 0.1s
Test Summary: | Pass Total Time
SGD training | 10 10 1.6s
Test Summary: | Pass Total Time
lecture! | 2 2 4.2s
Test Summary: | Pass Total Time
black box variational inference | 4 4 3.8s
opt_theta: -0.37567952986819614
true optimum log_marginal_likelihood: -14.79249158646238
4.845693 seconds (80.93 M allocations: 3.204 GiB, 5.75% gc time, 2.40% compilation time)
final theta: -0.3688737663341869
final elbo estimate: -14.83005893286215
8.781272 seconds (155.11 M allocations: 6.286 GiB, 6.19% gc time, 0.50% compilation time)
final theta: -0.3865867550388555
final elbo estimate: -14.771737102011121
Test Summary: | Pass Total Time
minimal vae | 3 3 13.8s
Test Summary: | Pass Total Time
hmm forward alg | 1 1 0.1s
Test Summary: | Pass Total Time
particle filtering | 8 8 3.3s
Test Summary: |Time
mala | None 0.1s
Test Summary: |Time
hmc | None 0.0s
Test Summary: |Time
map_optimize | None 0.0s
Test Summary: |Time
elliptical slice | None 1.6s
Test Summary: | Pass Total Time
Involutive MCMC | 6 6 2.0s
Test Summary: | Pass Total Time
trace translators | 5 5 1.0s
Test Summary: | Pass Total Time
Kernel DSL | 1 1 0.1s
Test Summary: | Pass Total Time
custom deterministic generative function with custom update and gradient | 23 23 0.1s
Test Summary: | Pass Total Time
custom gradient GF | 2 2 0.0s
Test Summary: | Pass Total Time
custom update GF | 12 12 0.1s
Test Summary: | Pass Total Time
bernoulli | 2 2 1.3s
Test Summary: | Pass Total Time
beta | 9 9 1.4s
Test Summary: | Pass Total Time
categorical | 7 7 0.1s
Test Summary: | Pass Total Time
gamma | 5 5 0.0s
Test Summary: | Pass Total Time
inv_gamma | 5 5 0.0s
Test Summary: | Pass Total Time
normal | 5 5 0.0s
Test Summary: | Pass Total Time
zero-dimensional broadcasted normal | 6 6 1.5s
Test Summary: | Pass Total Time
array normal (trivially broadcasted: all args have same shape) | 6 6 2.0s
Test Summary: | Pass Total Time
broadcasted normal | 14 14 0.8s
Test Summary: | Pass Total Time
multivariate normal | 9 9 0.2s
Test Summary: | Pass Total Time
uniform | 5 5 0.0s
Test Summary: | Pass Total Time
uniform_discrete | 3 3 0.0s
Test Summary: | Pass Total Time
piecewise_uniform | 8 8 1.4s
Test Summary: | Pass Total Time
beta uniform mixture | 6 6 1.3s
Test Summary: | Pass Total Time
geometric | 4 4 1.3s
Test Summary: | Pass Total Time
binom | 5 5 1.4s
Test Summary: | Pass Total Time
neg_binom | 5 5 0.0s
Test Summary: | Pass Total Time
exponential | 4 4 1.3s
Test Summary: | Pass Total Time
poisson | 4 4 0.0s
Test Summary: | Pass Total Time
laplace | 3 3 0.0s
Test Summary: | Pass Total Time
cauchy | 3 3 0.0s
Test Summary: | Pass Total Time
choice_at combinator | 74 74 0.2s
Test Summary: | Pass Total Time
call_at combinator | 74 74 0.2s
Test Summary: | Pass Total Time
map combinator | 139 139 1.2s
Test Summary: | Pass Total Time
unfold combinator | 199 199 2.0s
Test Summary: | Pass Total Time
recurse node numbering | 54 54 0.0s
Test Summary: | Pass Total Time
simple pcfg | 88 88 0.7s
Test Summary: | Pass Total Time
switch combinator | 72 72 0.9s
Test Summary: | Pass Total Time
dist DSL (untyped) | 12 12 8.2s
Test Summary: | Pass Total Time
dist DSL (typed) | 12 12 2.2s
Test Summary: | Pass Total Time
fixed mixture of different distributions | 11 11 4.2s
Test Summary: | Pass Total Time
mixture of normals | 20 20 3.3s
Test Summary: | Pass Total Time
mixture of binomial | 9 9 1.5s
Test Summary: | Pass Total Time
mixture of multivariate normals | 15 15 9.6s
Testing Gen tests passed
julia>
準備が整いました。実際にjulia上でGENを動かしてみましょう。
GENの使用例1:線形回帰
フェイクデータセットを使って線形回帰を実行します。
最初にモデルを生成します。以下のモデルはslopeとinterceptをパラメータを設定します。@traceでランダム変数を生成し、後で推論プログラムで参照します。
using Gen
@gen function my_model(xs::Vector{Float64})
slope = @trace(normal(0, 2), :slope)
intercept = @trace(normal(0, 10), :intercept)
for (i, x) in enumerate(xs)
@trace(normal(slope * x + intercept, 1), "y-$i")
end
end
julia> using Gen
[ Info: Precompiling Gen [ea4f424c-a589-11e8-07c0-fd5c91b9da4a]
WARNING: method definition for process! at /Users/username/.julia/packages/Gen/Dne3u/src/modeling_library/switch/update.jl:85 declares type variable DV but does not use it.
WARNING: method definition for #importance_resampling#354 at /Users/username/.julia/packages/Gen/Dne3u/src/inference/importance.jl:70 declares type variable W but does not use it.
WARNING: method definition for #importance_resampling#354 at /Users/username/.julia/packages/Gen/Dne3u/src/inference/importance.jl:70 declares type variable V but does not use it.
julia> @gen function my_model(xs::Vector{Float64})
slope = @trace(normal(0, 2), :slope)
intercept = @trace(normal(0, 10), :intercept)
for (i, x) in enumerate(xs)
@trace(normal(slope * x + intercept, 1), "y-$i")
end
end
DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Vector{Float64}], false, Union{Nothing, Some{Any}}[nothing], var"##my_model#292", Bool[0], false)
次に推論プログラムを実装します。モデルのトレースを実行します。GENの標準推論ライブラリを使って、正規のJuliaのコードとして書きます。推論プログラムは以下のデータセットをとり、slopeとintercept のパラメータを適合させるためにMCMCアルゴリズムを走らせます。
function my_inference_program(xs::Vector{Float64}, ys::Vector{Float64}, num_iters::Int)
# Create a set of constraints fixing the
# y coordinates to the observed y values
constraints = choicemap()
for (i, y) in enumerate(ys)
constraints["y-$i"] = y
end
# Run the model, constrained by `constraints`,
# to get an initial execution trace
(trace, _) = generate(my_model, (xs,), constraints)
# Iteratively update the slope then the intercept,
# using Gen's metropolis_hastings operator.
for iter=1:num_iters
(trace, _) = metropolis_hastings(trace, select(:slope))
(trace, _) = metropolis_hastings(trace, select(:intercept))
end
# From the final trace, read out the slope and
# the intercept.
choices = get_choices(trace)
return (choices[:slope], choices[:intercept])
end
julia> function my_inference_program(xs::Vector{Float64}, ys::Vector{Float64}, njulia> function my_inference_program(xs::Vector{Float64}, ys::Vector{Float64}, num_iters::Int)
# Create a set of constraints fixing the
# y coordinates to the observed y values
constraints = choicemap()
for (i, y) in enumerate(ys)
constraints["y-$i"] = y
end
# Run the model, constrained by `constraints`,
# to get an initial execution trace
(trace, _) = generate(my_model, (xs,), constraints)
# Iteratively update the slope then the intercept,
# using Gen's metropolis_hastings operator.
for iter=1:num_iters
(trace, _) = metropolis_hastings(trace, select(:slope))
(trace, _) = metropolis_hastings(trace, select(:intercept))
end
# From the final trace, read out the slope and
# the intercept.
choices = get_choices(trace)
return (choices[:slope], choices[:intercept])
end
my_inference_program (generic function with 1 method)
それでは、この推論プログラムで実際に以下のX,Yデータを入力し、結果を取得しましょう。
xs = [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]
ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90]
(slope, intercept) = my_inference_program(xs, ys, 1000)
println("slope: $slope, intercept: $intercept")
julia> xs = [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]
10-element Vector{Float64}:
1.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0
10.0
julia> ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90]
10-element Vector{Float64}:
8.23
5.87
3.99
2.59
0.23
-0.66
-3.53
-6.91
-7.24
-9.9
julia> (slope, intercept) = my_inference_program(xs, ys, 1000)
(-1.9471036551893741, 10.134566406450329)
julia> println("slope: $slope, intercept: $intercept")
slope: -1.9471036551893741, intercept: 10.134566406450329
Intercept(切片の座標)とslope(傾き)が出力されました。
単純な線形回帰でした。
GENの使用例2:tensorflowによるニューラルネットワーク
以下、もう少し複雑なコードの例を見てみましょう。GENにtensorflowのインターフェイスがあります。GenTFライブラリです。
juliaからGenTFライブラリを使って、tensorflowでニューラルネットワークを組んでみます。またjuliaにはPyCallというライブラリがあり、このライブラリを用いるとpythonのコードをjulia上で実行することができます。tensorflowにはpython用のAPIが用意されているので、このPyCallを経由して、juliaからtensorflowを使うことができます。
手書き文字を認識するCNNをtensorflowを使ってjulia上で実装します。
前準備として、PyCallとGenTFライブラリを追加します。
pythonバージョン3.8用に仮想環境を作ります。
conda create -n python38 python=3.8
conda activate python38
python38をactivateすると、コマンドプロンプトが(base)から(python38)に変更されます。
pythonのバージョンを確認してみましょう。
$python --version
Python 3.8.16
この仮想環境にtensorflowをインストールします。
conda install -c apple tensorflow-deps
python -m pip install tensorflow-macos
python -m pip install tensorflow-metal
これでtensorflowが使えます。この環境でコマンドプロンプトからjuliaを起動します。
(python38) UsernoMacBook-Air:~ username$ julia
_
_ _ _(_)_ | Documentation: https://docs.julialang.org
(_) | (_) (_) |
_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.
| | | | | | |/ _` | |
| | |_| | | | (_| | | Version 1.9.2 (2023-07-05)
_/ |\__'_|_|_|\__'_| | Official https://julialang.org/ release
|__/ |
julia>
PyCallのpython環境を確認してみます。PyCallがインストールされていれば、pythonの実行パスを表示します。
julia> using PyCall; println(PyCall.python)
/Users/username/anaconda3/envs/python38/bin/python3.8
バージョン3.8の環境でtensorflowが使えることがわかります。
以下、GenTFとPyCallを追加しましょう。PyCall追加後にPythonの環境を確認してください。
PyCall追加と環境設定
Julia REPLから']'
を入力してパッケージマネージャを起動します。
(@v1.9) pkg> add PyCall
Resolving package versions...
Installed VersionParsing ─ v1.3.0
Installed Conda ────────── v1.9.1
Installed PyCall ───────── v1.96.1
Updating `~/.julia/environments/v1.9/Project.toml`
[438e738f] + PyCall v1.96.1
Updating `~/.julia/environments/v1.9/Manifest.toml`
[8f4d0f93] + Conda v1.9.1
[438e738f] + PyCall v1.96.1
[81def892] + VersionParsing v1.3.0
Building Conda ─→ `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/8c86e48c0db1564a1d49548d3515ced5d604c408/build.log`
Building PyCall → `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/43d304ac6f0354755f1d60730ece8c499980f7ba/build.log`
Precompiling project...
3 dependencies successfully precompiled in 7 seconds. 171 already precompiled.
環境設定
pythonの環境設定を変更する場合は、ENV["python"] ="/use/local/bin" のように、環境変数にpythonの実行パスを記述してください。
println(PyCall.python)を使って確認したように、現在はpython3.8のパスが設定されいるはずです。python3.8になっていれば、変更する必要はありません。
julia> using Pkg; ENV["PYTHON"] = "<python>"; Pkg.build("PyCall")
Building Conda ─→ `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/8c86e48c0db1564a1d49548d3515ced5d604c408/build.log`
Building PyCall → `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/43d304ac6f0354755f1d60730ece8c499980f7ba/build.log`
julia> using PyCall; println(PyCall.python)
/Users/username/anaconda3/envs/python38/bin/python3.8
GenTFの追加
GenTFは開発中のため、パッケージに含まれていません。パッケージマネージャを使って、add GenTFでは追加できません。
以下のURLから追加します。
https://github.com/probcomp/GenTF
juliaのREPLから']'
を入力してパッケージマネージャーを起動します。
(@v1.9) pkg>のプロンプトで以下のコマンドを実行します。
(@v1.9) pkg> add https://github.com/probcomp/GenTF
Cloning git-repo `https://github.com/probcomp/GenTF`
Updating git-repo `https://github.com/probcomp/GenTF`
Resolving package versions...
Installed OffsetArrays ──────────────── v1.12.10
Installed ShowCases ─────────────────── v0.1.0
Installed ContextVariablesX ─────────── v0.1.3
Installed InlineStrings ─────────────── v1.4.0
Installed ZipFile ───────────────────── v0.10.1
Installed InitialValues ─────────────── v0.3.1
Installed CEnum ─────────────────────── v0.4.2
Installed InvertedIndices ───────────── v1.3.0
Installed FileIO ────────────────────── v1.16.1
Installed NNlib ─────────────────────── v0.9.4
Installed MAT ───────────────────────── v0.10.5
Installed BFloat16s ─────────────────── v0.4.2
Installed PrettyPrint ───────────────── v0.2.0
Installed DataFrames ────────────────── v1.6.1
Installed MLUtils ───────────────────── v0.4.3
Installed TupleTools ────────────────── v1.3.0
Installed LLVM ──────────────────────── v6.1.0
Installed Chemfiles ─────────────────── v0.10.41
Installed JLD2 ──────────────────────── v0.4.33
Installed GZip ──────────────────────── v0.6.1
Installed FLoopsBase ────────────────── v0.1.1
Installed AtomsBase ─────────────────── v0.3.4
Installed SentinelArrays ────────────── v1.4.0
Installed StringEncodings ───────────── v0.3.7
Installed MicroCollections ──────────── v0.1.4
Installed DefineSingletons ──────────── v0.1.2
Installed NameResolution ────────────── v0.1.5
Installed WorkerUtilities ───────────── v1.6.1
Installed UnsafeAtomicsLLVM ─────────── v0.1.3
Installed WeakRefStrings ────────────── v1.4.2
Installed MappedArrays ──────────────── v0.4.2
Installed PaddedViews ───────────────── v0.5.12
Installed BufferedStreams ───────────── v1.2.1
Installed HDF5_jll ──────────────────── v1.12.2+2
Installed Glob ──────────────────────── v1.3.1
Installed BangBang ──────────────────── v0.3.39
Installed LazyModules ───────────────── v0.3.1
Installed MLStyle ───────────────────── v0.4.17
Installed FLoops ────────────────────── v0.2.1
Installed MosaicViews ───────────────── v0.3.4
Installed PeriodicTable ─────────────── v1.1.4
Installed StructTypes ───────────────── v1.10.0
Installed ImageCore ─────────────────── v0.10.1
Installed ArgCheck ──────────────────── v2.3.0
Installed TableTraits ───────────────── v1.0.1
Installed JSON3 ─────────────────────── v1.13.2
Installed ImageShow ─────────────────── v0.3.8
Installed DataDeps ──────────────────── v0.7.11
Installed Setfield ──────────────────── v1.1.1
Installed DataValueInterfaces ───────── v1.0.0
Installed PooledArrays ──────────────── v1.4.2
Installed ConstructionBase ──────────── v1.5.3
Installed AbstractFFTs ──────────────── v1.5.0
Installed LLVMExtra_jll ─────────────── v0.0.23+0
Installed CSV ───────────────────────── v0.10.11
Installed StackViews ────────────────── v0.1.1
Installed CompositionsBase ──────────── v0.1.2
Installed UnsafeAtomics ─────────────── v0.2.1
Installed Strided ───────────────────── v1.2.3
Installed KernelAbstractions ────────── v0.9.8
Installed HDF5 ──────────────────────── v0.16.16
Installed MLDatasets ────────────────── v0.7.12
Installed JuliaVariables ────────────── v0.2.4
Installed PrettyTables ──────────────── v2.2.7
Installed GPUArraysCore ─────────────── v0.1.5
Installed SimpleTraits ──────────────── v0.9.4
Installed Atomix ────────────────────── v0.1.0
Installed Crayons ───────────────────── v4.1.1
Installed Adapt ─────────────────────── v3.6.2
Installed ImageBase ─────────────────── v0.1.7
Installed InternedStrings ───────────── v0.7.0
Installed Chemfiles_jll ─────────────── v0.10.4+0
Installed Transducers ───────────────── v0.4.78
Installed NPZ ───────────────────────── v0.4.3
Installed Tables ────────────────────── v1.10.1
Installed Baselet ───────────────────── v0.1.1
Installed IteratorInterfaceExtensions ─ v1.0.0
Installed UnitfulAtomic ─────────────── v1.0.0
Installed FilePathsBase ─────────────── v0.9.20
Installed SplittablesBase ───────────── v0.1.15
Installed StringManipulation ────────── v0.3.0
Installed Pickle ────────────────────── v0.3.3
Downloaded artifact: HDF5
Downloaded artifact: LLVMExtra
Downloaded artifact: Chemfiles
Updating `~/.julia/environments/v1.9/Project.toml`
[1956f2fc] + GenTF v0.1.0 `https://github.com/probcomp/GenTF#master`
Updating `~/.julia/environments/v1.9/Manifest.toml`
[621f4979] + AbstractFFTs v1.5.0
[79e6a3ab] + Adapt v3.6.2
[dce04be8] + ArgCheck v2.3.0
[a9b6321e] + Atomix v0.1.0
[a963bdd2] + AtomsBase v0.3.4
[ab4f0b2a] + BFloat16s v0.4.2
[198e06fe] + BangBang v0.3.39
[9718e550] + Baselet v0.1.1
[e1450e63] + BufferedStreams v1.2.1
[fa961155] + CEnum v0.4.2
[336ed68f] + CSV v0.10.11
[46823bd8] + Chemfiles v0.10.41
[a33af91c] + CompositionsBase v0.1.2
[187b0558] + ConstructionBase v1.5.3
[6add18c4] + ContextVariablesX v0.1.3
[a8cc5b0e] + Crayons v4.1.1
[124859b0] + DataDeps v0.7.11
[a93c6f00] + DataFrames v1.6.1
[e2d170a0] + DataValueInterfaces v1.0.0
[244e2a9f] + DefineSingletons v0.1.2
[cc61a311] + FLoops v0.2.1
[b9860ae5] + FLoopsBase v0.1.1
[5789e2e9] + FileIO v1.16.1
[48062228] + FilePathsBase v0.9.20
[46192b85] + GPUArraysCore v0.1.5
[92fee26a] + GZip v0.6.1
[1956f2fc] + GenTF v0.1.0 `https://github.com/probcomp/GenTF#master`
[c27321d9] + Glob v1.3.1
....
[f67ccb44] + HDF5 v0.16.16
[dad2f222] + LLVMExtra_jll v0.0.23+0
[8ba89e20] + Distributed
[9fa8497b] + Future
[4af54fe1] + LazyArtifacts
Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated -m`
Building HDF5 → `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/114e20044677badbc631ee6fdc80a67920561a29/build.log`
Precompiling project...
✓ Plots
96 dependencies successfully precompiled in 77 seconds. 172 already precompiled.
1 dependency precompiled but a different version is currently loaded. Restart julia to access the new version
ライブラリの準備が整ったので、juliaでtensorflowを使ったニューラルネットワークを組んでみましょう。
GenTF tensorflow によるCNN
以下はGENで実行するtensorflowのサンプルコードです。文字認識用の畳み込みニューラルネットワーク(CNN)を組んでいます。
よく知られた機械学習用のデータセットMNISTを使います。以下のURLに公開されています。手書き数字の画像データセットです。
http://yann.lecun.com/exdb/mnist/
最初にトレーニング用のデータセットを読み込みます。
import Random
Random.seed(1)
import MLDatasets
train_x, train_y = MLDatasets.MNIST.traindata()
mutable struct DataLoader
cur_id::Int
order::Vector{Int}
end
DataLoader() = DataLoader(1, Random.shuffle(1:60000))
function next_batch(loader::DataLoader, batch_size)
x = zeros(Float64, batch_size, 784)
y = Vector{Int}(undef, batch_size)
for i=1:batch_size
x[i, :] = reshape(train_x[:,:,loader.cur_id], (28*28))
y[i] = train_y[loader.cur_id] + 1
loader.cur_id += 1
if loader.cur_id > 60000
loader.cur_id = 1
end
end
x, y
end
unction load_test_set()
test_x, test_y = MLDatasets.MNIST.testdata()
N = length(test_y)
x = zeros(Float64, N, 784)
y = Vector{Int}(undef, N)
for i=1:N
x[i, :] = reshape(test_x[:,:,i], (28*28))
y[i] = test_y[i]+1
end
x, y
end
const loader = DataLoader()
(test_x, test_y) = load_test_set()
julia> import Random
julia> Random.seed!(1)
Random.TaskLocalRNG()
julia> import MLDatasets
julia> train_x, train_y = MLDatasets.MNIST.traindata()
┌ Warning: MNIST.traindata() is deprecated, use `MNIST(split=:train)[:]` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/Ob4aN/src/datasets/vision/mnist.jl:187
(features = [0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8;;; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8;;; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8;;; … ;;; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8;;; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8;;; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8], targets = [5, 0, 4, 1, 9, 2, 1, 3, 1, 4 … 9, 2, 9, 5, 1, 8, 3, 5, 6, 8])
julia> mutable struct DataLoader
cur_id::Int
order::Vector{Int}
end
julia> DataLoader() = DataLoader(1, Random.shuffle(1:60000))
DataLoader
julia> function next_batch(loader::DataLoader, batch_size)
x = zeros(Float64, batch_size, 784)
y = Vector{Int}(undef, batch_size)
for i=1:batch_size
x[i, :] = reshape(train_x[:,:,loader.cur_id], (28*28))
y[i] = train_y[loader.cur_id] + 1
loader.cur_id += 1
if loader.cur_id > 60000
loader.cur_id = 1
end
end
x, y
end
next_batch (generic function with 1 method)
julia> function load_test_set()
test_x, test_y = MLDatasets.MNIST.testdata()
N = length(test_y)
x = zeros(Float64, N, 784)
y = Vector{Int}(undef, N)
for i=1:N
x[i, :] = reshape(test_x[:,:,i], (28*28))
y[i] = test_y[i]+1
end
x, y
end
load_test_set (generic function with 1 method)
julia> const loader = DataLoader()
WARNING: redefinition of constant loader. This may fail, cause incorrect answers, or produce other errors.
DataLoader(1, [16889, 35959, 56640, 30743, 23204, 59111, 51586, 55462, 8914, 52722 … 20038, 28424, 8312, 3503, 57045, 19635, 1025, 28306, 57085, 10336])
julia> (test_x, test_y) = load_test_set()
┌ Warning: MNIST.testdata() is deprecated, use `MNIST(split=:test)[:]` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/Ob4aN/src/datasets/vision/mnist.jl:195
([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [8, 3, 2, 1, 5, 2, 5, 10, 6, 10 … 8, 9, 10, 1, 2, 3, 4, 5, 6, 7])
tensorflowを使ってニューラルネットワークを作成します。
using Gen
using GenTF
using PyCall
@pyimport tensorflow as tf
@pyimport tensorflow.nn as nn
# v1のコードをtf2 に変更
@pyimport tensorflow.compat.v1 as v1
xs = v1.placeholder(tf.float64) # N x 784
function weight_variable(shape)
initial = 0.001 * randn(shape...)
tf.Variable(initial)
end
function bias_variable(shape)
initial = fill(.1, shape...)
tf.Variable(initial)
end
function conv2d(x, W)
nn.conv2d(x, W, (1, 1, 1, 1), "SAME")
end
function max_pool_2x2(x)
nn.max_pool(x, (1, 2, 2, 1), (1, 2, 2, 1), "SAME")
end
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
x_image = tf.reshape(xs, [-1, 28, 28, 1])
h_conv1 = nn.relu(tf.add(conv2d(x_image, W_conv1), b_conv1))
h_pool1 = max_pool_2x2(h_conv1)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = nn.relu(tf.add(conv2d(h_pool1, W_conv2), b_conv2))
h_pool2 = max_pool_2x2(h_conv2)
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = nn.relu(tf.add(tf.matmul(h_pool2_flat, W_fc1), b_fc1))
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
probs = nn.softmax(tf.add(tf.matmul(h_fc1, W_fc2), b_fc2), axis=1) # N x 10
const sess = tf.Session()
const params = [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2]
const net = TFFunction(params, [xs], probs, sess)
@gen function f(xs::Matrix{Float64})
(N, D) = size(xs)
@assert D == 784
probs = @trace(net(xs), :net)
@assert size(probs) == (N, 10)
ys = Vector{Int}(undef, N)
for i=1:N
ys[i] = @trace(categorical(probs[i,:]), (:y, i))
end
return ys
end
#########
# 学習 #
#########
@pyimport tensorflow.train as train
#update = ParamUpdate(ADAM(1e-4, 0.9, 0.999, 1e-08), net)
update = ParamUpdate(FixedStepGradientDescent(0.001), net)
for i=1:1000
(xs, ys) = next_batch(loader, 100)
@assert size(xs) == (100, 784)
@assert size(ys) == (100,)
constraints = choicemap()
for (i, y) in enumerate(ys)
constraints[(:y, i)] = y
end
(trace, weight) = generate(f, (xs,), constraints)
# increments gradient accumulators
accumulate_param_gradients!(trace, nothing)
# performs SGD update and then resets gradient accumulators
apply!(update)
println("i: $i, weight: $weight")
end
##################################
# テストデータによる推論 #
##################################
for i=1:length(test_y)
x = test_x[i:i,:]
@assert size(x) == (1, 784)
true_y = test_y[i]
pred_y = f(x)
@assert length(pred_y) == 1
println("true: $true_y, predicted: $(pred_y[1])")
end
julia> xs = v1.placeholder(tf.float64)
PyObject <tf.Tensor 'Placeholder_9:0' shape=<unknown> dtype=float64>
julia> function weight_variable(shape)
initial = 0.001 * randn(shape...)
tf.Variable(initial)
end
weight_variable (generic function with 1 method)
julia> function bias_variable(shape)
initial = fill(.1, shape...)
tf.Variable(initial)
end
bias_variable (generic function with 1 method)
julia> function conv2d(x, W)
nn.conv2d(x, W, (1, 1, 1, 1), "SAME")
end
conv2d (generic function with 1 method)
julia> function max_pool_2x2(x)
nn.max_pool(x, (1, 2, 2, 1), (1, 2, 2, 1), "SAME")
end
max_pool_2x2 (generic function with 1 method)
julia> W_conv1 = weight_variable([5, 5, 1, 32])
PyObject <tf.Variable 'Variable_23:0' shape=(5, 5, 1, 32) dtype=float64>
julia> b_conv1 = bias_variable([32])
PyObject <tf.Variable 'Variable_24:0' shape=(32,) dtype=float64>
julia> x_image = tf.reshape(xs, [-1, 28, 28, 1])
PyObject <tf.Tensor 'Reshape_2:0' shape=(None, 28, 28, 1) dtype=float64>
julia> h_conv1 = nn.relu(tf.add(conv2d(x_image, W_conv1), b_conv1))
PyObject <tf.Tensor 'Relu_3:0' shape=(None, 28, 28, 32) dtype=float64>
julia> h_pool1 = max_pool_2x2(h_conv1)
PyObject <tf.Tensor 'MaxPool2d_2:0' shape=(None, 14, 14, 32) dtype=float64>
julia> W_conv2 = weight_variable([5, 5, 32, 64])
PyObject <tf.Variable 'Variable_25:0' shape=(5, 5, 32, 64) dtype=float64>
julia> b_conv2 = bias_variable([64])
PyObject <tf.Variable 'Variable_26:0' shape=(64,) dtype=float64>
julia> h_conv2 = nn.relu(tf.add(conv2d(h_pool1, W_conv2), b_conv2))
PyObject <tf.Tensor 'Relu_4:0' shape=(None, 14, 14, 64) dtype=float64>
julia> h_pool2 = max_pool_2x2(h_conv2)
PyObject <tf.Tensor 'MaxPool2d_3:0' shape=(None, 7, 7, 64) dtype=float64>
julia> W_fc1 = weight_variable([7*7*64, 1024])
PyObject <tf.Variable 'Variable_27:0' shape=(3136, 1024) dtype=float64>
julia> b_fc1 = bias_variable([1024])
PyObject <tf.Variable 'Variable_28:0' shape=(1024,) dtype=float64>
julia> h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
PyObject <tf.Tensor 'Reshape_3:0' shape=(None, 3136) dtype=float64>
julia> h_fc1 = nn.relu(tf.add(tf.matmul(h_pool2_flat, W_fc1), b_fc1))
PyObject <tf.Tensor 'Relu_5:0' shape=(None, 1024) dtype=float64>
julia> W_fc2 = weight_variable([1024, 10])
PyObject <tf.Variable 'Variable_29:0' shape=(1024, 10) dtype=float64>
julia> b_fc2 = bias_variable([10])
PyObject <tf.Variable 'Variable_30:0' shape=(10,) dtype=float64>
julia> probs = nn.softmax(tf.add(tf.matmul(h_fc1, W_fc2), b_fc2), axis=1) # N x 10
PyObject <tf.Tensor 'Softmax_3:0' shape=(None, 10) dtype=float64>
julia> const sess = v1.Session()
2023-08-24 23:16:44.334990: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-08-24 23:16:44.335655: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
WARNING: redefinition of constant sess. This may fail, cause incorrect answers, or produce other errors.
PyObject <tensorflow.python.client.session.Session object at 0x2c5416c40>
julia> const params = [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2]
8-element Vector{PyObject}:
PyObject <tf.Variable 'Variable_23:0' shape=(5, 5, 1, 32) dtype=float64>
PyObject <tf.Variable 'Variable_24:0' shape=(32,) dtype=float64>
PyObject <tf.Variable 'Variable_25:0' shape=(5, 5, 32, 64) dtype=float64>
PyObject <tf.Variable 'Variable_26:0' shape=(64,) dtype=float64>
PyObject <tf.Variable 'Variable_27:0' shape=(3136, 1024) dtype=float64>
PyObject <tf.Variable 'Variable_28:0' shape=(1024,) dtype=float64>
PyObject <tf.Variable 'Variable_29:0' shape=(1024, 10) dtype=float64>
PyObject <tf.Variable 'Variable_30:0' shape=(10,) dtype=float64>
julia> const net = TFFunction(params, [xs], probs, sess)
2023-08-24 23:17:08.518303: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2023-08-24 23:17:08.706835: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
TFFunction(PyObject <tensorflow.python.client.session.Session object at 0x2c5416c40>, PyObject[PyObject <tf.Tensor 'Placeholder_9:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Softmax_3:0' shape=(None, 10) dtype=float64>, Dict{PyObject, PyObject}(PyObject <tf.Variable 'Variable_26:0' shape=(64,) dtype=float64> => PyObject <tf.Variable 'Variable_34:0' shape=(64,) dtype=float64>, PyObject <tf.Variable 'Variable_23:0' shape=(5, 5, 1, 32) dtype=float64> => PyObject <tf.Variable 'Variable_31:0' shape=(5, 5, 1, 32) dtype=float64>, PyObject <tf.Variable 'Variable_25:0' shape=(5, 5, 32, 64) dtype=float64> => PyObject <tf.Variable 'Variable_33:0' shape=(5, 5, 32, 64) dtype=float64>, PyObject <tf.Variable 'Variable_29:0' shape=(1024, 10) dtype=float64> => PyObject <tf.Variable 'Variable_37:0' shape=(1024, 10) dtype=float64>, PyObject <tf.Variable 'Variable_28:0' shape=(1024,) dtype=float64> => PyObject <tf.Variable 'Variable_36:0' shape=(1024,) dtype=float64>, PyObject <tf.Variable 'Variable_30:0' shape=(10,) dtype=float64> => PyObject <tf.Variable 'Variable_38:0' shape=(10,) dtype=float64>, PyObject <tf.Variable 'Variable_24:0' shape=(32,) dtype=float64> => PyObject <tf.Variable 'Variable_32:0' shape=(32,) dtype=float64>, PyObject <tf.Variable 'Variable_27:0' shape=(3136, 1024) dtype=float64> => PyObject <tf.Variable 'Variable_35:0' shape=(3136, 1024) dtype=float64>), PyObject[PyObject <tf.Tensor 'gradients_6/Reshape_2_grad/Reshape:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Placeholder_10:0' shape=<unknown> dtype=float64>, PyObject <tf.Operation 'group_deps_2' type=NoOp>, PyObject <tf.Operation 'group_deps_3' type=NoOp>, PyObject <tf.Tensor 'Placeholder_11:0' shape=() dtype=float64>)
julia> @gen function f(xs::Matrix{Float64})
(N, D) = size(xs)
@assert D == 784
probs = @trace(net(xs), :net)
@assert size(probs) == (N, 10)
ys = Vector{Int}(undef, N)
for i=1:N
ys[i] = @trace(categorical(probs[i,:]), (:y, i))
end
return ys
end
DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Matrix{Float64}], false, Union{Nothing, Some{Any}}[nothing], var"##f#360", Bool[0], false)
ADAMのオプティマイザーでエラーが発生するため、勾配下降法を元にした更新を構築して、パラメータを更新します。
julia> @pyimport tensorflow.train as train
julia> ParamUpdate(FixedStepGradientDescent(0.001), net)
2023-08-25 22:19:28.296563: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
ParamUpdate(Dict{GenerativeFunction, Any}(TFFunction(PyObject <tensorflow.python.client.session.Session object at 0x2c5416c40>, PyObject[PyObject <tf.Tensor 'Placeholder_9:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Softmax_3:0' shape=(None, 10) dtype=float64>, Dict{PyObject, PyObject}(PyObject <tf.Variable 'Variable_26:0' shape=(64,) dtype=float64> => PyObject <tf.Variable 'Variable_34:0' shape=(64,) dtype=float64>, PyObject <tf.Variable 'Variable_23:0' shape=(5, 5, 1, 32) dtype=float64> => PyObject <tf.Variable 'Variable_31:0' shape=(5, 5, 1, 32) dtype=float64>, PyObject <tf.Variable 'Variable_25:0' shape=(5, 5, 32, 64) dtype=float64> => PyObject <tf.Variable 'Variable_33:0' shape=(5, 5, 32, 64) dtype=float64>, PyObject <tf.Variable 'Variable_29:0' shape=(1024, 10) dtype=float64> => PyObject <tf.Variable 'Variable_37:0' shape=(1024, 10) dtype=float64>, PyObject <tf.Variable 'Variable_28:0' shape=(1024,) dtype=float64> => PyObject <tf.Variable 'Variable_36:0' shape=(1024,) dtype=float64>, PyObject <tf.Variable 'Variable_30:0' shape=(10,) dtype=float64> => PyObject <tf.Variable 'Variable_38:0' shape=(10,) dtype=float64>, PyObject <tf.Variable 'Variable_24:0' shape=(32,) dtype=float64> => PyObject <tf.Variable 'Variable_32:0' shape=(32,) dtype=float64>, PyObject <tf.Variable 'Variable_27:0' shape=(3136, 1024) dtype=float64> => PyObject <tf.Variable 'Variable_35:0' shape=(3136, 1024) dtype=float64>), PyObject[PyObject <tf.Tensor 'gradients_6/Reshape_2_grad/Reshape:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Placeholder_10:0' shape=<unknown> dtype=float64>, PyObject <tf.Operation 'group_deps_2' type=NoOp>, PyObject <tf.Operation 'group_deps_3' type=NoOp>, PyObject <tf.Tensor 'Placeholder_11:0' shape=() dtype=float64>) => GenTF.FixedStepGradientDescentTFFunctionState(PyObject <tf.Operation 'GradientDescent_3' type=NoOp>, TFFunction(PyObject <tensorflow.python.client.session.Session object at 0x2c5416c40>, PyObject[PyObject <tf.Tensor 'Placeholder_9:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Softmax_3:0' shape=(None, 10) dtype=float64>, Dict{PyObject, PyObject}(PyObject <tf.Variable 'Variable_26:0' shape=(64,) dtype=float64> => PyObject <tf.Variable 'Variable_34:0' shape=(64,) dtype=float64>, PyObject <tf.Variable 'Variable_23:0' shape=(5, 5, 1, 32) dtype=float64> => PyObject <tf.Variable 'Variable_31:0' shape=(5, 5, 1, 32) dtype=float64>, PyObject <tf.Variable 'Variable_25:0' shape=(5, 5, 32, 64) dtype=float64> => PyObject <tf.Variable 'Variable_33:0' shape=(5, 5, 32, 64) dtype=float64>, PyObject <tf.Variable 'Variable_29:0' shape=(1024, 10) dtype=float64> => PyObject <tf.Variable 'Variable_37:0' shape=(1024, 10) dtype=float64>, PyObject <tf.Variable 'Variable_28:0' shape=(1024,) dtype=float64> => PyObject <tf.Variable 'Variable_36:0' shape=(1024,) dtype=float64>, PyObject <tf.Variable 'Variable_30:0' shape=(10,) dtype=float64> => PyObject <tf.Variable 'Variable_38:0' shape=(10,) dtype=float64>, PyObject <tf.Variable 'Variable_24:0' shape=(32,) dtype=float64> => PyObject <tf.Variable 'Variable_32:0' shape=(32,) dtype=float64>, PyObject <tf.Variable 'Variable_27:0' shape=(3136, 1024) dtype=float64> => PyObject <tf.Variable 'Variable_35:0' shape=(3136, 1024) dtype=float64>), PyObject[PyObject <tf.Tensor 'gradients_6/Reshape_2_grad/Reshape:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Placeholder_10:0' shape=<unknown> dtype=float64>, PyObject <tf.Operation 'group_deps_2' type=NoOp>, PyObject <tf.Operation 'group_deps_3' type=NoOp>, PyObject <tf.Tensor 'Placeholder_11:0' shape=() dtype=float64>))), FixedStepGradientDescent(0.001))
julia> for i=1:1000
(xs, ys) = next_batch(loader, 100)
@assert size(xs) == (100, 784)
@assert size(ys) == (100,)
constraints = choicemap()
for (i, y) in enumerate(ys)
constraints[(:y, i)] = y
end
(trace, weight) = generate(f, (xs,), constraints)
# increments gradient accumulators
accumulate_param_gradients!(trace, nothing)
# performs SGD update and then resets gradient accumulators
apply!(update)
println("i: $i, weight: $weight")
end
i: 1, weight: -230.2418501131283
i: 2, weight: -230.25861867136845
i: 3, weight: -230.24646340719963
i: 4, weight: -230.25554056052002
i: 5, weight: -230.25868100738683
学習がうまく働いていません。GenTFは開発中のライブラリのため、サンプルコードとして捉えてください。
julia> for i=1:length(test_y)
x = test_x[i:i,:]
@assert size(x) == (1, 784)
true_y = test_y[i]
pred_y = f(x)
@assert length(pred_y) == 1
println("true: $true_y, predicted: $(pred_y[1])")
end
true: 8, predicted: 5
true: 3, predicted: 10
.
.
.
以上、畳み込みニューラルネットワークの実装例です。