Upgrade to Pro — share decks privately, control downloads, hide ads and more …

ChainerRL について / Introduction to ChainerRL (Ja)

tos
June 10, 2017
2k

ChainerRL について / Introduction to ChainerRL (Ja)

Chainer Meetup #05

tos

June 10, 2017
Tweet

Transcript

  1. ࣗݾ঺հ • ยԬ ढ़ج / toslunar • 2016.12– Preferred Networks

    • ࠷ۙɼChainerRL ͷ Chainer v2 ରԠͯ͠·ͨ͠
  2. ChainerRL github.com/chainer/chainerrl • RL
 = Reinforcement Learning
 = ڧԽֶश •

    Chainerͷ֦ுύοέʔδ 2017.03.27 ChainerRL v0.1.0 2017.06.08 ChainerRL v0.2.0 ChainerRL
  3. ڧԽֶशͷΞϧΰϦζϜ (1/2) • Qֶश • ֶश͢Δؔ਺ɿQ*(s, a)
 (ঢ়ଶ s Ͱߦಈ

    a Λͯ͠
 ͦͷޙ͸࠷దߦಈΛͨ͠ͱ͖ͷ
 ใुͷ࿨) • Watkins '89 • Mnih+ '13 (ਂ૚) s a Q*(s, a) ≈ −5.7 a = argmax
 Q*(s, _) … or random
  4. ڧԽֶशͷΞϧΰϦζϜ (2/2) • ํࡦޯ഑๏ • ࠷దԽ͢Δؔ਺ɿπ(s)
 (ঢ়ଶ s ͷͱ͖ʹ͢Δߦಈ) •

    Qπ(s, a) (ޙʹ π ʹै͏ߦಈΛͨ͠ͱ͖ͷ
 ใुͷ࿨) ͳͲΛಉ࣌ʹֶश • Williams '92, Sutton+ '99 • Lillicrap+ '15 (ਂ૚) s a Qπ(s, a) ≈ −5.7 a = π(s) + ε
  5. ChainerRLͰ͸ for _ in range(1000): obs = env.reset() reward =

    0.0 done = False while not done: action = agent.act_and_train(obs, reward) obs, reward, done, _ = env.step(action) agent.stop_episode_and_train(obs, reward, done) agent.save('final_agent')
  6. ؀ڥͷͭ͘Γ͔ͨ (1/2) • ࣗ෼Ͱఆٛ͢Δ৔߹ • ؀ڥͷॳظԽɿenv.reset • ؀ڥͷ࣮ߦɿenv.step • ྫɿ਺౰ͯήʔϜ

    • ॳظԽ࣌ʹൿີͷ਺ΛܾΊΔ • ΤʔδΣϯτͷਪଌ (ʮߦಈʯ) ͷେখΛʮ؍ଌʯͱͯ͠ฦ͢ class GuessNumberEnv (object): def reset(self): self._state = np.random.uniform(-1, 1) obs = np.array([0, 0, 0, 1], dtype=np.float32) return obs def step(self, action): assert action.shape == (1,) diff = action[0] - self._state obs = np.array([0, 0, 0, 0], dtype=np.float32) obs[1 + int(np.sign(diff))] = 1 reward = np.random.normal(0, 1) - abs(diff) return obs, reward, False, None # not done, no info
  7. ΤʔδΣϯτͷͭ͘Γ͔ͨ (ྫɿDeep Q-Network) model = chainerrl.q_functions.FCStateQFunctionWithDiscreteAction( env.observation_space.low.size, env.action_space.n, n_hidden_channels=64, n_hidden_layers=1)

    opt = chainer.optimizers.Adam() opt.setup(model) rbuf = chainerrl.replay_buffer.ReplayBuffer(None) explorer = chainerrl.explorers.LinearDecayEpsilonGreedy(1.0, 0.1, 10**4,
 random_action_func=env.action_space.sample) agent = chainerrl.agents.DQN(model, opt, rbuf, gamma=0.98, explorer=explorer, target_update_interval=100, replay_start_size=10**3)
  8. agent Λߏ੒͍ͯ͠Δ • model • optimizer • replay buffer •

    explorer͋ ͸ɼͦΕͧΕ (ಠཱʹ) มߋՄೳ
  9. ϞσϧΛมߋ͢Δ (1/2) DQNͰ͸ɼঢ়ଶ͕ೖྗɼ֤ߦಈʹର͢ΔQ஋͕ग़ྗͷϞσϧ • ChainerRLʹ༻ҙ͞Ε͍ͯΔϞσϧΛ࢖͏ model = chainerrl.q_functions.FCStateQFunctionWithDiscreteAction( dim_obs, n_action,

    n_hidden_channels, n_hidden_layers) • ࣗ෼Ͱ࡞ͬͨ chain Λ ChainerRL ޲͚ʹม׵ model = chainerrl.q_functions.SingleModelStateQFunctionWithDiscreteAction( MyChain(dim_obs, n_action)) model s Q(s, a1) Q(s, a2) … …
  10. ϞσϧΛมߋ͢Δ (2/2) • ࿈ଓ஋Ͱߦಈ͢Δ؀ڥͰɼNormalized Advantage Function Λ༻͍Δ model = chainerrl.q_functions.FCQuadraticStateQFunction(

    dim_obs, dim_action, n_hidden_channels, n_hidden_layers, action_space) • (ͪͳΈʹɼ͜ͷͱ͖ explorer ͸ DDPG ಉ༷ʹ Ornstein-Uhlenbeck աఔΛ༻͍Δͷ͕ྑ͍) explorer = chainerrl.explorers.AdditiveOU(...)
  11. Replay buffer Λมߋ͢Δ Replay buffer: ϛχόον಺ͷσʔλ͕ภΒͳ͍Α͏ʹɼ
 อଘ͓͍ͯͨ͠ܦݧ͔ΒαϯϓϦϯάֶͯ͠श͢ΔςΫ • αΠζΛઃఆͨ͠ΓɼαϯϓϦϯάΞϧΰϦζϜΛม͑ͨΓͰ͖Δ rbuf

    = chainerrl.replay_buffer.ReplayBuffer(5 * 10**5) rbuf = chainerrl.replay_buffer.EpisodicReplayBuffer(10**4) rbuf = chainerrl.replay_buffer.PrioritizedReplayBuffer(5 * 10**5)
  12. ΞϧΰϦζϜΛมߋ͢Δ • DQN ͷվྑΞϧΰϦζϜ (ͨͱ͑͹ Double DQN) ʹ
 มߋ͍ͨ͠ͱ͖ agent

    = chainerrl.agents.DQN(...) Λ agent = chainerrl.agents.DoubleDQN(...) ͷΑ͏ʹ͢Ε͹OK • ҟͳΔϞσϧΛ༻͍ΔΞϧΰϦζϜ (ͨͱ͑͹ DDPG) ʹมߋ͢Δ৔߹Ͱ΋ replay buffer ΍ explorer ͸ಉ͡΋ͷ͕࢖͑Δ
  13. ࣮૷ࡁΈͷΞϧΰϦζϜ • Q-learning algorithms: • Deep Q-Network, • Double DQN,

    • Normalized Advantage Function, • (Persistent) Advantage Learning, • Asynchronous Advantage Actor-Critic, • Asynchronous N-step Q-learning • Path Consistency Learning • Policy gradient methods: • Deep Deterministic Policy Gradient, • SVG(0), • Actor-Critic with Experience Replay
  14. ChainerRL ͷֶशϧʔϓ • chainerrl.experiments.train_agent ͰͰ͖Δ͜ͱɿ • Ұఆͷ iteration ͝ͱʹςετ؀ڥͰ࣮ߦͤ͞
 ֶशۂઢ

    (Λඳ͘ͷʹඞཁͳσʔλ) Λग़͢ • ϞσϧΛࣗಈͰอଘ͢Δ • chainer ͷ Trainer ʹͳ͍ͬͯͳ͍ • Ͳ͏͢Δ͔ະఆ
  15. ฒྻԽ • A3C ͳͲͷΞϧΰϦζϜ͸ΤʔδΣϯτΛฒྻԽ • ڧԽֶशͰ͸ݱঢ় async update ͕ओྲྀ •

    ChainerRL Ͱ͸ train_agent_async ΛݺͿͱϚϧνϓϩηε࣮ߦ ChainerRL async ChainerMN sync