CartPole 問題は、今までこのブログでもなんども取り上げてきた。

CartPole 問題は、今までこのブログでもなんども取り上げてきた。

今日、ついに解けたのだ。

課題

環境から取得できる情報は次の 4 つ。(observation_space)

  • 横軸座標(x)
  • 横軸速度(x_dot)
  • ポール角度(theta)
  • ポール角速度(theta_dot)

エージェントが取れる行動は環境状態によらず次の 2 つ。(acnton_space)

  • 左に動く
  • 右に動く

課題は、環境から取得できる値が連続値であること。 Q テーブル(状態、行動)対に対して一つの値を扱う。

つまり、連続値である状態を範囲で区切って離散の値に変換して、Q テーブルのインデックスに用いることが必要。

ここでは gdb さんのコード を参考に以下のように書いた。

def build_state(features):
    """get our features and put all together converting into an integer"""
    return int("".join(map(lambda feature: str(int(feature)), features)))

def to_bin(value, bins):
    return np.digitize(x=[value], bins=bins)[0]

cart_position_bins = np.linspace(-2.4, 2.4, 2)
cart_velocity_bins = np.linspace(-2, 2, 10)
pole_angle_bins = np.linspace(-0.4, 0.4, 50)
pole_velocity_bins = np.linspace(-3.5, 3.5, 20)

def transform(observation):
    # return an int
    cart_pos, cart_vel, pole_angle, pole_vel = observation
    return build_state([
        to_bin(cart_pos, cart_position_bins),
        to_bin(cart_vel, cart_velocity_bins),
        to_bin(pole_angle, pole_angle_bins),
        to_bin(pole_vel, pole_velocity_bins)
    ])

これで、transform を呼ぶと、ovservation に対して一意な離散値が出力されるので、それを状態値として扱う。

注目すべきは、ここ。横軸座標(x)は荒く、ポール角度は細かく値を区切っている。こうすると精度がよかったのでこうしている。

cart_position_bins = np.linspace(-2.4, 2.4, 2)
cart_velocity_bins = np.linspace(-2, 2, 10)
pole_angle_bins = np.linspace(-0.4, 0.4, 50)
pole_velocity_bins = np.linspace(-3.5, 3.5, 20)

Code

このコードは、Practical RL の week2 の宿題である。

以下を参考にした。

Result