R版の因果推論パッケージCausalImpactの評価のために、軍事侵攻が外国為替レートに与えた影響を、counterfactual推論(注1)の例として紹介しました。このCausalImpactのPython実装版がgoogleから提供されました。これはTensorFlow Probabilityの上に実装されています。その他、Python用の因果推論パッケージが多くリリースされています。その中でPyMCの因果推論ライブラリCausalPyを取り上げ、前述のR版CausalImpactで用いたデータセットを使ってこの実装を評価してみます。
因果推論ライブラリ
ここで評価するツール以外にもPythonの因果推論パッケージが多くリリースされいます。西海岸の企業だけでなく、ハイテク企業やコンピュータサイエンスに力を入れている大学などが、最近、AIへの投資を増加させているので、Googleの他、複数の企業や研究機関からベイジアンと機械学習に関連した因果推論ライブラリが提供されています。別の機会にいくつか紹介したいと思います。
tfCausalImpact
TensorFlow Probabilityは、当初、EdwordとしてTensorFlowを使って実装されていたPythonのベイジアンライブラリをTensorFlowが本体に取り込んで、機能を追加したものです。TensorFlowがKerasを内部に取り込んで本体のAPIとして提供しているのと同様です。以前のKerasやEdwordは外部ライブラリでなくTensorFlow内のAPIとして提供されています。
- インストール
pip install tfCausalimpact
APIの引数の仕様はR版を踏襲しており、ほぼ同じ仕様です。結果の出力も以下の記事で解説しているR版のCαusalImpactと同一の仕様になっています。
CausalPy
CausalPyはPyMCをベースにした因果推論パッケージです。
APIの引数の仕様は以下の”割り込まれた時系列データの例”を参照してください。結果の表示におけるHDIはHighest Density Intervalsを意味し0を含む、94%の信頼区間として定義してあります。
- インストール
pip install CausalPy
割り込まれた時系列データの例
R版のCausalImpactの評価で用いたデータセット(期間2021-05-12 ~ 2023-03-17のドル円為替レート)を使って、CausalPyの時系列データに対する割り込み(intervention)を評価します。最初にCausalPyをインポートします。
import arviz as az
import pandas as pd
import causalpy as cp
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'
seed = 42
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
CausalImpactで用いた為替レートのデータ(USDJPY.csv)を読み込みます。
pd.readcsv(index_colum = 0)
でインデックスカラムを指定して読み込んだ場合、dateがindexとして認識されますがデータ型が文字列になります。このままCausalPyのAPIに渡すとindexがTimestamp型でないため、エラーになります。そのため、index_columを設定せずにcsvを読み込み、データフレームのインデックスをTimestamp型に設定します。
urlpath = "./../ファイルパス/USDJPY.csv"
#df_x = pd.read_csv(urlpath,index_col=0)
df_x = pd.read_csv(urlpath)
df_x.index = pd.to_datetime(df_x['date'])
df_x = df_x.drop('date',axis=1)
割り込みの日時を設定し、新しいデータフレームを作ります。
treatment_time = pd.to_datetime("2022-02-25")
df = pd.DataFrame({ 'x':df_x['t'],'y':df_x['close']})
df.head()
データ系列は以下のようになります。
モデルを定義し、PyMCのサンプラーを実行します。fomula
の定義"y ~ 1 + x"
の最初の項の1は、線形回帰においてインターセプトを適合させるかどうか(0または1)を意味します。
result = cp.pymc_experiments.InterruptedTimeSeries(
df,
treatment_time,
formula="y ~ 1 + x",
model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
)
サンプリングが終了しました。結果を表示します。
fig, ax = result.plot()
R版CasalImpactと比較して,CausalPy版のCausalImpactも同じ三つのパネルを表示します。同様の結果を確認することができます。
result.summary()
総合的な影響による推論 - Synthetic controlの例
CausalPyを使った複数のユニットの中の特定のユニットに処置を施す合成コントロール(Synthtic control)の例として、複数の通貨バスケットの中での出来事の影響によるJPYの推移を調べてみます。Synthetic controlのcounterfactualの例を外国為替レートに応用して検証してみましょう。
日本円以外の主要国の通貨のクロスレートを用いて調べてみます。他通貨との比較をわかりやすくするために、円の為替レートはUSDJPYでなくJPYUSDを使います。通貨のクロスレートの単位を合わせるために、USD 1 : JPY 100でスケールを変換しています。ここで用いるJPYUSDの指標は、実際のJPYUSDレートを100倍しています。
次の通貨、AUD,CHF,EUR,GBPを使います。counterfactual予測の前に、線形回帰と同様の考えで、これらの通貨を用いてJPYUSDを近似します。これらの通貨をdonor poolとみなし、JPYの結果を予測します。treatmentの時期は、前の割り込み時系列の例と同じ日程の2022年2月末日で設定します。シミュレーション後にtreatment以後のJPYのcounterfactual予測と実際のデータを比較します。
主要通貨の対ドルレートのデータを’’crossUSD.csv"を読み込みます。
urlpath = "./ファイルパス/crossUSD.csv"
df_x = pd.read_csv(urlpath)
df_x.index = pd.to_datetime(df_x['date'])
df_x = df_x.drop('date',axis=1)
df2 = pd.DataFrame({ 'aud':df_x['Audusd'],'chf':df_x['Chfusd'],'eur':df_x['Eurusd'],'gbp':df_x['Gbpusd'],'y':df_x['Jpyusd']})
df2
データフレームは以下のようになります。
Synthetic control estimatorの予測には、pymc_modelsモジュールからWeightedSumFitter()を用います。
model2 = cp.pymc_models.WeightedSumFitter()
formulaの構造は、以下の定義になります。interceptはfitさせません。
formula2 = 'y ~ 0 + aud + chf + gbp + eur'
モデルを定義して、サンプラーを動かします。
treatment_time = pd.to_datetime("2022-02-28")
model2 = cp.pymc_models.WeightedSumFitter()
formula2 = 'y ~ 0 + aud + chf + gbp + eur'
results = cp.pymc_experiments.SyntheticControl(
df2,
treatment_time,
formula = formula2,
model = model2
)
サンプリングが終了しました。結果を表示します。
fig, ax = results.plot()
results.summary()
結果を見ると、図のSynthetic Controlとの間で幾らかの乖離があります。出来事の前まで、地域的に近いAUDとの連動の割合が大きかった分、実際の観測値は、AUDの推移より大きな変動をもたらしていることを意味しています。欧州通貨同様にエネルギー需給の影響を強く受けているため、乖離は大きくなってます。
Synthetic controlの各通貨の重みの係数を表示します。
ax = az.plot_forest(results.idata, var_names="beta", figsize=(6, 3))
ax[0].set(title="Estimated weighting coefficients");