

Rによるベイジアンネットワークを用いた因果探索。 有向グラフ因果モデル(DGCMs)、またはDAGは、因果関係を説明し、データから真の因果の関係を探索するために計算に用いる方法です。 causal-learnやcausalpyというpythonの因果探索ライブラリを評価しました、Rにも同様のライブラリが提供されています。ここでは、CRANに登録されているRのライブラリpcalgとbnlearnに実装されているいくつかの因果探索アルゴリズムを評価します。 2025年の10大リスク  ユーラシア・グループは、 ...




Pythonの開発環境としてJupyter notebookがよく使われています。RでもPython同様にJupyter によるコードの編集、実行環境が利用できます。データサイエンスでは、主にJulia, Python, Rという言語が使用されますが、このイニシャルを繋げるとJupyterになります。 Jupyter notebook  R環境設定  JupyterからRを使うには、Rを起動したコマンドプロンプトから以下のコマンドを使います。  使用しているバージョンのR環境にJupyterのカーネルのパ ...




【マネーサプライと物価上昇の因果推論】で、線形モデルを適用して因果関係を調べてみました。物価上昇がドル円為替レートと因果関係があり、国内の金融政策には影響を受けていないことを示しました。ここでは、このドル円為替レートと物価上昇の関係に非線形なモデルを適用した例を示します。 USDJPY為替レートと物価上昇  最初に、【マネーサプライと物価上昇の因果推論】で使った同じデータセットを読み込みます。 図1 CPIーUSDJPY為替レートのデータセット  このデータセット(期間2001年12月〜2023年11月の ...


Apple Silicon Mac の R バージョン更新・インストール


ARM Mac(Apple silicon)用の新しいバージョンの R バイナリパッケージがリリースされていたので(R4.4.2 released on 2024/10/31) 更新します。 CRANから Rバイナリパッケージをインストール 以下のCRANのサイトにアクセスしてダウンロードします。 https://cran.rstudio.com "Download R for macOS" をクリックします。  Linux(Debian, Fedora/Redhat, Ubuntu) またはWindow ...


書評:Essential Math for AI


Essential Math for AI:Next-Level Mathematics for Efficient and Successful AI Systems Hala Nelson Essential Math for AI:Next-Level Mathematics for Efficient and Successful AI Systems  本書は、機械学習に関してトピックごとに関連する数学が挿入してあります。数学の理論や証明、プラミングコードは記述してありません。  読者として、数 ...




Supremacy: AI, ChatGPT and the race that will change the world Parmy Olson Supremacy - AI, ChatGPT and the race that will change the world  ジェフリー・ヒントン氏のチームがGPUにCNNを実装したAlexNetを使って画像認識でブレークスルーを起こしたのが2012年です。  5年後の2017年にGoogleのチームがAttentionモデルを使ったTransforme ...


R統合開発環境 RStudioのインストール


 ARMネイティブなRStudioがリリースされていたのでインストールしました。  AppleSilcon版(ARMコアのMシリーズプロセッサ) Mac用のAnacondaをインストールしたときに、RStudioはバンドルされていませんでした。  Rの統合開発環境(IDE) RStudioは、公益法人RStudioが開発元でしたが、2022年に法人形態が変更され、Positという名称になっています。また、Positは、Rの開発、メンテナンスを行なっているR ファウンデーションとは無関係です。 RStudi ...


書評:Why Machines Learn


Why Machines Learn: The Elegant Math Behind Modern AI Anil Ananthaswamy Why Machines Learn: The Elegant Math Behind Modern AI  本書は1950年代のローゼンブラットのパーセプトロンから現代の深層機械学習までの物語を記述してあります。  今年、2024年のノーベル物理学賞を授与された、ホップフィールド氏(John Hopfield)とヒントン氏(Geoffey Hinton)の仕事も ...


書評:The Model Thinker


The Model Thinker: What You Need to Know to Make Data Work for You Scott E. Page The Model Thinker: What You Need to Know to Make Data Work for You  本書のサンプルを目を通してみると、本文がチャーリー・マンガー氏の言葉の引用から始まっています。  ー賢明になるためには、頭の中にモデルを持つことだ。このモデルの格子の中に、直接の経験と、代行による間接的な経験の両 ...


市場創造型のイノベーション 書評:The Prosperity Paradox


ノーベル物理学賞の対象としての機械学習  今年のストックホルムの物理学賞に、機械学習の分野への貢献に対してジョン・ホップフィールド氏と以前グーグルにも所属していたトロントのジェフリー・ヒントン氏が選出された。  ACM、チューリング賞なら自然なのだが、コンピュータサイエンスの分野から物理学賞として選ばれるのは珍しい。ストックホルムの賞は物理学と化学と生理学の3種類しかないので、物理学的な考えがアルゴリズムに導入されているので物理学という枠組みが適用されるのであろう。それだけ機械学習の社会へのインパクトが大 ...



確率プログラミングモデル GEN

 最近いくつかの新しい確率プログラミングライブラリがリリースされました。BeanMachine, GEN, Turingです。

 BeanMachineはPyTorch上に実装され、GEN, TuringはJuliaに提供される確率プログラミングのライブラリです。


 Turingは、Hong Ge氏によって開発され、ボランティアによってメンテナンスされています。MITライセンスで提供されているライブラリです。GENはMITが開発、提供元のベイジアンライブラリです。




  • 確率推論アルゴリズムです。ユーザーに変分推論、シーケンシャル・モンテカルロ、マルコフチェイン・モンテカルロ、ニューラルネットワークのハイブリッドなアルゴリズムの開発環境を提供します。
  • 自動微分を含む柔軟なAPIを提供しています。
  • 確率構造に基づいた効果的な推論モデル。GENの生成モデルと推論モデルは動的に計算グラフを構成します。リバーシブルジャンプとinvolutive MCMCアルゴリズムで効果的な推論を実行します。最新のMCMCの一つ、involutive MCMCについては以下の論文を参照してください。

 "Automating Involutive MCMC using Probabilistic and Differentiable Programming”





Julia REPL(REPL:read-eval-print loop)



vi ~/.bashrc
export PATH=/us/local/bin: $PATH
export PATH=/Applications/Julio-1.9.app/Contents/Resources/julia/bin:$PATH



echo $PATH
$ echo $PATH



vi ~/.zshrc
# >>> conda initialize >>>
# !! Contents within this block are managed by 'conda init' !!
__conda_setup="$('/Users/username/anaconda3/bin/conda' 'shell.zsh' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
    eval "$__conda_setup"
    if [ -f "/Users/username/anaconda3/etc/profile.d/conda.sh" ]; then
        . "/Users/username/anaconda3/etc/profile.d/conda.sh"
        export PATH="/Users/username/anaconda3/bin:$PATH"
unset __conda_setup
# <<< conda initialize <<<
export PATH=/opt/R/arm64/bin:$PATH
export PATH=/Applications/Julia-1.9.app/Contents/Resources/julia/bin:$PATH



(base)xxxxx~ % julia
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.9.2 (2023-07-05)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
|__/                   |




Juliaを起動し、Julia REPL(REPL:read-eval-print loop)から']'を入力します。
juliaのパッケージマネージャーが起動し、プロンプトが変更されます。パッケージマネージャのプロンプトでadd Genと入力します。

(base) UsernoMacBook-Air:~ username$ julia
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.9.2 (2023-07-05)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
|__/                   |

(@v1.9) pkg> 

(@v1.9) pkg> add Gen



Julia> using Pkg; Pkg.test("Gen")
julia> using Pkg; Pkg.test("Gen")
     Testing Gen
      Status `/private/var/folders/z3/y2vb_4653kjcytsygg5bkv7w0000gn/T/jl_NeEvI0/Project.toml`
  [34da2185] Compat v4.9.0
  [864edb3b] DataStructures v0.18.15
  [31c24e10] Distributions v0.25.100
  [f6369f11] ForwardDiff v0.10.36
  [de31a74c] FunctionalCollections v0.5.0
  [ea4f424c] Gen v0.4.5
  [682c06a0] JSON v0.21.4
  [1914dd2f] MacroTools v0.5.11
  [d96e819e] Parameters v0.12.3
  [37e2e3b7] ReverseDiff v1.15.1
  [276daf66] SpecialFunctions v2.3.1
  [37e2e46d] LinearAlgebra `@stdlib/LinearAlgebra`
  [9a3f8284] Random `@stdlib/Random`
  [8dfed614] Test `@stdlib/Test`
      Status `/private/var/folders/z3/y2vb_4653kjcytsygg5bkv7w0000gn/T/jl_NeEvI0/Manifest.toml`
  [49dc2e85] Calculus v0.5.1
  [d360d2e6] ChainRulesCore v1.16.0
  [bbf7d656] CommonSubexpressions v0.3.0
  [34da2185] Compat v4.9.0
  [9a962f9c] DataAPI v1.15.0
  [864edb3b] DataStructures v0.18.15
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
  [31c24e10] Distributions v0.25.100
  [ffbed154] DocStringExtensions v0.9.3
  [fa6b7ba4] DualNumbers v0.6.8
  [1a297f60] FillArrays v1.6.1
  [f6369f11] ForwardDiff v0.10.36
  [069b7b12] FunctionWrappers v1.1.3
  [de31a74c] FunctionalCollections v0.5.0
  [ea4f424c] Gen v0.4.5
  [34004b35] HypergeometricFunctions v0.3.23
  [92d709cd] IrrationalConstants v0.2.2
  [692b3bcd] JLLWrappers v1.5.0
  [682c06a0] JSON v0.21.4
  [2ab3a3ac] LogExpFunctions v0.3.26
  [1914dd2f] MacroTools v0.5.11
  [e1d29d7a] Missings v1.1.0
  [77ba4419] NaNMath v1.0.2
  [bac558e1] OrderedCollections v1.6.2
  [90014a1f] PDMats v0.11.17
  [d96e819e] Parameters v0.12.3
  [69de0a69] Parsers v2.7.2
  [aea7be01] PrecompileTools v1.2.0
  [21216c6a] Preferences v1.4.0
  [1fd47b50] QuadGK v2.8.2
  [189a3867] Reexport v1.2.2
  [37e2e3b7] ReverseDiff v1.15.1
  [79098fc4] Rmath v0.7.1
  [a2af1166] SortingAlgorithms v1.1.1
  [276daf66] SpecialFunctions v2.3.1
  [90137ffa] StaticArrays v1.6.2
  [1e83bf80] StaticArraysCore v1.4.2
  [82ae8749] StatsAPI v1.6.0
  [2913bbd2] StatsBase v0.34.0
  [4c63d2b9] StatsFuns v1.3.0
  [3a884ed6] UnPack v1.0.2
  [efe28fd5] OpenSpecFun_jll v0.5.5+0
  [f50d1b31] Rmath_jll v0.4.0+0
  [0dad84c5] ArgTools v1.1.1 `@stdlib/ArgTools`
  [56f22d72] Artifacts `@stdlib/Artifacts`
  [2a0f44e3] Base64 `@stdlib/Base64`
  [ade2ca70] Dates `@stdlib/Dates`
  [f43a241f] Downloads v1.6.0 `@stdlib/Downloads`
  [7b1f6079] FileWatching `@stdlib/FileWatching`
  [b77e0a4c] InteractiveUtils `@stdlib/InteractiveUtils`
  [b27032c2] LibCURL v0.6.3 `@stdlib/LibCURL`
  [76f85450] LibGit2 `@stdlib/LibGit2`
  [8f399da3] Libdl `@stdlib/Libdl`
  [37e2e46d] LinearAlgebra `@stdlib/LinearAlgebra`
  [56ddb016] Logging `@stdlib/Logging`
  [d6f4376e] Markdown `@stdlib/Markdown`
  [a63ad114] Mmap `@stdlib/Mmap`
  [ca575930] NetworkOptions v1.2.0 `@stdlib/NetworkOptions`
  [44cfe95a] Pkg v1.9.2 `@stdlib/Pkg`
  [de0858da] Printf `@stdlib/Printf`
  [3fa0cd96] REPL `@stdlib/REPL`
  [9a3f8284] Random `@stdlib/Random`
  [ea8e919c] SHA v0.7.0 `@stdlib/SHA`
  [9e88b42a] Serialization `@stdlib/Serialization`
  [6462fe0b] Sockets `@stdlib/Sockets`
  [2f01184e] SparseArrays `@stdlib/SparseArrays`
  [10745b16] Statistics v1.9.0 `@stdlib/Statistics`
  [4607b0f0] SuiteSparse `@stdlib/SuiteSparse`
  [fa267f1f] TOML v1.0.3 `@stdlib/TOML`
  [a4e569a6] Tar v1.10.0 `@stdlib/Tar`
  [8dfed614] Test `@stdlib/Test`
  [cf7118a7] UUIDs `@stdlib/UUIDs`
  [4ec0a83e] Unicode `@stdlib/Unicode`
  [e66e0078] CompilerSupportLibraries_jll v1.0.5+0 `@stdlib/CompilerSupportLibraries_jll`
  [deac9b47] LibCURL_jll v7.84.0+0 `@stdlib/LibCURL_jll`
  [29816b5a] LibSSH2_jll v1.10.2+0 `@stdlib/LibSSH2_jll`
  [c8ffd9c3] MbedTLS_jll v2.28.2+0 `@stdlib/MbedTLS_jll`
  [14a3606d] MozillaCACerts_jll v2022.10.11 `@stdlib/MozillaCACerts_jll`
  [4536629a] OpenBLAS_jll v0.3.21+4 `@stdlib/OpenBLAS_jll`
  [05823500] OpenLibm_jll v0.8.1+0 `@stdlib/OpenLibm_jll`
  [bea87d4a] SuiteSparse_jll v5.10.1+6 `@stdlib/SuiteSparse_jll`
  [83775a58] Zlib_jll v1.2.13+0 `@stdlib/Zlib_jll`
  [8e850b90] libblastrampoline_jll v5.8.0+0 `@stdlib/libblastrampoline_jll`
  [8e850ede] nghttp2_jll v1.48.0+0 `@stdlib/nghttp2_jll`
  [3f19e933] p7zip_jll v17.4.0+0 `@stdlib/p7zip_jll`
Precompiling project...
  53 dependencies successfully precompiled in 23 seconds. 7 already precompiled.
     Testing Running tests...
WARNING: method definition for process! at /Users/username/.julia/packages/Gen/Dne3u/src/modeling_library/switch/update.jl:85 declares type variable DV but does not use it.
WARNING: method definition for #importance_resampling#354 at /Users/username/.julia/packages/Gen/Dne3u/src/inference/importance.jl:70 declares type variable W but does not use it.
WARNING: method definition for #importance_resampling#354 at /Users/username/.julia/packages/Gen/Dne3u/src/inference/importance.jl:70 declares type variable V but does not use it.
Test Summary: | Pass  Total  Time
autodiff      |    2      2  0.2s
Test Summary:   | Pass  Total  Time
diff arithmetic |   12     12  0.0s
Test Summary: | Pass  Total  Time
diff vectors  |   20     20  0.1s
Test Summary:     |Time
diff dictionaries | None  0.0s
Test Summary: |Time
diff sets     | None  0.0s
Test Summary:     | Pass  Total  Time
dynamic selection |   18     18  0.1s
Test Summary: | Pass  Total  Time
all selection |    4      4  0.0s
Test Summary:        | Pass  Total  Time
complement selection |   13     13  0.0s
Test Summary:                   | Pass  Total  Time
static assignment to/from array |   12     12  0.1s
Test Summary:                    | Pass  Total  Time
dynamic assignment to/from array |   12     12  0.2s
Test Summary:                       | Pass  Total  Time
dynamic assignment copy constructor |    4      4  0.0s
Test Summary:                            | Pass  Total  Time
internal vector assignment to/from array |    8      8  0.1s
Test Summary:            | Pass  Total  Time
dynamic assignment merge |   10     10  0.0s
Test Summary:                     | Pass  Total  Time
dynamic assignment variadic merge |    2      2  0.1s
Test Summary:           | Pass  Total  Time
static assignment merge |   10     10  0.2s
Test Summary:                    | Pass  Total  Time
static assignment variadic merge |    2      2  0.1s
Test Summary:            | Pass  Total  Time
static assignment errors |   12     12  0.0s
Test Summary:             | Pass  Total  Time
dynamic assignment errors |    7      7  0.0s
Test Summary:                | Pass  Total  Time
dynamic assignment overwrite |   16     16  0.0s
Test Summary:                  | Pass  Total  Time
dynamic assignment constructor |    2      2  0.0s
Test Summary:       | Pass  Total  Time
choice map equality |    6      6  0.0s
Test Summary:          | Pass  Total  Time
choice map nested view |    8      8  0.2s
Test Summary:                        | Pass  Total  Time
filtering choicemaps with selections |    6      6  0.0s
Test Summary:                  | Pass  Total  Time
invalid choice map constructor |    1      1  0.0s
Test Summary:     | Pass  Total  Time
choicemap hashing |   25     25  0.1s
Test Summary:                                                         | Pass  Total  Time
update(...) shorthand assuming unchanged args (dynamic modeling lang) |    3      3  2.9s
Test Summary:                                                             | Pass  Total  Time
regenerate(...) shorthand assuming unchanged args (dynamic modeling lang) |    2      2  0.1s
Test Summary:                                                        | Pass  Total  Time
update(...) shorthand assuming unchanged args (static modeling lang) |    3      3  0.6s
Test Summary:                                                            | Pass  Total  Time
regenerate(...) shorthand assuming unchanged args (static modeling lang) |    2      2  0.2s
Test Summary:         | Pass  Total  Time
macros in dynamic DSL |    7      7  0.1s
Test Summary:        | Pass  Total  Time
macros in static DSL |    7      7  0.2s
Test Summary: | Pass  Total  Time
Dynamic DSL   |  148    148  3.5s
Test Summary: | Pass  Total  Time
static DSL    |  172    172  1.2s
Test Summary:                            | Pass  Total  Time
optional positional args (calling + GFI) |   42     42  4.6s
 Test Summary:                          | Pass  Total     Time
optional positional args (combinators) |   16     16  1m28.8s
Test Summary: | Pass  Total  Time
static IR     |  138    138  3.6s
Test Summary: | Pass  Total  Time
tilde syntax  |   12     12  0.3s
Test Summary:       | Pass  Total  Time
importance sampling |   22     22  0.1s
Test Summary: | Pass  Total  Time
SGD training  |   10     10  1.6s
Test Summary: | Pass  Total  Time
lecture!      |    2      2  4.2s
Test Summary:                   | Pass  Total  Time
black box variational inference |    4      4  3.8s
opt_theta: -0.37567952986819614
true optimum log_marginal_likelihood: -14.79249158646238
  4.845693 seconds (80.93 M allocations: 3.204 GiB, 5.75% gc time, 2.40% compilation time)
final theta: -0.3688737663341869
final elbo estimate: -14.83005893286215
  8.781272 seconds (155.11 M allocations: 6.286 GiB, 6.19% gc time, 0.50% compilation time)
final theta: -0.3865867550388555
final elbo estimate: -14.771737102011121
Test Summary: | Pass  Total   Time
minimal vae   |    3      3  13.8s
Test Summary:   | Pass  Total  Time
hmm forward alg |    1      1  0.1s
Test Summary:      | Pass  Total  Time
particle filtering |    8      8  3.3s
Test Summary: |Time
mala          | None  0.1s
Test Summary: |Time
hmc           | None  0.0s
Test Summary: |Time
map_optimize  | None  0.0s
Test Summary:    |Time
elliptical slice | None  1.6s
Test Summary:   | Pass  Total  Time
Involutive MCMC |    6      6  2.0s
Test Summary:     | Pass  Total  Time
trace translators |    5      5  1.0s
Test Summary: | Pass  Total  Time
Kernel DSL    |    1      1  0.1s
Test Summary:                                                            | Pass  Total  Time
custom deterministic generative function with custom update and gradient |   23     23  0.1s
Test Summary:      | Pass  Total  Time
custom gradient GF |    2      2  0.0s
Test Summary:    | Pass  Total  Time
custom update GF |   12     12  0.1s
Test Summary: | Pass  Total  Time
bernoulli     |    2      2  1.3s
Test Summary: | Pass  Total  Time
beta          |    9      9  1.4s
Test Summary: | Pass  Total  Time
categorical   |    7      7  0.1s
Test Summary: | Pass  Total  Time
gamma         |    5      5  0.0s
Test Summary: | Pass  Total  Time
inv_gamma     |    5      5  0.0s
Test Summary: | Pass  Total  Time
normal        |    5      5  0.0s
Test Summary:                       | Pass  Total  Time
zero-dimensional broadcasted normal |    6      6  1.5s
Test Summary:                                                  | Pass  Total  Time
array normal (trivially broadcasted: all args have same shape) |    6      6  2.0s
Test Summary:      | Pass  Total  Time
broadcasted normal |   14     14  0.8s
Test Summary:       | Pass  Total  Time
multivariate normal |    9      9  0.2s
Test Summary: | Pass  Total  Time
uniform       |    5      5  0.0s
Test Summary:    | Pass  Total  Time
uniform_discrete |    3      3  0.0s
Test Summary:     | Pass  Total  Time
piecewise_uniform |    8      8  1.4s
Test Summary:        | Pass  Total  Time
beta uniform mixture |    6      6  1.3s
Test Summary: | Pass  Total  Time
geometric     |    4      4  1.3s
Test Summary: | Pass  Total  Time
binom         |    5      5  1.4s
Test Summary: | Pass  Total  Time
neg_binom     |    5      5  0.0s
Test Summary: | Pass  Total  Time
exponential   |    4      4  1.3s
Test Summary: | Pass  Total  Time
poisson       |    4      4  0.0s
Test Summary: | Pass  Total  Time
laplace       |    3      3  0.0s
Test Summary: | Pass  Total  Time
cauchy        |    3      3  0.0s
Test Summary:        | Pass  Total  Time
choice_at combinator |   74     74  0.2s
Test Summary:      | Pass  Total  Time
call_at combinator |   74     74  0.2s
Test Summary:  | Pass  Total  Time
map combinator |  139    139  1.2s
Test Summary:     | Pass  Total  Time
unfold combinator |  199    199  2.0s
Test Summary:          | Pass  Total  Time
recurse node numbering |   54     54  0.0s
Test Summary: | Pass  Total  Time
simple pcfg   |   88     88  0.7s
Test Summary:     | Pass  Total  Time
switch combinator |   72     72  0.9s
Test Summary:      | Pass  Total  Time
dist DSL (untyped) |   12     12  8.2s
Test Summary:    | Pass  Total  Time
dist DSL (typed) |   12     12  2.2s
Test Summary:                            | Pass  Total  Time
fixed mixture of different distributions |   11     11  4.2s
Test Summary:      | Pass  Total  Time
mixture of normals |   20     20  3.3s
Test Summary:       | Pass  Total  Time
mixture of binomial |    9      9  1.5s
Test Summary:                   | Pass  Total  Time
mixture of multivariate normals |   15     15  9.6s
     Testing Gen tests passed 






using Gen

@gen function my_model(xs::Vector{Float64})
    slope = @trace(normal(0, 2), :slope)
    intercept = @trace(normal(0, 10), :intercept)
    for (i, x) in enumerate(xs)
        @trace(normal(slope * x + intercept, 1), "y-$i")


julia> using Gen
[ Info: Precompiling Gen [ea4f424c-a589-11e8-07c0-fd5c91b9da4a]
WARNING: method definition for process! at /Users/username/.julia/packages/Gen/Dne3u/src/modeling_library/switch/update.jl:85 declares type variable DV but does not use it.
WARNING: method definition for #importance_resampling#354 at /Users/username/.julia/packages/Gen/Dne3u/src/inference/importance.jl:70 declares type variable W but does not use it.
WARNING: method definition for #importance_resampling#354 at /Users/username/.julia/packages/Gen/Dne3u/src/inference/importance.jl:70 declares type variable V but does not use it.
julia> @gen function my_model(xs::Vector{Float64})
           slope = @trace(normal(0, 2), :slope)
           intercept = @trace(normal(0, 10), :intercept)
           for (i, x) in enumerate(xs)
               @trace(normal(slope * x + intercept, 1), "y-$i")
DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Vector{Float64}], false, Union{Nothing, Some{Any}}[nothing], var"##my_model#292", Bool[0], false)

 次に推論プログラムを実装します。モデルのトレースを実行します。GENの標準推論ライブラリを使って、正規のJuliaのコードとして書きます。推論プログラムは以下のデータセットをとり、slopeとintercept のパラメータを適合させるためにMCMCアルゴリズムを走らせます。

function my_inference_program(xs::Vector{Float64}, ys::Vector{Float64}, num_iters::Int)
    # Create a set of constraints fixing the 
    # y coordinates to the observed y values
    constraints = choicemap()
    for (i, y) in enumerate(ys)
        constraints["y-$i"] = y
    # Run the model, constrained by `constraints`,
    # to get an initial execution trace
    (trace, _) = generate(my_model, (xs,), constraints)
    # Iteratively update the slope then the intercept,
    # using Gen's metropolis_hastings operator.
    for iter=1:num_iters
        (trace, _) = metropolis_hastings(trace, select(:slope))
        (trace, _) = metropolis_hastings(trace, select(:intercept))
    # From the final trace, read out the slope and
    # the intercept.
    choices = get_choices(trace)
    return (choices[:slope], choices[:intercept])
julia> function my_inference_program(xs::Vector{Float64}, ys::Vector{Float64}, njulia> function my_inference_program(xs::Vector{Float64}, ys::Vector{Float64}, num_iters::Int)
           # Create a set of constraints fixing the 
           # y coordinates to the observed y values
           constraints = choicemap()
           for (i, y) in enumerate(ys)
               constraints["y-$i"] = y
           # Run the model, constrained by `constraints`,
           # to get an initial execution trace
           (trace, _) = generate(my_model, (xs,), constraints)
           # Iteratively update the slope then the intercept,
           # using Gen's metropolis_hastings operator.
           for iter=1:num_iters
               (trace, _) = metropolis_hastings(trace, select(:slope))
               (trace, _) = metropolis_hastings(trace, select(:intercept))
           # From the final trace, read out the slope and
           # the intercept.
           choices = get_choices(trace)
           return (choices[:slope], choices[:intercept])
my_inference_program (generic function with 1 method)


xs = [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]
ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90]
(slope, intercept) = my_inference_program(xs, ys, 1000)
println("slope: $slope, intercept: $intercept")
julia> xs = [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]
10-element Vector{Float64}:
julia> ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90]
10-element Vector{Float64}:
julia> (slope, intercept) = my_inference_program(xs, ys, 1000)
(-1.9471036551893741, 10.134566406450329)
julia> println("slope: $slope, intercept: $intercept")
slope: -1.9471036551893741, intercept: 10.134566406450329










conda create -n python38 python=3.8
conda activate python38


$python --version
Python 3.8.16


conda install -c apple tensorflow-deps
python -m pip install tensorflow-macos
python -m pip install tensorflow-metal


(python38) UsernoMacBook-Air:~ username$ julia
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.9.2 (2023-07-05)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
|__/                   |



julia> using PyCall; println(PyCall.python)




Julia REPLから']'を入力してパッケージマネージャを起動します。

 (@v1.9) pkg> add PyCall 
  Resolving package versions...
   Installed VersionParsing ─ v1.3.0
   Installed Conda ────────── v1.9.1
   Installed PyCall ───────── v1.96.1
    Updating `~/.julia/environments/v1.9/Project.toml`
  [438e738f] + PyCall v1.96.1
    Updating `~/.julia/environments/v1.9/Manifest.toml`
  [8f4d0f93] + Conda v1.9.1
  [438e738f] + PyCall v1.96.1
  [81def892] + VersionParsing v1.3.0
    Building Conda ─→ `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/8c86e48c0db1564a1d49548d3515ced5d604c408/build.log`
    Building PyCall → `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/43d304ac6f0354755f1d60730ece8c499980f7ba/build.log`
Precompiling project...
  3 dependencies successfully precompiled in 7 seconds. 171 already precompiled.


 pythonの環境設定を変更する場合は、ENV["python"] ="/use/local/bin" のように、環境変数にpythonの実行パスを記述してください。


 julia> using Pkg; ENV["PYTHON"] = "<python>"; Pkg.build("PyCall")
  Building Conda ─→ `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/8c86e48c0db1564a1d49548d3515ced5d604c408/build.log`
    Building PyCall → `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/43d304ac6f0354755f1d60730ece8c499980f7ba/build.log`
julia> using PyCall; println(PyCall.python)


 GenTFは開発中のため、パッケージに含まれていません。パッケージマネージャを使って、add GenTFでは追加できません。


(@v1.9) pkg>のプロンプトで以下のコマンドを実行します。

(@v1.9) pkg> add https://github.com/probcomp/GenTF
     Cloning git-repo `https://github.com/probcomp/GenTF`
    Updating git-repo `https://github.com/probcomp/GenTF`
   Resolving package versions...
   Installed OffsetArrays ──────────────── v1.12.10
   Installed ShowCases ─────────────────── v0.1.0
   Installed ContextVariablesX ─────────── v0.1.3
   Installed InlineStrings ─────────────── v1.4.0
   Installed ZipFile ───────────────────── v0.10.1
   Installed InitialValues ─────────────── v0.3.1
   Installed CEnum ─────────────────────── v0.4.2
   Installed InvertedIndices ───────────── v1.3.0
   Installed FileIO ────────────────────── v1.16.1
   Installed NNlib ─────────────────────── v0.9.4
   Installed MAT ───────────────────────── v0.10.5
   Installed BFloat16s ─────────────────── v0.4.2
   Installed PrettyPrint ───────────────── v0.2.0
   Installed DataFrames ────────────────── v1.6.1
   Installed MLUtils ───────────────────── v0.4.3
   Installed TupleTools ────────────────── v1.3.0
   Installed LLVM ──────────────────────── v6.1.0
   Installed Chemfiles ─────────────────── v0.10.41
   Installed JLD2 ──────────────────────── v0.4.33
   Installed GZip ──────────────────────── v0.6.1
   Installed FLoopsBase ────────────────── v0.1.1
   Installed AtomsBase ─────────────────── v0.3.4
   Installed SentinelArrays ────────────── v1.4.0
   Installed StringEncodings ───────────── v0.3.7
   Installed MicroCollections ──────────── v0.1.4
   Installed DefineSingletons ──────────── v0.1.2
   Installed NameResolution ────────────── v0.1.5
   Installed WorkerUtilities ───────────── v1.6.1
   Installed UnsafeAtomicsLLVM ─────────── v0.1.3
   Installed WeakRefStrings ────────────── v1.4.2
   Installed MappedArrays ──────────────── v0.4.2
   Installed PaddedViews ───────────────── v0.5.12
   Installed BufferedStreams ───────────── v1.2.1
   Installed HDF5_jll ──────────────────── v1.12.2+2
   Installed Glob ──────────────────────── v1.3.1
   Installed BangBang ──────────────────── v0.3.39
   Installed LazyModules ───────────────── v0.3.1
   Installed MLStyle ───────────────────── v0.4.17
   Installed FLoops ────────────────────── v0.2.1
   Installed MosaicViews ───────────────── v0.3.4
   Installed PeriodicTable ─────────────── v1.1.4
   Installed StructTypes ───────────────── v1.10.0
   Installed ImageCore ─────────────────── v0.10.1
   Installed ArgCheck ──────────────────── v2.3.0
   Installed TableTraits ───────────────── v1.0.1
   Installed JSON3 ─────────────────────── v1.13.2
   Installed ImageShow ─────────────────── v0.3.8
   Installed DataDeps ──────────────────── v0.7.11
   Installed Setfield ──────────────────── v1.1.1
   Installed DataValueInterfaces ───────── v1.0.0
   Installed PooledArrays ──────────────── v1.4.2
   Installed ConstructionBase ──────────── v1.5.3
   Installed AbstractFFTs ──────────────── v1.5.0
   Installed LLVMExtra_jll ─────────────── v0.0.23+0
   Installed CSV ───────────────────────── v0.10.11
   Installed StackViews ────────────────── v0.1.1
   Installed CompositionsBase ──────────── v0.1.2
   Installed UnsafeAtomics ─────────────── v0.2.1
   Installed Strided ───────────────────── v1.2.3
   Installed KernelAbstractions ────────── v0.9.8
   Installed HDF5 ──────────────────────── v0.16.16
   Installed MLDatasets ────────────────── v0.7.12
   Installed JuliaVariables ────────────── v0.2.4
   Installed PrettyTables ──────────────── v2.2.7
   Installed GPUArraysCore ─────────────── v0.1.5
   Installed SimpleTraits ──────────────── v0.9.4
   Installed Atomix ────────────────────── v0.1.0
   Installed Crayons ───────────────────── v4.1.1
   Installed Adapt ─────────────────────── v3.6.2
   Installed ImageBase ─────────────────── v0.1.7
   Installed InternedStrings ───────────── v0.7.0
   Installed Chemfiles_jll ─────────────── v0.10.4+0
   Installed Transducers ───────────────── v0.4.78
   Installed NPZ ───────────────────────── v0.4.3
   Installed Tables ────────────────────── v1.10.1
   Installed Baselet ───────────────────── v0.1.1
   Installed IteratorInterfaceExtensions ─ v1.0.0
   Installed UnitfulAtomic ─────────────── v1.0.0
   Installed FilePathsBase ─────────────── v0.9.20
   Installed SplittablesBase ───────────── v0.1.15
   Installed StringManipulation ────────── v0.3.0
   Installed Pickle ────────────────────── v0.3.3
  Downloaded artifact: HDF5
  Downloaded artifact: LLVMExtra
  Downloaded artifact: Chemfiles
    Updating `~/.julia/environments/v1.9/Project.toml`
  [1956f2fc] + GenTF v0.1.0 `https://github.com/probcomp/GenTF#master`
    Updating `~/.julia/environments/v1.9/Manifest.toml`
  [621f4979] + AbstractFFTs v1.5.0
  [79e6a3ab] + Adapt v3.6.2
  [dce04be8] + ArgCheck v2.3.0
  [a9b6321e] + Atomix v0.1.0
  [a963bdd2] + AtomsBase v0.3.4
  [ab4f0b2a] + BFloat16s v0.4.2
  [198e06fe] + BangBang v0.3.39
  [9718e550] + Baselet v0.1.1
  [e1450e63] + BufferedStreams v1.2.1
  [fa961155] + CEnum v0.4.2
  [336ed68f] + CSV v0.10.11
  [46823bd8] + Chemfiles v0.10.41
  [a33af91c] + CompositionsBase v0.1.2
  [187b0558] + ConstructionBase v1.5.3
  [6add18c4] + ContextVariablesX v0.1.3
  [a8cc5b0e] + Crayons v4.1.1
  [124859b0] + DataDeps v0.7.11
  [a93c6f00] + DataFrames v1.6.1
  [e2d170a0] + DataValueInterfaces v1.0.0
  [244e2a9f] + DefineSingletons v0.1.2
  [cc61a311] + FLoops v0.2.1
  [b9860ae5] + FLoopsBase v0.1.1
  [5789e2e9] + FileIO v1.16.1
  [48062228] + FilePathsBase v0.9.20
  [46192b85] + GPUArraysCore v0.1.5
  [92fee26a] + GZip v0.6.1
  [1956f2fc] + GenTF v0.1.0 `https://github.com/probcomp/GenTF#master`
  [c27321d9] + Glob v1.3.1
  [f67ccb44] + HDF5 v0.16.16
  [dad2f222] + LLVMExtra_jll v0.0.23+0
  [8ba89e20] + Distributed
  [9fa8497b] + Future
  [4af54fe1] + LazyArtifacts
        Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated -m`
    Building HDF5 → `~/.julia/scratchspaces/44cfe95a-1eb2-52ea-b672-e2afdf69b78f/114e20044677badbc631ee6fdc80a67920561a29/build.log`
Precompiling project...
  ✓ Plots
  96 dependencies successfully precompiled in 77 seconds. 172 already precompiled.
  1 dependency precompiled but a different version is currently loaded. Restart julia to access the new version


GenTF tensorflow によるCNN





import Random
import MLDatasets
train_x, train_y = MLDatasets.MNIST.traindata()
mutable struct DataLoader
DataLoader() = DataLoader(1, Random.shuffle(1:60000))
function next_batch(loader::DataLoader, batch_size)
           x = zeros(Float64, batch_size, 784)
           y = Vector{Int}(undef, batch_size)
           for i=1:batch_size
               x[i, :] = reshape(train_x[:,:,loader.cur_id], (28*28))
               y[i] = train_y[loader.cur_id] + 1
               loader.cur_id += 1
               if loader.cur_id > 60000
                   loader.cur_id = 1
           x, y

unction load_test_set()
           test_x, test_y = MLDatasets.MNIST.testdata()
           N = length(test_y)
           x = zeros(Float64, N, 784)
           y = Vector{Int}(undef, N)
           for i=1:N
               x[i, :] = reshape(test_x[:,:,i], (28*28))
               y[i] = test_y[i]+1
           x, y

const loader = DataLoader()
(test_x, test_y) = load_test_set()
julia> import Random

julia> Random.seed!(1)
julia> import MLDatasets

julia> train_x, train_y = MLDatasets.MNIST.traindata()
┌ Warning: MNIST.traindata() is deprecated, use `MNIST(split=:train)[:]` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/Ob4aN/src/datasets/vision/mnist.jl:187
(features = [0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8;;; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8;;; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8;;; … ;;; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8;;; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8;;; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; … ; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8; 0.0N0f8 0.0N0f8 … 0.0N0f8 0.0N0f8], targets = [5, 0, 4, 1, 9, 2, 1, 3, 1, 4  …  9, 2, 9, 5, 1, 8, 3, 5, 6, 8])
julia> mutable struct DataLoader

julia> DataLoader() = DataLoader(1, Random.shuffle(1:60000))
julia> function next_batch(loader::DataLoader, batch_size)
           x = zeros(Float64, batch_size, 784)
           y = Vector{Int}(undef, batch_size)
           for i=1:batch_size
               x[i, :] = reshape(train_x[:,:,loader.cur_id], (28*28))
               y[i] = train_y[loader.cur_id] + 1
               loader.cur_id += 1
               if loader.cur_id > 60000
                   loader.cur_id = 1
           x, y
next_batch (generic function with 1 method)
julia> function load_test_set()
           test_x, test_y = MLDatasets.MNIST.testdata()
           N = length(test_y)
           x = zeros(Float64, N, 784)
           y = Vector{Int}(undef, N)
           for i=1:N
               x[i, :] = reshape(test_x[:,:,i], (28*28))
               y[i] = test_y[i]+1
           x, y
load_test_set (generic function with 1 method)
julia> const loader = DataLoader()
WARNING: redefinition of constant loader. This may fail, cause incorrect answers, or produce other errors.
DataLoader(1, [16889, 35959, 56640, 30743, 23204, 59111, 51586, 55462, 8914, 52722  …  20038, 28424, 8312, 3503, 57045, 19635, 1025, 28306, 57085, 10336])
julia> (test_x, test_y) = load_test_set()
┌ Warning: MNIST.testdata() is deprecated, use `MNIST(split=:test)[:]` instead.
└ @ MLDatasets ~/.julia/packages/MLDatasets/Ob4aN/src/datasets/vision/mnist.jl:195
([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [8, 3, 2, 1, 5, 2, 5, 10, 6, 10  …  8, 9, 10, 1, 2, 3, 4, 5, 6, 7])


using Gen
using GenTF
using PyCall

@pyimport tensorflow as tf
@pyimport tensorflow.nn as nn

# v1のコードをtf2 に変更
@pyimport tensorflow.compat.v1 as v1

xs = v1.placeholder(tf.float64) # N x 784

function weight_variable(shape)
    initial = 0.001 * randn(shape...)

function bias_variable(shape)
    initial = fill(.1, shape...)

function conv2d(x, W)
    nn.conv2d(x, W, (1, 1, 1, 1), "SAME")

function max_pool_2x2(x)
    nn.max_pool(x, (1, 2, 2, 1), (1, 2, 2, 1), "SAME")

W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])

x_image = tf.reshape(xs, [-1, 28, 28, 1])

h_conv1 = nn.relu(tf.add(conv2d(x_image, W_conv1), b_conv1))
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])

h_conv2 = nn.relu(tf.add(conv2d(h_pool1, W_conv2), b_conv2))
h_pool2 = max_pool_2x2(h_conv2)

W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = nn.relu(tf.add(tf.matmul(h_pool2_flat, W_fc1), b_fc1))

W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])

probs = nn.softmax(tf.add(tf.matmul(h_fc1, W_fc2), b_fc2), axis=1) # N x 10

const sess = tf.Session()
const params = [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2]
const net = TFFunction(params, [xs], probs, sess)

@gen function f(xs::Matrix{Float64})
    (N, D) = size(xs)
    @assert D == 784
    probs = @trace(net(xs), :net)
    @assert size(probs) == (N, 10)
    ys = Vector{Int}(undef, N)
    for i=1:N
        ys[i] = @trace(categorical(probs[i,:]), (:y, i))
    return ys

# 学習 #

@pyimport tensorflow.train as train

#update = ParamUpdate(ADAM(1e-4, 0.9, 0.999, 1e-08), net)
update = ParamUpdate(FixedStepGradientDescent(0.001), net)
for i=1:1000

    (xs, ys) = next_batch(loader, 100)

    @assert size(xs) == (100, 784)
    @assert size(ys) == (100,)
    constraints = choicemap()
    for (i, y) in enumerate(ys)
        constraints[(:y, i)] = y

    (trace, weight) = generate(f, (xs,), constraints)

    # increments gradient accumulators
    accumulate_param_gradients!(trace, nothing)

    # performs SGD update and then resets gradient accumulators

    println("i: $i, weight: $weight")

# テストデータによる推論 #

for i=1:length(test_y)
    x = test_x[i:i,:]
    @assert size(x) == (1, 784)
    true_y = test_y[i]
    pred_y = f(x)
    @assert length(pred_y) == 1
    println("true: $true_y, predicted: $(pred_y[1])")
julia> xs = v1.placeholder(tf.float64)
PyObject <tf.Tensor 'Placeholder_9:0' shape=<unknown> dtype=float64>
julia> function weight_variable(shape)
           initial = 0.001 * randn(shape...)
weight_variable (generic function with 1 method)
julia> function bias_variable(shape)
           initial = fill(.1, shape...)
bias_variable (generic function with 1 method)
julia> function conv2d(x, W)
           nn.conv2d(x, W, (1, 1, 1, 1), "SAME")
conv2d (generic function with 1 method)
julia> function max_pool_2x2(x)
           nn.max_pool(x, (1, 2, 2, 1), (1, 2, 2, 1), "SAME")
max_pool_2x2 (generic function with 1 method)
julia> W_conv1 = weight_variable([5, 5, 1, 32])
PyObject <tf.Variable 'Variable_23:0' shape=(5, 5, 1, 32) dtype=float64>
julia> b_conv1 = bias_variable([32])
PyObject <tf.Variable 'Variable_24:0' shape=(32,) dtype=float64>
julia> x_image = tf.reshape(xs, [-1, 28, 28, 1])
PyObject <tf.Tensor 'Reshape_2:0' shape=(None, 28, 28, 1) dtype=float64>
julia> h_conv1 = nn.relu(tf.add(conv2d(x_image, W_conv1), b_conv1))
PyObject <tf.Tensor 'Relu_3:0' shape=(None, 28, 28, 32) dtype=float64>
julia> h_pool1 = max_pool_2x2(h_conv1)
PyObject <tf.Tensor 'MaxPool2d_2:0' shape=(None, 14, 14, 32) dtype=float64>
julia> W_conv2 = weight_variable([5, 5, 32, 64])
PyObject <tf.Variable 'Variable_25:0' shape=(5, 5, 32, 64) dtype=float64>
julia> b_conv2 = bias_variable([64])
PyObject <tf.Variable 'Variable_26:0' shape=(64,) dtype=float64>
julia> h_conv2 = nn.relu(tf.add(conv2d(h_pool1, W_conv2), b_conv2))
PyObject <tf.Tensor 'Relu_4:0' shape=(None, 14, 14, 64) dtype=float64>
julia> h_pool2 = max_pool_2x2(h_conv2)
PyObject <tf.Tensor 'MaxPool2d_3:0' shape=(None, 7, 7, 64) dtype=float64>
julia> W_fc1 = weight_variable([7*7*64, 1024])
PyObject <tf.Variable 'Variable_27:0' shape=(3136, 1024) dtype=float64>
julia> b_fc1 = bias_variable([1024])
PyObject <tf.Variable 'Variable_28:0' shape=(1024,) dtype=float64>
julia> h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
PyObject <tf.Tensor 'Reshape_3:0' shape=(None, 3136) dtype=float64>
julia> h_fc1 = nn.relu(tf.add(tf.matmul(h_pool2_flat, W_fc1), b_fc1))
PyObject <tf.Tensor 'Relu_5:0' shape=(None, 1024) dtype=float64>
julia> W_fc2 = weight_variable([1024, 10])
PyObject <tf.Variable 'Variable_29:0' shape=(1024, 10) dtype=float64>
julia> b_fc2 = bias_variable([10])
PyObject <tf.Variable 'Variable_30:0' shape=(10,) dtype=float64>
julia> probs = nn.softmax(tf.add(tf.matmul(h_fc1, W_fc2), b_fc2), axis=1) # N x 10
PyObject <tf.Tensor 'Softmax_3:0' shape=(None, 10) dtype=float64>
julia> const sess = v1.Session()
2023-08-24 23:16:44.334990: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-08-24 23:16:44.335655: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
WARNING: redefinition of constant sess. This may fail, cause incorrect answers, or produce other errors.
PyObject <tensorflow.python.client.session.Session object at 0x2c5416c40>
julia> const params = [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2]
8-element Vector{PyObject}:
 PyObject <tf.Variable 'Variable_23:0' shape=(5, 5, 1, 32) dtype=float64>
 PyObject <tf.Variable 'Variable_24:0' shape=(32,) dtype=float64>
 PyObject <tf.Variable 'Variable_25:0' shape=(5, 5, 32, 64) dtype=float64>
 PyObject <tf.Variable 'Variable_26:0' shape=(64,) dtype=float64>
 PyObject <tf.Variable 'Variable_27:0' shape=(3136, 1024) dtype=float64>
 PyObject <tf.Variable 'Variable_28:0' shape=(1024,) dtype=float64>
 PyObject <tf.Variable 'Variable_29:0' shape=(1024, 10) dtype=float64>
 PyObject <tf.Variable 'Variable_30:0' shape=(10,) dtype=float64>
julia> const net = TFFunction(params, [xs], probs, sess)
2023-08-24 23:17:08.518303: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2023-08-24 23:17:08.706835: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
TFFunction(PyObject <tensorflow.python.client.session.Session object at 0x2c5416c40>, PyObject[PyObject <tf.Tensor 'Placeholder_9:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Softmax_3:0' shape=(None, 10) dtype=float64>, Dict{PyObject, PyObject}(PyObject <tf.Variable 'Variable_26:0' shape=(64,) dtype=float64> => PyObject <tf.Variable 'Variable_34:0' shape=(64,) dtype=float64>, PyObject <tf.Variable 'Variable_23:0' shape=(5, 5, 1, 32) dtype=float64> => PyObject <tf.Variable 'Variable_31:0' shape=(5, 5, 1, 32) dtype=float64>, PyObject <tf.Variable 'Variable_25:0' shape=(5, 5, 32, 64) dtype=float64> => PyObject <tf.Variable 'Variable_33:0' shape=(5, 5, 32, 64) dtype=float64>, PyObject <tf.Variable 'Variable_29:0' shape=(1024, 10) dtype=float64> => PyObject <tf.Variable 'Variable_37:0' shape=(1024, 10) dtype=float64>, PyObject <tf.Variable 'Variable_28:0' shape=(1024,) dtype=float64> => PyObject <tf.Variable 'Variable_36:0' shape=(1024,) dtype=float64>, PyObject <tf.Variable 'Variable_30:0' shape=(10,) dtype=float64> => PyObject <tf.Variable 'Variable_38:0' shape=(10,) dtype=float64>, PyObject <tf.Variable 'Variable_24:0' shape=(32,) dtype=float64> => PyObject <tf.Variable 'Variable_32:0' shape=(32,) dtype=float64>, PyObject <tf.Variable 'Variable_27:0' shape=(3136, 1024) dtype=float64> => PyObject <tf.Variable 'Variable_35:0' shape=(3136, 1024) dtype=float64>), PyObject[PyObject <tf.Tensor 'gradients_6/Reshape_2_grad/Reshape:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Placeholder_10:0' shape=<unknown> dtype=float64>, PyObject <tf.Operation 'group_deps_2' type=NoOp>, PyObject <tf.Operation 'group_deps_3' type=NoOp>, PyObject <tf.Tensor 'Placeholder_11:0' shape=() dtype=float64>)
julia> @gen function f(xs::Matrix{Float64})
           (N, D) = size(xs)
           @assert D == 784
           probs = @trace(net(xs), :net)
           @assert size(probs) == (N, 10)
           ys = Vector{Int}(undef, N)
           for i=1:N
               ys[i] = @trace(categorical(probs[i,:]), (:y, i))
           return ys
DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Matrix{Float64}], false, Union{Nothing, Some{Any}}[nothing], var"##f#360", Bool[0], false)


julia> @pyimport tensorflow.train as train

julia> ParamUpdate(FixedStepGradientDescent(0.001), net)
2023-08-25 22:19:28.296563: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
ParamUpdate(Dict{GenerativeFunction, Any}(TFFunction(PyObject <tensorflow.python.client.session.Session object at 0x2c5416c40>, PyObject[PyObject <tf.Tensor 'Placeholder_9:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Softmax_3:0' shape=(None, 10) dtype=float64>, Dict{PyObject, PyObject}(PyObject <tf.Variable 'Variable_26:0' shape=(64,) dtype=float64> => PyObject <tf.Variable 'Variable_34:0' shape=(64,) dtype=float64>, PyObject <tf.Variable 'Variable_23:0' shape=(5, 5, 1, 32) dtype=float64> => PyObject <tf.Variable 'Variable_31:0' shape=(5, 5, 1, 32) dtype=float64>, PyObject <tf.Variable 'Variable_25:0' shape=(5, 5, 32, 64) dtype=float64> => PyObject <tf.Variable 'Variable_33:0' shape=(5, 5, 32, 64) dtype=float64>, PyObject <tf.Variable 'Variable_29:0' shape=(1024, 10) dtype=float64> => PyObject <tf.Variable 'Variable_37:0' shape=(1024, 10) dtype=float64>, PyObject <tf.Variable 'Variable_28:0' shape=(1024,) dtype=float64> => PyObject <tf.Variable 'Variable_36:0' shape=(1024,) dtype=float64>, PyObject <tf.Variable 'Variable_30:0' shape=(10,) dtype=float64> => PyObject <tf.Variable 'Variable_38:0' shape=(10,) dtype=float64>, PyObject <tf.Variable 'Variable_24:0' shape=(32,) dtype=float64> => PyObject <tf.Variable 'Variable_32:0' shape=(32,) dtype=float64>, PyObject <tf.Variable 'Variable_27:0' shape=(3136, 1024) dtype=float64> => PyObject <tf.Variable 'Variable_35:0' shape=(3136, 1024) dtype=float64>), PyObject[PyObject <tf.Tensor 'gradients_6/Reshape_2_grad/Reshape:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Placeholder_10:0' shape=<unknown> dtype=float64>, PyObject <tf.Operation 'group_deps_2' type=NoOp>, PyObject <tf.Operation 'group_deps_3' type=NoOp>, PyObject <tf.Tensor 'Placeholder_11:0' shape=() dtype=float64>) => GenTF.FixedStepGradientDescentTFFunctionState(PyObject <tf.Operation 'GradientDescent_3' type=NoOp>, TFFunction(PyObject <tensorflow.python.client.session.Session object at 0x2c5416c40>, PyObject[PyObject <tf.Tensor 'Placeholder_9:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Softmax_3:0' shape=(None, 10) dtype=float64>, Dict{PyObject, PyObject}(PyObject <tf.Variable 'Variable_26:0' shape=(64,) dtype=float64> => PyObject <tf.Variable 'Variable_34:0' shape=(64,) dtype=float64>, PyObject <tf.Variable 'Variable_23:0' shape=(5, 5, 1, 32) dtype=float64> => PyObject <tf.Variable 'Variable_31:0' shape=(5, 5, 1, 32) dtype=float64>, PyObject <tf.Variable 'Variable_25:0' shape=(5, 5, 32, 64) dtype=float64> => PyObject <tf.Variable 'Variable_33:0' shape=(5, 5, 32, 64) dtype=float64>, PyObject <tf.Variable 'Variable_29:0' shape=(1024, 10) dtype=float64> => PyObject <tf.Variable 'Variable_37:0' shape=(1024, 10) dtype=float64>, PyObject <tf.Variable 'Variable_28:0' shape=(1024,) dtype=float64> => PyObject <tf.Variable 'Variable_36:0' shape=(1024,) dtype=float64>, PyObject <tf.Variable 'Variable_30:0' shape=(10,) dtype=float64> => PyObject <tf.Variable 'Variable_38:0' shape=(10,) dtype=float64>, PyObject <tf.Variable 'Variable_24:0' shape=(32,) dtype=float64> => PyObject <tf.Variable 'Variable_32:0' shape=(32,) dtype=float64>, PyObject <tf.Variable 'Variable_27:0' shape=(3136, 1024) dtype=float64> => PyObject <tf.Variable 'Variable_35:0' shape=(3136, 1024) dtype=float64>), PyObject[PyObject <tf.Tensor 'gradients_6/Reshape_2_grad/Reshape:0' shape=<unknown> dtype=float64>], PyObject <tf.Tensor 'Placeholder_10:0' shape=<unknown> dtype=float64>, PyObject <tf.Operation 'group_deps_2' type=NoOp>, PyObject <tf.Operation 'group_deps_3' type=NoOp>, PyObject <tf.Tensor 'Placeholder_11:0' shape=() dtype=float64>))), FixedStepGradientDescent(0.001))
julia> for i=1:1000

           (xs, ys) = next_batch(loader, 100)

           @assert size(xs) == (100, 784)
           @assert size(ys) == (100,)
           constraints = choicemap()
           for (i, y) in enumerate(ys)
               constraints[(:y, i)] = y

           (trace, weight) = generate(f, (xs,), constraints)

           # increments gradient accumulators
           accumulate_param_gradients!(trace, nothing)

           # performs SGD update and then resets gradient accumulators

           println("i: $i, weight: $weight")
i: 1, weight: -230.2418501131283
i: 2, weight: -230.25861867136845
i: 3, weight: -230.24646340719963
i: 4, weight: -230.25554056052002
i: 5, weight: -230.25868100738683


julia> for i=1:length(test_y)
           x = test_x[i:i,:]
           @assert size(x) == (1, 784)
           true_y = test_y[i]
           pred_y = f(x)
           @assert length(pred_y) == 1
           println("true: $true_y, predicted: $(pred_y[1])")
true: 8, predicted: 5
true: 3, predicted: 10


