2. 問題1:売上予測

2.1. 取り組む問題

M5 Forecasting - Accuracy Competition

2.1.1. 背景

ウォルマートはアメリカ最大のスーパーマーケットです.同社が直面する主な課題のひとつは,どの商品を,いつ,どのくらいの量仕入れるかを決定することです.間違った商品を仕入れると,潜在的な売上を逃し,在庫コストの増加につながります.

青果物のような特定の商品の需要は,収穫期やその商品の供給量と連動している場合があります.そのような商品の場合,供給が徐々に変化するにつれて,需要もスムーズに変化することが期待できます.しかし,他の商品では同じ傾向は見られず,休日や行事によって需要が急激に変化することもあります.商品需要の傾向を把握し予測することは,スーパーマーケットが適切な量の商品を仕入れ,損失を最小限に抑えながら利益を最大化するのに役立ちます.

2.1.2. タスク

ウォルマートの様々な品目に関する過去の売上データが与えられたとき,個々の店舗における次の月の売上を予測してください.

2.1.3. データ

  • calendar.csv - 商品が販売される日付に関する情報.

  • sample_submission.csv - 投稿用の正しいフォーマット.

  • sell_prices.csv - 各店舗で販売された商品の価格と日付に関する情報.

  • sales_train_evaluation.csv - [d_1 - d_1941]` 日間における,各商品と各店舗の過去の日次販売個数データ.

  • sales_train_validation.csv - [d_1 - d_1913]`日の各商品と店舗の過去の日次販売台数データ.#このファイルは使用しないでください.

当初のコンペティションは2つのステージで構成されていました.第1ステージでは,1日目から1913日目までのデータが発表され,1914日目から1941日目までの予測に基づいて得点が決定されました.第2段階では,1914~1941年の日数のデータが発表され,1942~1969年の日数の予測に基づいて得点が決定されました.ここでは簡単のため,sales_train_evaluation.csvのみを使用して,このコンペティションのセカンドステージに取り組みます.

2.2. セットアップ

以下のセルは必要なデータをダウンロードし,ノートブックで使用する環境を設定するためのものです.

2.2.1. データセット

Kaggleはコンペティションと簡単にやり取りできるAPIを提供しています.このAPIを使って自動的にデータをダウンロードし,予測をアップロードします.

このAPIを使用する最初のステップは,自分のユーザーとして認証することです.APIトークンはユーザー名とKaggleが生成したキーを含むファイルです.トークンはアカウントページからダウンロードすることができ,通常 kaggle.json と呼ばれます.APIトークンはユーザーとしてAPIにアクセスするために必要なものなので,個人のGoogle Driveフォルダに安全に保管してください.

このノートブックはGoogle Driveフォルダ内のkaggle.jsonというKaggle APIトークンを検索します.トークンをGoogle Driveに置いたことを確認し,プロンプトが表示されたらこのノートブックがトークンにアクセスすることを許可してください.

from google.colab import drive
import os
import json

drive.mount("/content/drive", force_remount=True)
fin = open("/content/drive/MyDrive/kaggle/token/kaggle.json", "r")
json_data = json.load(fin)
fin.close()
os.environ["KAGGLE_USERNAME"] = json_data["username"]
os.environ["KAGGLE_KEY"] = json_data["key"]

認証後,参加したすべてのコンペティションにアクセスできます.データのダウンロードにエラーが発生した場合は,Kaggle API トークンが有効であること,コンペティションのルールに同意していることを確認してください.

%%bash
kaggle competitions download -c m5-forecasting-accuracy --force
if [ $? -ne 0 ]; then
    echo "データのダウンロードに問題があったようです."
    echo "競技規則に同意し,APIキーが有効であることを確認してください."
else
    mkdir -p /content/kaggle
    unzip -o /content/m5-forecasting-accuracy.zip -d /content/kaggle
fi
wget -q -P /tmp https://noto-website-2.storage.googleapis.com/pkgs/NotoSansCJKjp-hinted.zip
unzip -u /tmp/NotoSansCJKjp-hinted.zip -d /usr/share/fonts/NotoSansCJKjp

2.2.2. 計算環境

import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.tsa.seasonal import STL
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.forecasting.stl import STLForecast

font_path = '/usr/share/fonts/NotoSansCJKjp/NotoSansMonoCJKjp-Regular.otf'
matplotlib.font_manager.fontManager.addfont(font_path)
prop = matplotlib.font_manager.FontProperties(fname=font_path)
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = prop.get_name()
os.chdir('/content/kaggle')

2.3. 探索的データ解析

いくつかの異なるファイルがあるので,各ファイルにどのような種類のデータがあるのか視覚化してみると良いです.

2.3.1. sales_train_evaluation.csv

sales = pd.read_csv('sales_train_evaluation.csv')
sales.sample(10)

このファイルには item_id で指定されたアイテムの販売履歴が含まれます.アイテムには関連する部門とカテゴリが dept_idcat_id に格納されています.さらに,場所に関する情報が state_id に,特定の店舗に関する情報が store_id に格納されています.残りの d_x というパターンにマッチする列は,x 日の売上数を示しています.データは 2011-01-29 から始まるので,d_1 には2011-01-29の売上が格納されています.

トレーニングデータは3つのカテゴリーに大別されます.

  • 食べ物

  • 趣味

  • 家庭用品

以下のプロットは,ほとんどの売上が食品であることを示しています.

total_sales = sales.copy()
day_cols = [col for col in sales.columns if col.startswith('d_')]
total_sales['total'] = total_sales[day_cols].sum(axis=1)
total_sales = total_sales.groupby('cat_id', as_index=False).agg({'total': 'sum'})
ax = sns.barplot(total_sales, x='cat_id', y='total');

これは時系列データです.予測の際にはデータが時間とともにどのように変化するかを意識する必要があります.

sales_long = sales.groupby('cat_id', as_index=False).agg({col: 'sum' for col in day_cols})
sales_long = sales_long.melt(id_vars=['cat_id'], var_name='day', value_name='sales')
sales_long['day'] = sales_long['day'].str.replace('d_', '').astype(int)
ax = sns.lineplot(sales_long, x='day', y='sales', hue='cat_id')
ax.grid()
ax.legend();

この図によると,すべての売上高はおおむね増加していますが,停滞あるいは減少している時期もあります.

時系列を分析するための一般的な手法のひとつに,トレンドと季節的要素(シーズナリティ)でデータを分解してモデル化することがあります.このデータでは,トレンドは明らかに上向きですが,売上減の変動も見られる.これらの変動は,周期的,あるいは季節的な何らかの影響によるものかもしれません.分解を行うには,statsmodels パッケージを使用します.

sales_long = sales.groupby('cat_id', as_index=False).agg({col: 'sum' for col in day_cols})
sales_long = sales_long.melt(id_vars=['cat_id'], var_name='day', value_name='sales')
sales_long['day'] = sales_long['day'].str.replace('d_', '').astype(int)
sales_long = sales_long.pivot(index='day', columns='cat_id', values='sales')
sales_long = sales_long.sort_index()
sales_long = sales_long.set_index(pd.to_datetime(pd.read_csv('calendar.csv')['date'].values[:1941]))
sales_long = sales_long[sales_long.index.year==2012]
stl = STL(sales_long['FOODS'], period=30)
res = stl.fit()
ax = res.plot()
fig = ax.get_figure()
fig.set_size_inches(14, 7)
ax.tight_layout()

上の図は,季節期間を30日に設定した2012年のFOODSの分解です.トレンドと季節を分けることで,それぞれが元のデータにどのように寄与しているかを明確に見ることができます.これからわかることは,シーズンの初め,つまりこの場合は月の初めに,売上が一般的に高くなることです.

2.3.2. calendar.csv

calendar = pd.read_csv('calendar.csv')
calendar

このファイルにはデータセットの各日に関する情報が含まれており, d_1 のような日を 2011-01-29 のような実際の日付に変換するのに使うことができます.データセットの開始日と終了日は1年単位で揃っていないことに注意してください.これを下図に示します.

sns.catplot(data=calendar, x='year', kind='count', color='tab:blue');

このファイルには特別なイベントに関する情報も含まれています.特別なイベントは比較的頻繁に発生しないものの,売上に影響を与える可能性があります.

event_idx = calendar['event_name_1'].notna() | calendar['event_name_2'].notna()
n_events = event_idx.sum()
plt.bar(['イベントない日', 'イベントある日'], [calendar.shape[0] - n_events, n_events])

2.3.3. sell_prices.csv

sell_prices = pd.read_csv('sell_prices.csv')
sell_prices.head()

このデータは,各店舗の様々な商品の価格を時系列で示したものです.価格は場所によって異なる場合があります.

wm_yr_wk_to_date = calendar[['wm_yr_wk', 'date']].groupby('wm_yr_wk').min()
sell_prices_dt = sell_prices.groupby(['wm_yr_wk', 'item_id'], as_index=False)['sell_price'].agg('mean')
sell_prices_dt = sell_prices_dt.merge(wm_yr_wk_to_date, how='inner', on='wm_yr_wk')
sell_prices_dt['date'] = pd.to_datetime(sell_prices_dt['date'])

ax = sns.lineplot(sell_prices_dt[sell_prices_dt['item_id']=='FOODS_1_001'], x='date', y='sell_price', errorbar=None, label='FOODS_1_001')
ax = sns.lineplot(sell_prices_dt[sell_prices_dt['item_id']=='HOUSEHOLD_1_001'], x='date', y='sell_price', errorbar=None, label='HOUSEHOLD_1_001')
ax = sns.lineplot(sell_prices_dt[sell_prices_dt['item_id']=='HOBBIES_1_001'], x='date', y='sell_price', errorbar=None, label='HOBBIES_1_001')
ax.legend()
ax.figure.autofmt_xdate()

この図では3個の商品の価格を示しています.一時的な売れ行きを示すと思われる小規模で短期的な価格下落が見られます.また,長期的な価格の変化も見られますが,これは市場の需要や生産コストに対するメーカーの対応と考えられます.

2.4. モデリング

2.4.1. 分解と予測

将来の売上を予測するために,上で紹介した時系列分解法と,時系列予測によく使われるモデルであるARIMAモデルを組み合わせます.単純なベースラインモデルとして,カテゴリーごとに全国的な売上高を予測します.このモデルは特定の商品の動向を無視し,立地も考慮しないため,あまり良いスコアは得られないかもしれません.

sales_modeling = sales.groupby('cat_id', as_index=False).agg({col: 'sum' for col in sales.columns if col.startswith('d_')})
sales_modeling = sales_modeling.melt(id_vars=['cat_id'], var_name='day', value_name='sales')
sales_modeling['day'] = sales_modeling['day'].str.replace('d_', '').astype(int)
sales_modeling = sales_modeling.pivot(index='day', columns='cat_id', values='sales')
sales_modeling = sales_modeling.sort_index()
sales_modeling = sales_modeling.set_index(pd.to_datetime(pd.read_csv('calendar.csv')['date'].values[:1941]))
sales_modeling.index.freq = 'D'

forecast = {}
for cat_id in sales['cat_id'].unique():
    stlf = STLForecast(sales_modeling[cat_id], ARIMA, model_kwargs=dict(order=(1, 1, 0), trend="t"), period=30)
    stlf_res = stlf.fit()
    forecast[cat_id] = stlf_res.forecast(28)
fig, ax = plt.subplots()
for i, cat_id in enumerate(sales_modeling.columns):
  cmap = matplotlib.color_sequences['tab10']
  ax.plot(sales_modeling[cat_id].iloc[-100:].index, sales_modeling[cat_id].iloc[-100:].values, color=cmap[i], label=cat_id)
  ax.plot(forecast[cat_id].iloc[-100:].index, forecast[cat_id].iloc[-100:].values, color=cmap[i], linestyle='dashed', label=f'{cat_id}予測')
  ax.figure.autofmt_xdate()
  ax.legend(ncols=3)

最後に,提出ファイルを作成し,そこに予測値を記入します.総売上高を予測するので,各店舗の平均売上高を得るために店舗数で割らなければなりません.

def extract_cat_id(string):
    return string.split('_')[0]

def extract_split(string):
    return string.split('_')[-1]
submission = pd.read_csv('sample_submission.csv')
submission['cat_id'] = submission['id'].apply(extract_cat_id)
submission['split'] = submission['id'].apply(extract_split)

# fill validation with correct data
validation_idx = submission['id'].str.endswith('validation')
forecast_cols = [f'F{i}' for i in range(1, 29)]
d_1914_1941_cols = [f'd_{i}' for i in range(1914, 1942)]
submission.loc[validation_idx, forecast_cols] = sales.loc[submission.loc[validation_idx].index, d_1914_1941_cols].values

# fill evaluation with forecasted data
evaluation_idx = submission['id'].str.endswith('evaluation')
for cat_id, forecasted_sales in forecast.items():
    cat_idx = submission['cat_id'] == cat_id
    num_stores = sum(evaluation_idx & cat_idx)
    submission.loc[(evaluation_idx & cat_idx), forecast_cols] = (forecasted_sales / num_stores).values
submission.drop(columns=['cat_id', 'split']).to_csv('submission.csv', index=False)
submission

2.4.2. Kaggleへのアップロード

予測はKaggleのAPIを通して直接提出されるため,ファイルを手動でダウンロードし,Kaggleのウェブサイトに再アップロードする必要はありません.

! kaggle competitions submit -c m5-forecasting-accuracy -f submission.csv -m "upload"

提出後は提出ページでスコアを確認します.

2.5. 性能向上のための提案

2.5.1. 異なる季節性期間

時系列分解で,一般的に月の初めに売上が高くなることが示されました.時系列分解に週単位を使用することで,他の傾向を示す可能性があります.

2.5.2. 店舗を個別にモデル化する

このノートブックでは現在,各店舗の売上が同じであると仮定しています.店舗間の売り上げは普通異なる場合が多いです.

2.5.3. より複雑なモデル

ここで使われているモデルは比較的単純です.LSTMやコンペティションの勝者が使ったモデル等,他のタイプのモデルを使ってみてください.

2.5.4. 追加データ

現在のモデルは販売データしか使っていません.しかし,価格や休日などのデータもあり,予測を改善できるかもしれません.