您好,登錄后才能下訂單哦!
本篇內容介紹了“怎么用Java訓練出一只不死鳥”的有關知識,在實際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領大家學習一下如何處理這些情況吧!希望大家仔細閱讀,能夠學有所成!
在這一節會介紹主要用到的算法以及神經網絡,幫助你更好的了解如何進行訓練。本項目與 DeepLearningFlappyBird 使用了類似的方法進行訓練。算法整體的架構是 Q-Learning + 卷積神經網絡(CNN),把游戲每一幀的狀態存儲起來,即小鳥采用的動作和采用動作之后的效果,這些將作為卷積神經網絡的訓練數據。
CNN 的輸入數據為連續的 4 幀圖像,我們將這圖像 stack 起來作為小鳥當前的“observation”,圖像會轉換成灰度圖以減少所需的訓練資源。圖像存儲的矩陣形式是 (batch size, 4 (frames), 80 (width), 80 (height))
數組里的元素就是當前幀的像素值,這些數據將輸入到 CNN 后將輸出 (batch size, 2)
的矩陣,矩陣的第二個維度就是小鳥 (振翅不采取動作) 對應的收益。
在小鳥采取動作后,我們會得到 preObservation and currentObservation
即是兩組 4 幀的連續的圖像表示小鳥動作前和動作后的狀態。然后我們將 preObservation, currentObservation, action, reward, terminal
組成的五元組作為一個 step 存進 replayBuffer 中。它是一個有限大小的訓練數據集,他會隨著最新的操作動態更新內容。
public void step(NDList action, boolean training) { if (action.singletonOrThrow().getInt(1) == 1) { bird.birdFlap(); } stepFrame(); NDList preObservation = currentObservation; currentObservation = createObservation(currentImg); FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(), preObservation, currentObservation, action, currentReward, currentTerminal); if (training) { replayBuffer.addStep(step); } if (gameState == GAME_OVER) { restartGame(); } }
訓練分為 3 個不同的周期以更好地生成訓練數據:
Observe(觀察) 周期:隨機產生訓練數據
Explore (探索) 周期:隨機與推理動作結合更新訓練數據
Training (訓練) 周期:推理動作主導產生新數據
通過這種訓練模式,我們可以更好的達到預期效果。
處于 Explore 周期時,我們會根據權重選取隨機的動作或使用模型推理出的動作來作為小鳥的動作。訓練前期,隨機動作的權重會非常大,因為模型的決策十分不準確 (甚至不如隨機)。在訓練后期時,隨著模型學習的動作逐步增加,我們會不斷增加模型推理動作的權重并最終使它成為主導動作。調節隨機動作的參數叫做 epsilon 它會隨著訓練的過程不斷變化。
public NDList chooseAction(RlEnv env, boolean training) { if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) { return env.getActionSpace().randomAction(); } else return baseAgent.chooseAction(env, training); }
首先,我們會從 replayBuffer 中隨機抽取一批數據作為作為訓練集。然后將 preObservation 輸入到神經網絡得到所有行為的 reward(Q)作為預測值:
NDList QReward = trainer.forward(preInput); NDList Q = new NDList(QReward.singletonOrThrow() .mul(actionInput.singletonOrThrow()) .sum(new int[]{1}));
postObservation 同樣會輸入到神經網絡,根據馬爾科夫決策過程以及貝爾曼價值函數計算出所有行為的 reward(targetQ)作為真實值:
// 將 postInput 輸入到神經網絡中得到 targetQReward 是 (batchsize,2) 的矩陣。根據 Q-learning 的算法,每一次的 targetQ 需要根據當前環境是否結束算出不同的值,因此需要將每一個 step 的 targetQ 單獨算出后再將 targetQ 堆積成 NDList。 NDList targetQReward = trainer.forward(postInput); NDArray[] targetQValue = new NDArray[batchSteps.length]; for (int i = 0; i < batchSteps.length; i++) { if (batchSteps[i].isTerminal()) { targetQValue[i] = batchSteps[i].getReward(); } else { targetQValue[i] = targetQReward.singletonOrThrow().get(i) .max() .mul(rewardDiscount) .add(rewardInput.singletonOrThrow().get(i)); } } NDList targetQBatch = new NDList(); Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value))); NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));
在訓練結束時,計算 Q 和 targetQ 的損失值,并在 CNN 中更新權重。
我們采用了采用了 3 個卷積層,4 個 relu 激活函數以及 2 個全連接層的神經網絡架構。
layer | input shape | output shape |
---|---|---|
conv2d | (batchSize, 4, 80, 80) | (batchSize,4,20,20) |
conv2d | (batchSize, 4, 20 ,20) | (batchSize, 32, 9, 9) |
conv2d | (batchSize, 32, 9, 9) | (batchSize, 64, 7, 7) |
linear | (batchSize, 3136) | (batchSize, 512) |
linear | (batchSize, 512) | (batchSize, 2) |
DJL 的 RL 庫中提供了非常方便的用于實現強化學習的接口:(RlEnv, RlAgent, ReplayBuffer)。
實現 RlAgent 接口即可構建一個可以進行訓練的智能體。
在現有的游戲環境中實現 RlEnv 接口即可生成訓練所需的數據。
創建 ReplayBuffer 可以存儲并動態更新訓練數據。
在實現這些接口后,只需要調用 step 方法:
RlEnv.step(action, training);
這個方法會將 RlAgent 決策出的動作輸入到游戲環境中獲得反饋。我們可以在 RlEnv 中提供的 runEnviroment
方法中調用 step 方法,然后只需要重復執行 runEnvironment
方法,即可不斷地生成用于訓練的數據。
public Step[] runEnvironment(RlAgent agent, boolean training) { // run the game NDList action = agent.chooseAction(this, training); step(action, training); if (training) { batchSteps = this.getBatch(); } return batchSteps; }
我們將 ReplayBuffer 可存儲的 step 數量設置為 50000,在 observe 周期我們會先向 replayBuffer 中存儲 1000 個使用隨機動作生成的 step,這樣可以使智能體更快地從隨機動作中學習。
在 explore 和 training 周期,神經網絡會隨機從 replayBuffer 中生成訓練集并將它們輸入到模型中訓練。我們使用 Adam 優化器和 MSE 損失函數迭代神經網絡。
首先將圖像大小 resize 成 80x80
并轉為灰度圖,這有助于在不丟失信息的情況下提高訓練速度。
public static NDArray imgPreprocess(BufferedImage observation) { return NDImageUtils.toTensor( NDImageUtils.resize( ImageFactory.getInstance().fromImage(observation) .toNDArray(NDManager.newBaseManager(), Image.Flag.GRAYSCALE) ,80,80)); }
然后我們把連續的四幀圖像作為一個輸入,為了獲得連續四幀的連續圖像,我們維護了一個全局的圖像隊列保存游戲線程中的圖像,每一次動作后替換掉最舊的一幀,然后把隊列里的圖像 stack 成一個單獨的 NDArray。
public NDList createObservation(BufferedImage currentImg) { NDArray observation = GameUtil.imgPreprocess(currentImg); if (imgQueue.isEmpty()) { for (int i = 0; i < 4; i++) { imgQueue.offer(observation); } return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1)); } else { imgQueue.remove(); imgQueue.offer(observation); NDArray[] buf = new NDArray[4]; int i = 0; for (NDArray nd : imgQueue) { buf[i++] = nd; } return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1)); } }
一旦以上部分完成,我們就可以開始訓練了。訓練優化為了獲得最佳的訓練性能,我們關閉了 GUI 以加快樣本生成速度。并使用 Java 多線程將訓練循環和樣本生成循環分別在不同的線程中運行。
List<Callable<Object>> callables = new ArrayList<>(numOfThreads); callables.add(new GeneratorCallable(game, agent, training)); if(training) { callables.add(new TrainerCallable(model, agent)); }
“怎么用Java訓練出一只不死鳥”的內容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業相關的知識可以關注億速云網站,小編將為大家輸出更多高質量的實用文章!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。