【AI + pygame】pygameで作るインベーダー風ゲーム 第4回 強化学習編その1(必要なメソッドの実装)

2019年04月26日

(2019年5月16日更新:コード修正)
_get_observation() の2箇所

  • ビーム位置情報の処理
  • self.observation_space の dtype=float

エンジニアのMです。
pygame によるインベーダー風ゲームの改造記事の第4回です。

いよいよ強化学習編です。
pygame 製のゲームに強化学習を適用するには、

  • ゲームからゲームプレイに必要な情報を得る(画面、変数の状態など)
  • 学習器側の入力をゲームに反映する
  • 強化学習に必要な報酬を設定する

といった課題をまずクリアする必要があります。
その上でパラメータやモデルのネットワーク構造を吟味して、時間をかけて学習させなければなりません。

今回のコードと画像データは、こちらからダウンロードできます。
-> pygame_invader_4

どのように pygame 製のゲームに対して強化学習をするか

まずは出来そうかどうかについて情報を収集

まず、どうすれば pygame 製のゲームに強化学習を適用できるのでしょうか?
ATARI のゲームを攻略した「Deep Q-Network(DQN)」について調べているとき、
OpenAI Gym」という強化学習シミュレーション環境と、
Keras-RL」という強化学習ライブラリについて知りました。
(参考:Keras-RLを用いた深層強化学習コト始め – Qiita

これらの情報から、pygame 製ゲームが「OpenAI Gym」の枠組みに沿ったパラメータを「Keras-RL」に渡してやれば、 DQN による学習が出来そうだという結論に行き着きました。

pygame でステップ実行をする

Gym は1ステップ毎に入力を受け付けて、その結果(報酬・状態など)を返すようになっています。
pygame でも同じことができれば、Keras-RL と連携して学習させることができるはずです。

まず通常の動作について確認してみましょう。

# メインループ開始
clock = pygame.time.Clock()
while True:
    clock.tick(60)
    self.update()
    self.draw(screen)
    pygame.display.update()
    self.key_handler()

インベーダー風ゲームのメインループ部分です。
毎秒60回のペースで、update 以下が実行されるようになっています。
つまり、「update 1回分」が「1フレーム=1ステップ」に相当するということなので、
Keras-RL 側から update 以下の処理を制御できれば良いわけです。

例えば以下のようなメソッドを Invader クラスに実装します。

def step(self, action):
    reward = 0
    done = False
    self._key_action(action)
    self.update()
    self.draw(self.screen)
    pygame.display.update()
    self._get_observation()

    # ここに報酬の計算処理が入る

    return self.observation_space, reward, done, dict()

key_handler の代わりに、Keras-RLからの入力を反映する _key_action というメソッドが入っています。
_get_observation で Keras-RL に渡す、ゲームの状態(self.observation_space)を取得します。
「update -> draw -> display.update」の流れは同様に含まれているので、pygame 側の1フレーム分の処理が進み、画面もしっかり更新されます。

返り値は gym に関する例(参考)を踏襲しています。

action = 0
observation, reward, done, info = env.step(action)

info は辞書型が想定されていたので、Invader.step も空辞書を返してエラーを回避するようにします。
このようにメソッドの動作と名前を揃えて実行することで、Keras-RL には gym が動いているように見せるわけです。

これでステップ実行と必要な値を渡す処理まで pygame 側でできそうということが確認できました。

報酬を決定する

強化学習では、設定された報酬値を最大化する方向に学習していきます。
このとき、適切な報酬を設定しなければ学習が全く進まない恐れもあります。

報酬を決めるにあたって、ゲームの目的を再確認しましょう。
インベーダー風ゲームの目的は、弾を避けながらエイリアンを全滅させることです。
つまり最低限設定しなければならない報酬は、

  • プラス報酬 :エイリアンの撃破
  • マイナス報酬:自機の被弾

ということになります。
これに加えて、他の好ましい行動にはプラス報酬を、好ましくない行動にマイナス報酬を与えることにし、各報酬値の大小についても考えていく必要があります。

報酬については、「エイリアン1体撃破の報酬」を「+1」とし、これを基準に考えていきます。
しかし学習に非常に時間がかかることもあり、未だに適切な報酬のバランスが見つかっていません。
以下に示すのは、現在実験中の報酬設定です。2つ以上同時に起こった場合足し合わされます。

  • 「+ 1 」:エイリアン撃破
  • 「+300」:ステージクリア
  • 「+ 15」:UFO撃破
  • 「- 15」:自機被弾
  • 「- 20」:ゲームオーバー
  • 「- 1 」:画面端付近に自機がいる時(1ステップ毎)
  • 「- 1 」:エイリアンを撃破できない時間が15秒を超えたとき(1ステップ毎)
  • 「-1000」:開始から75秒経過したとき(タイムオーバー)

あまり画面端で待機されてしまっても面白くないので、座標によるペナルティも設けています。
実験の中である程度の目処がついている設定ですが、どこまで適切かどうかは判断できていません。
今後のDQNはこれらの報酬値を使って進めていきます。

強化学習向けに細かく修正

画像の修正

このゲームでは画像から当たり判定を取っています。
見た目よりビームの当たり判定が広いなど気になる点がいくつかあったので、この機会に修正しています。

  • ビームのグラフィック修正と幅縮小
  • 自機画像の余白削除
  • ミサイルの幅拡大

この記事のzipファイルに新しい画像が含まれているので、上書きしてください。

既存コードの修正

画像変更に伴うもの以外の修正もあります。

  • 後で必要な変数を作成
  • リロードタイムの管理を修正
  • 情報表示の追加(自機の横座標)
  • 自機位置の調整
  • 無敵時間の削除
  • エイリアンの移動範囲拡大
  • 自機の移動範囲縮小
# 略
SCR_RECT = Rect(0, 0, 640, 480)
PLAYER_MOVE_RECT = Rect(40, 0, 560, 480)


class Invader:
    def __init__(self, step=False, image=False):
        self.lives = 5  # 残機数
        self.wave = 1  # Wave数
        self.counter = 0  # タイムカウンター(60カウント=1秒)
        self.score = 0  # スコア(エイリアン 10点、UFO 50点にWaveを乗算)
        self.step_flag = step  # 学習用にstep実行するかどうか
        self.image_flag = image  # 学習用データを画像にするかどうか
        self.observation_space = None  # 学習環境の情報
        self.num_aliens = None  # エイリアンの数
        self.prev_lives = None  # 直前のステップの残機
        self.prev_num_aliens = None  # 直前のステップのエイリアン数
        pygame.init()
        self.screen = pygame.display.set_mode(SCR_RECT.size, pygame.DOUBLEBUF)
        pygame.display.set_caption("Invader Part4")
    def update(self):
        """ゲーム状態の更新"""
        if self.game_state == PLAY:
            # タイムカウンターを進める
            self.counter += 1
            # リロード時間を減らす
            if self.player.reload_timer > 0:
                self.player.reload_timer -= 1
            # UFOの出現判定(15秒後に出現する)
            if self.counter == 900:
                UFO((10, 30), self.wave)

    def draw(self, screen):
        """描画"""
        # 略
        elif self.game_state == PLAY:  # ゲームプレイ画面
            # 略
            # wave数と残機数を描画
            stat_font = pygame.font.SysFont(None, 20)
            stat = stat_font.render("Wave:{:2d}  Lives:{:2d}  Score:{:05d}  pos_X:{:3d}".format(
                                        self.wave, self.lives, self.score,
                                        self.player.rect.center[0]), False, (255,255,255))

    def collision_detection(self):
        """衝突判定"""
        # 略
        if beam_collided:  # プレイヤーと衝突したビームがあれば
            Player.bomb_sound.play()
            Explosion(self.player.rect.center)
            self.lives -= 1
            self.player.invisible = 0 # DQN用は無敵時間無し
            if self.lives < 0:
                self.game_state = GAMEOVER  # ゲームオーバー!


class Player(pygame.sprite.Sprite):
    """自機"""
    # 略
    def update(self):
        # 押されているキーをチェック
        pressed_keys = pygame.key.get_pressed()
        # 押されているキーに応じてプレイヤーを移動
        if pressed_keys[K_LEFT]:
            self.rect.move_ip(-self.speed, 0)
        elif pressed_keys[K_RIGHT]:
            self.rect.move_ip(self.speed, 0)
        #self.rect.clamp_ip(SCR_RECT)
        self.rect.clamp_ip(PLAYER_MOVE_RECT)
        # ミサイルの発射
        if pressed_keys[K_SPACE]:
            # リロード時間が0になるまで再発射できない
            '''
            if self.reload_timer > 0:
                # リロード中
                self.reload_timer -= 1
            '''
            if self.reload_timer == 0:
                # 発射!!!
                Player.shot_sound.play()
                Shot(self.rect.center)  # 作成すると同時にallに追加される
                self.reload_timer = self.reload_time

元のリロードタイム管理だと、実はスペースキーを押している間しか、リロードタイムが減っていませんでした。
まず、これを常に減るようにします。
それと同時に、外部から自機の行動を管理に備えてリロードタイムの管理を Player クラスの外に出しています。

無敵時間の削除については、学習の簡略化を意図しています。
入力として画像だけを使った場合、無敵時間の認識は困難だと思われます。
仮に自機の点滅を無くしたとしても、敵弾に当たったり当たらなかったりでは学習に支障が出そうです。

「screen -> self.screen」となっているのは、後で状態取得の際にも必要だからです。
また、pygame.DOUBLEBUF という引数を追加して、ハードウェア描画に切り替えています。
これにより若干の高速化が見込めます。

自機の移動範囲は SCR_RECT(画面サイズと同じ大きさの四角形)で制御されていたのですが、
移動範囲を狭くするため新たに PLAYER_MOVE_RECT を用意しています。
これにより、強化学習時により中央に近い場所で戦うことを強制します。

強化学習に必要なメソッドの実装

必要なメソッドは以下の通りです。

  • reset:環境初期化
  • render:環境の視覚化
  • step:学習器からの操作に従っての1ステップ実行

render については、mode を引数にとって現環境を視覚化するメソッドです。
何もしなくても pygame が画面表示するので、エラー回避用で空のメソッドを用意します。

    def render(self, mode=None):
        pass

resetstep について見ていく前に、

  • 学習器から受け取った action をゲームに反映するメソッド
  • 学習環境の状態を取得するメソッド

について確認しましょう。

受け取った action を反映するメソッド

可能な行動は、

  • ミサイル発射
  • ミサイル発射+左移動
  • ミサイル発射+右移動

の3つとします。
ミサイルを発射しない、という操作を省いて学習を高速化します。

class Invader:
    def _key_action(self, action): # キー入力を代替
        # 常にミサイル発射
        if self.player.reload_timer == 0:
            Shot(self.player.rect.center)
            self.player.reload_timer = self.player.reload_time
        # 左移動
        if action == 1:
            self.player.rect.move_ip(-self.player.speed, 0)
            self.player.rect.clamp_ip(PLAYER_MOVE_RECT)
        # 右移動
        elif action == 2:
            self.player.rect.move_ip(self.player.speed, 0)
            self.player.rect.clamp_ip(PLAYER_MOVE_RECT)
        pass

操作が反映されればいいので、値は返しません。
中身を見ると分かりますが、Player クラスから必要な処理をそのまま持ってきています。
常にミサイルを発射するようにしたおかげで、コードが短く済んでいます。

状態を取得するメソッド

このメソッドの実装にあたって、NumPyPillow という2つのライブラリを使用します。
NumPy は数値演算ライブラリで、Pillow は画像処理ライブラリです。
どちらも pip からインストール可能です。

このメソッドは、環境の状態として「画像」と「指定のパラメータ」のどちらかを返すようにします。
後で両方の場合について強化学習をします。
指定のパラメータとしては以下のものを取得することにします。

  • エイリアンとUFOの位置情報と速度
  • ミサイルとビームの位置情報
  • 自機の位置情報とリロードタイマーの情報

学習時には壁を削除するつもりなので、壁の情報は取得しません。

エイリアンの情報を取得する際には、並び順を固定したほうが学習に都合が良いと思われるので、スプライト生成時にリストのほうに情報を追加するようにします。
ミサイルの情報もエイリアンと同様に並びを固定したほうが良いとは思うのですが、それほど重要では無いと考えているので self.shots からそのまま取得します
ビームは自機周辺の情報のみを取得するようにして、回避させます。

import numpy as np
from PIL import Image

class Invader:
    def _get_observation(self):
        if self.image_flag:
            resize_x = 160
            resize_y = 120
            cut_y_rate = 0.06
            pilImg = Image.fromarray(pygame.surfarray.array3d(self.screen))
            resizedImg = pilImg.resize((resize_x, resize_y), Image.LANCZOS)
            self.observation_space = np.asarray(resizedImg)[:][int(resize_y*cut_y_rate):]
            return None

        observation_list = list()
        for index, alien in enumerate(self.alien_list):
            if alien.alive:
                observation_list.append((alien.rect.center[0] - self.player.rect.center[0])/640)
                observation_list.append((alien.rect.center[1] - self.player.rect.center[1])/480)
                observation_list.append(alien.speed/2)
            else:
                observation_list.extend([0, 0, 0])
        observation_list.append(len(self.alien_list))

        # ビーム位置情報
        beam_check_range = np.array([64, 64, 192, 24])  # L,R,U,D(学習に使う範囲を指定)
        margin = 60  # ビーム位置用行列の余白
        compressed_rate = 5  # ビーム位置行列の圧縮

        comp_beam_check_range = beam_check_range // compressed_rate
        beam_pos_all = np.zeros(((640 + margin * 2) // compressed_rate
                                 , (480 + margin * 2) // compressed_rate), dtype=int)
        for index, beam in enumerate(self.beams):
            beam_pos_all[(beam.rect.center[0] + margin) // compressed_rate] \
                [(beam.rect.center[1] + margin) // compressed_rate] = 1

        comp_x_pos = (self.player.rect.center[0] + margin) // compressed_rate
        comp_y_pos = (self.player.rect.center[1] + margin) // compressed_rate
        beam_pos = beam_pos_all[
                   comp_x_pos - comp_beam_check_range[0]:comp_x_pos + comp_beam_check_range[1],
                   comp_y_pos - comp_beam_check_range[2]:comp_y_pos + comp_beam_check_range[3]
                   ]

        observation_list.extend(beam_pos.flatten())

        for index, shot in enumerate(self.shots):
            observation_list.append((shot.rect.center[0] - self.player.rect.center[0])/640)
            observation_list.append((shot.rect.center[1] - self.player.rect.center[1])/480)
        observation_list.extend([0 for _ in range(10- len(self.shots)*2)])

        for ufo in self.ufos:
            observation_list.append((ufo.rect.center[0] - self.player.rect.center[0])/640)
            observation_list.append((ufo.rect.center[1] - self.player.rect.center[1])/480)
            observation_list.append(ufo.speed)
        observation_list.extend([0 for _ in range(3 - len(self.ufos) * 3)])

        observation_list.append((self.player.rect.center[0]-320)/320)
        observation_list.append((self.player.rect.center[1]-240)/240)
        observation_list.append(self.player.reload_timer/15)
        self.observation_space = np.array(observation_list, dtype=float)
        pass

画像を渡す場合は、そのままだと(640, 480, 3)という大きなデータになってしまうので、Pillow を使って画像を縮小しています。
縮小された画像
横160 x 縦120 というかなり小さい画像でも状況は把握できそうですね。
そのあとはcut_y_rate で指定された割合に従って、画面上部の情報表示部を取り除いています。

指定のパラメータを渡す場合には、全ての値の取りうる範囲が近くなるようにそれぞれ適当な値で割っています。
ビーム情報の部分が少し分かりにくいかもしれません。ざっくりと図にすると以下のような意味です。
ビーム情報
margin は、自機が画面端にいても同様に取得するための予備領域です。
行列の範囲外指定を防止しています。

compressed_rate は、行列サイズを小さくするための係数です。
得られたビームの座標をこの値で割ることで、学習に渡すデータ量を削減します。
(beam_check_range や画面サイズも同時にこの値で割っています)
ビームのある位置を「1」として表現している beam_pos_all は元々非常に疎な行列なので、
このような処理をしてもビームの位置情報への影響は軽微です。


class Invader:
    def init_game(self):
        """ゲームオブジェクトを初期化"""
        # 略
        # エイリアンを作成
        self.alien_list = list()
        for i in range(0, 50):
            x = 10 + (i % 10) * 40
            y = 50 + (i // 10) * 40
            self.alien_list.append(Alien((x,y), self.wave))
        # 壁を作成
        '''
        self.wall_list = list()
        for i in range(4):
            x = 95 + i * 150
            y = 400
            self.wall_list.append(Wall((x, y), self.wave))
        '''
        self.num_aliens = len(self.aliens)
        self.prev_num_aliens = self.num_aliens
        self.prev_lives = self.lives

init_game にエイリアンのリストを取得する部分と、あとで報酬の決定に必要なパラメータ num_aliens 以下を追加します。
これで学習環境の情報が用意できるようになりました。

reset(学習環境のリセット)を実装する

学習器から reset の指示があったときの処理を実装します。

class Invader:
    def reset(self):
        self.score = 0  # スコア初期化
        self.wave = 1
        self.lives = 5
        self.counter = 0
        self.no_kill_counter = 0  # エイリアンを撃破できていない時間(60カウント=1秒)
        self.init_game()  # ゲームを初期化して再開
        self.game_state = PLAY
        self._get_observation()
        return self.observation_space

ほとんどゲームオーバー時の処理と一緒です。
no_kill_counter は後で報酬決定に使う変数です。
reset の返り値で self.observation_space(環境の情報)が必要なので、
先程実装した self._get_observation が早速使われています。

step(学習環境を1ステップ分操作)を実装する

step については最初の方で軽く触れました。
報酬についても確認済みですので、まずはコードを見ましょう。

    def step(self, action):
        if not self.step_flag:
            sys.stdout.write("Not running in stepping mode!\n")
            return None

        reward = 0
        done = False
        self._key_action(action)
        self.update()
        self.draw(self.screen)
        pygame.display.update()
        self._get_observation()
        self.no_kill_counter += 1

        if self.player.rect.center[0] < 80 or self.player.rect.center[0] > 560:
            reward += -1
        if self.prev_lives > self.lives:
            reward += -15
            self.prev_lives = self.lives
        elif self.prev_lives < self.lives:
            reward += 15
            self.prev_lives = self.lives

        num_aliens = len(self.aliens)
        if self.prev_num_aliens > num_aliens:
            reward += 1
            self.no_kill_counter = 0
            self.prev_num_aliens = num_aliens
        if self.no_kill_counter > 900:
            reward += -1

        if num_aliens == 0:
            reward += 300
            done = True
            sys.stdout.write("Cleared!\n")
        elif self.lives < 0:
            reward += -20
            done = True
        elif self.counter == 4500:
            reward += -1000
            done = True
            sys.stdout.write("Time up!\n")
        return self.observation_space, reward, done, dict()

エイリアンが全滅、残機0で撃墜される、タイムオーバーのどれかで1周終了ということで、done=True を返します。
ほとんどが報酬の決定処理なので、難しいところは無いでしょう。

ちなみに演算子を「+=」に統一し、右辺にマイナスの値が来るのは意図的なものです。
個人的には「-=」と「+=」を使い分けるより把握しやすいと思うのですが、どうでしょうか。

最後に Invader クラスを修正する

最後に、ステップ実行をするかどうか、環境の情報は画像か、を指定して使い分けられるように修正すれば完成です。

class Invader:
    def __init__(self, step=False, image=False):
        self.lives = 5  # 残機数
        self.wave = 1  # Wave数
        self.counter = 0  # タイムカウンター(60カウント=1秒)
        self.score = 0  # スコア(エイリアン 10点、UFO 50点にWaveを乗算)
        self.step_flag = step  # 学習用にstep実行するかどうか
        self.image_flag = image  # 学習用データを画像にするかどうか
        self.observation_space = None  # 学習環境の情報
        self.num_aliens = None  # エイリアンの数
        self.prev_lives = None  # 直前のステップの残機
        self.prev_num_aliens = None  # 直前のステップのエイリアン数
        # 略
        # メインループ開始
        if not self.step_flag: # 通常モード
            clock = pygame.time.Clock()
            while True:
                clock.tick(60)
                self.update()
                self.draw(self.screen)
                pygame.display.update()
                self.key_handler()
        else: # step実行時はタイトル画面を飛ばす
            self.game_state = PLAY

ステップ実行の場合は、外からゲームを制御するのでメインループは不要です。

まとめ

残りのコード変更を全て詰め込んだので、かなりの長さになってしまいました。
何はともあれ、これで強化学習の準備ができました。

今回のデータはこちらから(再掲) -> pygame_invader_4.zip

#!/usr/bin/env python
#coding: utf-8
import pygame
from pygame.locals import *
import os
import random
import sys

import numpy as np
from PIL import Image

START, PLAY, GAMEOVER, STAGECLEAR = (0, 1, 2, 3)  # ゲーム状態
SCR_RECT = Rect(0, 0, 640, 480)
PLAYER_MOVE_RECT = Rect(40, 0, 560, 480)


class Invader:
    def __init__(self, step=False, image=False):
        self.lives = 5  # 残機数
        self.wave = 1  # Wave数
        self.counter = 0  # タイムカウンター(60カウント=1秒)
        self.score = 0  # スコア(エイリアン 10点、UFO 50点にWaveを乗算)
        self.step_flag = step  # 学習用にstep実行するかどうか
        self.image_flag = image  # 学習用データを画像にするかどうか
        self.observation_space = None  # 学習環境の情報
        self.num_aliens = None  # エイリアンの数
        self.prev_lives = None  # 直前のステップの残機
        self.prev_num_aliens = None  # 直前のステップのエイリアン数
        pygame.init()
        self.screen = pygame.display.set_mode(SCR_RECT.size, pygame.DOUBLEBUF)
        pygame.display.set_caption("Invader Part4")
        # 素材のロード
        self.load_images()
        self.load_sounds()
        # ゲームオブジェクトを初期化
        self.init_game()
        self._get_observation()
        # メインループ開始
        if not self.step_flag: # 通常モード
            clock = pygame.time.Clock()
            while True:
                clock.tick(60)
                self.update()
                self.draw(self.screen)
                pygame.display.update()
                self.key_handler()
        else: # step実行時はタイトル画面を飛ばす
            self.game_state = PLAY

    # ここからgym互換関数群
    def reset(self):
        self.score = 0  # スコア初期化
        self.wave = 1
        self.lives = 5
        self.counter = 0
        self.no_kill_counter = 0  # エイリアンを撃破できていない時間(60カウント=1秒)
        self.init_game()  # ゲームを初期化して再開
        self.game_state = PLAY
        self._get_observation()
        return self.observation_space

    def render(self, mode=None):
        pass

    def step(self, action):
        if not self.step_flag:
            sys.stdout.write("Not running in stepping mode!\n")
            return None

        reward = 0
        done = False
        self._key_action(action)
        self.update()
        self.draw(self.screen)
        pygame.display.update()
        self._get_observation()
        self.no_kill_counter += 1

        if self.player.rect.center[0] < 80 or self.player.rect.center[0] > 560:
            reward += -1
        if self.prev_lives > self.lives:
            reward += -15
            self.prev_lives = self.lives
        elif self.prev_lives < self.lives:
            reward += 15
            self.prev_lives = self.lives

        num_aliens = len(self.aliens)
        if self.prev_num_aliens > num_aliens:
            reward += 1
            self.no_kill_counter = 0
            self.prev_num_aliens = num_aliens
        if self.no_kill_counter > 900:
            reward += -1

        if num_aliens == 0:
            reward += 300
            done = True
            sys.stdout.write("Cleared!\n")
        elif self.lives < 0:
            reward += -20
            done = True
        elif self.counter == 4500:
            reward += -1000
            done = True
            sys.stdout.write("Time up!\n")
        return self.observation_space, reward, done, dict()

    def _key_action(self, action): # キー入力を代替
        # 常にミサイル発射
        if self.player.reload_timer == 0:
            Shot(self.player.rect.center)
            self.player.reload_timer = self.player.reload_time
        # 左移動
        if action == 1:
            self.player.rect.move_ip(-self.player.speed, 0)
            self.player.rect.clamp_ip(PLAYER_MOVE_RECT)
        # 右移動
        elif action == 2:
            self.player.rect.move_ip(self.player.speed, 0)
            self.player.rect.clamp_ip(PLAYER_MOVE_RECT)
        pass

    def _get_observation(self):
        if self.image_flag:
            resize_x = 160
            resize_y = 120
            cut_y_rate = 0.06
            pilImg = Image.fromarray(pygame.surfarray.array3d(self.screen))
            resizedImg = pilImg.resize((resize_x, resize_y), Image.LANCZOS)
            self.observation_space = np.asarray(resizedImg)[:][int(resize_y*cut_y_rate):]
            return None

        observation_list = list()
        for index, alien in enumerate(self.alien_list):
            if alien.alive:
                observation_list.append((alien.rect.center[0] - self.player.rect.center[0])/640)
                observation_list.append((alien.rect.center[1] - self.player.rect.center[1])/480)
                observation_list.append(alien.speed/2)
            else:
                observation_list.extend([0, 0, 0])
        observation_list.append(len(self.alien_list))

        # ビーム位置情報
        beam_check_range = np.array([64, 64, 192, 24])  # L,R,U,D(学習に使う範囲を指定)
        margin = 60  # ビーム位置用行列の余白
        compressed_rate = 5  # ビーム位置行列の圧縮

        comp_beam_check_range = beam_check_range // compressed_rate
        beam_pos_all = np.zeros(((640 + margin * 2) // compressed_rate
                                 , (480 + margin * 2) // compressed_rate), dtype=int)
        for index, beam in enumerate(self.beams):
            beam_pos_all[(beam.rect.center[0] + margin) // compressed_rate] \
                [(beam.rect.center[1] + margin) // compressed_rate] = 1

        comp_x_pos = (self.player.rect.center[0] + margin) // compressed_rate
        comp_y_pos = (self.player.rect.center[1] + margin) // compressed_rate
        beam_pos = beam_pos_all[
                   comp_x_pos - comp_beam_check_range[0]:comp_x_pos + comp_beam_check_range[1],
                   comp_y_pos - comp_beam_check_range[2]:comp_y_pos + comp_beam_check_range[3]
                   ]

        observation_list.extend(beam_pos.flatten())

        for index, shot in enumerate(self.shots):
            observation_list.append((shot.rect.center[0] - self.player.rect.center[0])/640)
            observation_list.append((shot.rect.center[1] - self.player.rect.center[1])/480)
        observation_list.extend([0 for _ in range(10- len(self.shots)*2)])

        for ufo in self.ufos:
            observation_list.append((ufo.rect.center[0] - self.player.rect.center[0])/640)
            observation_list.append((ufo.rect.center[1] - self.player.rect.center[1])/480)
            observation_list.append(ufo.speed)
        observation_list.extend([0 for _ in range(3 - len(self.ufos) * 3)])

        observation_list.append((self.player.rect.center[0]-320)/320)
        observation_list.append((self.player.rect.center[1]-240)/240)
        observation_list.append(self.player.reload_timer/15)
        self.observation_space = np.array(observation_list, dtype=float)

        pass

    # ここまでgym互換関数群

    def init_game(self):
        """ゲームオブジェクトを初期化"""
        # ゲーム状態
        self.game_state = START
        # スプライトグループを作成して登録
        self.all = pygame.sprite.RenderUpdates()
        self.invisible = pygame.sprite.RenderUpdates()
        self.aliens = pygame.sprite.Group()  # エイリアングループ
        self.shots = pygame.sprite.Group()   # ミサイルグループ
        self.beams = pygame.sprite.Group()   # ビームグループ
        self.walls = pygame.sprite.Group()  # 壁グループ
        self.ufos = pygame.sprite.Group()  # UFOグループ
        # デフォルトスプライトグループを登録
        Player.containers = self.all
        Shot.containers = self.all, self.shots, self.invisible
        Alien.containers = self.all, self.aliens, self.invisible
        Beam.containers = self.all, self.beams, self.invisible
        Wall.containers = self.all, self.walls, self.invisible
        UFO.containers = self.all, self.ufos, self.invisible
        Explosion.containers = self.all, self.invisible
        ExplosionWall.containers = self.all, self.invisible
        # 自機を作成
        self.player = Player()
        # エイリアンを作成
        self.alien_list = list()
        for i in range(0, 50):
            x = 10 + (i % 10) * 40
            y = 50 + (i // 10) * 40
            self.alien_list.append(Alien((x,y), self.wave))
        # 壁を作成
        '''
        self.wall_list = list()
        for i in range(4):
            x = 95 + i * 150
            y = 400
            self.wall_list.append(Wall((x, y), self.wave))
        '''
        self.num_aliens = len(self.aliens)
        self.prev_num_aliens = self.num_aliens
        self.prev_lives = self.lives

    def update(self):
        """ゲーム状態の更新"""
        if self.game_state == PLAY:
            # タイムカウンターを進める
            self.counter += 1
            # リロード時間を減らす
            if self.player.reload_timer > 0:
                self.player.reload_timer -= 1
            # UFOの出現判定(15秒後に出現する)
            if self.counter == 900:
                UFO((10, 30), self.wave)

            self.all.update()
            # エイリアンの方向転換判定
            turn_flag = False
            for alien in self.aliens:
                if (alien.rect.center[0] < 10 and alien.speed < 0) or \
                        (alien.rect.center[0] > SCR_RECT.width-10 and alien.speed > 0):
                    turn_flag = True
                    break
            if turn_flag:
                for alien in self.aliens:
                    alien.speed *= -1
            # エイリアンの追加ビーム判定(プレイヤーが近くにいると反応する)
            for alien in self.aliens:
                alien.shoot_extra_beam(self.player.rect.center[0], 32, 2)
            # ミサイルとエイリアン、壁の衝突判定
            self.collision_detection()
            # エイリアンをすべて倒したら次のステージへ
            if len(self.aliens.sprites()) == 0:
                self.game_state = STAGECLEAR

    def draw(self, screen):
        """描画"""
        screen.fill((0, 0, 0))
        if self.game_state == START:  # スタート画面
            # タイトルを描画
            title_font = pygame.font.SysFont(None, 80)
            title = title_font.render("INVADER GAME", False, (255,0,0))
            screen.blit(title, ((SCR_RECT.width-title.get_width())//2, 100))
            # エイリアンを描画
            alien_image = Alien.images[0]
            screen.blit(alien_image, ((SCR_RECT.width-alien_image.get_width())//2, 200))
            # PUSH STARTを描画
            push_font = pygame.font.SysFont(None, 40)
            push_space = push_font.render("PUSH SPACE KEY", False, (255,255,255))
            screen.blit(push_space, ((SCR_RECT.width-push_space.get_width())//2, 300))
            # クレジットを描画
            credit_font = pygame.font.SysFont(None, 20)
            credit = credit_font.render("2019 http://pygame.skr.jp", False, (255,255,255))
            screen.blit(credit, ((SCR_RECT.width-credit.get_width())//2, 380))
        elif self.game_state == PLAY:  # ゲームプレイ画面
            # 無敵時間中は自機が点滅する
            if self.player.invisible % 10 > 4:
                self.invisible.draw(screen)
            else:
                self.all.draw(screen)
            # wave数と残機数を描画
            stat_font = pygame.font.SysFont(None, 20)
            stat = stat_font.render("Wave:{:2d}  Lives:{:2d}  Score:{:05d}  pos_X:{:3d}".format(
                                        self.wave, self.lives, self.score,
                                        self.player.rect.center[0]), False, (255,255,255))
            screen.blit(stat, ((SCR_RECT.width - stat.get_width()) // 2, 10))
            # 壁の耐久力描画
            shield_font = pygame.font.SysFont(None, 30)
            for wall in self.walls:
                shield = shield_font.render(str(wall.shield), False, (0,0,0))
                text_size = shield_font.size(str(wall.shield))
                screen.blit(shield, (wall.rect.center[0]-text_size[0]//2,
                                     wall.rect.center[1]-text_size[1]//2))
        elif self.game_state == GAMEOVER:  # ゲームオーバー画面
            # GAME OVERを描画
            gameover_font = pygame.font.SysFont(None, 80)
            gameover = gameover_font.render("GAME OVER", False, (255,0,0))
            screen.blit(gameover, ((SCR_RECT.width-gameover.get_width())//2, 100))
            # エイリアンを描画
            alien_image = Alien.images[0]
            screen.blit(alien_image, ((SCR_RECT.width-alien_image.get_width())//2, 200))
            # PUSH SPACEを描画
            push_font = pygame.font.SysFont(None, 40)
            push_space = push_font.render("PUSH SPACE KEY", False, (255,255,255))
            screen.blit(push_space, ((SCR_RECT.width-push_space.get_width())//2, 300))

        elif self.game_state == STAGECLEAR:  # ステージクリア画面
            # wave数と残機数を描画
            stat_font = pygame.font.SysFont(None, 20)
            stat = stat_font.render("Wave:{:2d}  Lives:{:2d}  Score:{:05d}".format(
                self.wave, self.lives, self.score), False, (255,255,255))
            screen.blit(stat, ((SCR_RECT.width - stat.get_width()) // 2, 10))
            # STAGE CLEARを描画
            gameover_font = pygame.font.SysFont(None, 80)
            gameover = gameover_font.render("STAGE CLEAR", False, (255,0,0))
            screen.blit(gameover, ((SCR_RECT.width-gameover.get_width())//2, 100))
            # エイリアンを描画
            alien_image = Alien.images[0]
            screen.blit(alien_image, ((SCR_RECT.width-alien_image.get_width())//2, 200))
            # PUSH SPACEを描画
            push_font = pygame.font.SysFont(None, 40)
            push_space = push_font.render("PUSH SPACE KEY", False, (255,255,255))
            screen.blit(push_space, ((SCR_RECT.width-push_space.get_width())//2, 300))

    def key_handler(self):
        """キーハンドラー"""
        for event in pygame.event.get():
            if event.type == QUIT:
                pygame.quit()
                sys.exit()
            elif event.type == KEYDOWN and event.key == K_ESCAPE:
                pygame.quit()
                sys.exit()
            elif event.type == KEYDOWN and event.key == K_SPACE:
                if self.game_state == START:  # スタート画面でスペースを押したとき
                    self.game_state = PLAY
                elif self.game_state == GAMEOVER:  # ゲームオーバー画面でスペースを押したとき
                    self.score = 0  # スコア初期化
                    self.wave = 1
                    self.lives = 5
                    self.counter = 0
                    self.init_game()  # ゲームを初期化して再開
                    self.game_state = PLAY
                elif self.game_state == STAGECLEAR:
                    self.wave += 1
                    self.lives += 1
                    self.counter = 0
                    self.init_game()  # ゲームを初期化して再開
                    self.game_state = PLAY


    def collision_detection(self):
        """衝突判定"""
        # エイリアンとミサイルの衝突判定
        alien_collided = pygame.sprite.groupcollide(self.aliens, self.shots, True, True)
        for alien in alien_collided.keys():
            Alien.kill_sound.play()
            self.score += 10 * self.wave
            Explosion(alien.rect.center)  # エイリアンの中心で爆発
        # UFOとミサイルの衝突判定
        ufo_collided = pygame.sprite.groupcollide(self.ufos, self.shots, True, True)
        for ufo in ufo_collided.keys():
            Alien.kill_sound.play()
            self.score += 50 * self.wave
            Explosion(ufo.rect.center)
            self.lives += 1
        # プレイヤーとビームの衝突判定
        # 無敵時間中なら判定せずに無敵時間を1減らす
        if self.player.invisible > 0:
            beam_collided = False
            self.player.invisible -= 1
        else:
            beam_collided = pygame.sprite.spritecollide(self.player, self.beams, True)
        if beam_collided:  # プレイヤーと衝突したビームがあれば
            Player.bomb_sound.play()
            Explosion(self.player.rect.center)
            self.lives -= 1
            self.player.invisible = 0 # DQN用は無敵時間無し
            if self.lives < 0:
                self.game_state = GAMEOVER  # ゲームオーバー!
        # 壁とミサイル、ビームの衝突判定
        hit_dict = pygame.sprite.groupcollide(self.walls, self.shots, False, True)
        hit_dict.update(pygame.sprite.groupcollide(self.walls, self.beams, False, True))
        for hit_wall in hit_dict:
            hit_wall.shield -= len(hit_dict[hit_wall])
            for hit_beam in hit_dict[hit_wall]:
                Alien.kill_sound.play()
                Explosion(hit_beam.rect.center)  # ミサイル・ビームの当たった場所で爆発
            if hit_wall.shield <= 0:
                hit_wall.kill()
                Alien.kill_sound.play()
                ExplosionWall(hit_wall.rect.center)  # 壁の中心で爆発

    def load_images(self):
        """イメージのロード"""
        # スプライトの画像を登録
        Player.image = load_image("player.png")
        Shot.image = load_image("shot.png")
        Alien.images = split_image(load_image("alien.png"), 2)
        UFO.images = split_image(load_image("ufo.png"), 2)
        Beam.image = load_image("beam.png")
        Wall.image = load_image("wall.png")
        Explosion.images = split_image(load_image("explosion.png"), 16)
        ExplosionWall.images = split_image(load_image("explosion2.png"), 16)

    def load_sounds(self):
        """サウンドのロード"""
        Alien.kill_sound = load_sound("kill.wav")
        Player.shot_sound = load_sound("shot.wav")
        Player.bomb_sound = load_sound("bomb.wav")


class Player(pygame.sprite.Sprite):
    """自機"""
    speed = 5  # 移動速度
    reload_time = 15  # リロード時間
    invisible = 0  # 無敵時間
    def __init__(self):
        # imageとcontainersはmain()でセット
        pygame.sprite.Sprite.__init__(self, self.containers)
        self.rect = self.image.get_rect()
        self.rect.center = (SCR_RECT.width//2, SCR_RECT.bottom - 9)
        self.reload_timer = 0
    def update(self):
        # 押されているキーをチェック
        pressed_keys = pygame.key.get_pressed()
        # 押されているキーに応じてプレイヤーを移動
        if pressed_keys[K_LEFT]:
            self.rect.move_ip(-self.speed, 0)
        elif pressed_keys[K_RIGHT]:
            self.rect.move_ip(self.speed, 0)
        #self.rect.clamp_ip(SCR_RECT)
        self.rect.clamp_ip(PLAYER_MOVE_RECT)
        # ミサイルの発射
        if pressed_keys[K_SPACE]:
            # リロード時間が0になるまで再発射できない
            '''
            if self.reload_timer > 0:
                # リロード中
                self.reload_timer -= 1
            '''
            if self.reload_timer == 0:
                # 発射!!!
                Player.shot_sound.play()
                Shot(self.rect.center)  # 作成すると同時にallに追加される
                self.reload_timer = self.reload_time


class Alien(pygame.sprite.Sprite):
    """エイリアン"""
    def __init__(self, pos, wave):
        self.speed = 1 + wave  # 移動速度
        self.animcycle = 18  # アニメーション速度
        self.frame = 0
        self.prob_beam = (1.5 + wave) * 0.002  # ビームを発射する確率
        # imagesとcontainersはmain()でセット
        pygame.sprite.Sprite.__init__(self, self.containers)
        self.image = self.images[0]
        self.rect = self.image.get_rect()
        self.rect.center = pos

    def update(self):
        # 横方向への移動
        self.rect.move_ip(self.speed, 0)
        # ビームを発射
        if random.random() < self.prob_beam:
            Beam(self.rect.center)
        # キャラクターアニメーション
        self.frame += 1
        self.image = self.images[self.frame//self.animcycle%2]

    def shoot_extra_beam(self, target_x_pos, border_dist, rate):
        if random.random() < self.prob_beam*rate and \
                abs(self.rect.center[0] - target_x_pos) < border_dist:
            Beam(self.rect.center)


class UFO(pygame.sprite.Sprite):
    """UFO"""
    def __init__(self, pos, wave):
        self.speed = 1 + wave//2  # 移動速度
        # side => 0: left, 1: right
        side = 0 if random.random() < 0.5 else 1
        if side:
            self.speed *= -1 # 右から出現する場合、速度を反転する
        self.animcycle = 18  # アニメーション速度
        self.frame = 0
        # imagesとcontainersはmain()でセット
        pygame.sprite.Sprite.__init__(self, self.containers)
        self.image = self.images[0]
        self.rect = self.image.get_rect()
        self.rect.center = (SCR_RECT.width - pos[0] if side else pos[0], pos[1]) # 開始位置(x)
        self.pos_kill = pos[0] if side else SCR_RECT.width - pos[0] # 消滅位置(x)
    def update(self):
        # 横方向への移動
        self.rect.move_ip(self.speed, 0)
        # 指定位置まで来たら消滅
        if (self.rect.center[0] > self.pos_kill and self.speed > 0) or \
            (self.rect.center[0] < self.pos_kill and self.speed < 0):
            self.kill()
        # キャラクターアニメーション
        self.frame += 1
        self.image = self.images[int(self.frame//self.animcycle%2)]


class Shot(pygame.sprite.Sprite):
    """プレイヤーが発射するミサイル"""
    speed = 12  # 移動速度
    def __init__(self, pos):
        # imageとcontainersはmain()でセット
        pygame.sprite.Sprite.__init__(self, self.containers)
        self.rect = self.image.get_rect()
        self.rect.center = pos  # 中心座標をposに
    def update(self):
        self.rect.move_ip(0, -self.speed)  # 上へ移動
        if self.rect.top < 0:  # 上端に達したら除去
            self.kill()


class Beam(pygame.sprite.Sprite):
    """エイリアンが発射するビーム"""
    speed = 5  # 移動速度
    def __init__(self, pos):
        # imagesとcontainersはmain()でセット
        pygame.sprite.Sprite.__init__(self, self.containers)
        self.rect = self.image.get_rect()
        self.rect.center = pos
    def update(self):
        self.rect.move_ip(0, self.speed)  # 下へ移動
        if self.rect.bottom > SCR_RECT.height:  # 下端に達したら除去
            self.kill()


class Explosion(pygame.sprite.Sprite):
    """爆発エフェクト"""
    animcycle = 2  # アニメーション速度
    frame = 0

    def __init__(self, pos):
        # imagesとcontainersはmain()でセット
        pygame.sprite.Sprite.__init__(self, self.containers)
        self.image = self.images[0]
        self.rect = self.image.get_rect()
        self.rect.center = pos
        self.max_frame = len(self.images) * self.animcycle  # 消滅するフレーム

    def update(self):
        # キャラクターアニメーション
        self.image = self.images[self.frame//self.animcycle]
        self.frame += 1
        if self.frame == self.max_frame:
            self.kill()  # 消滅


class ExplosionWall(Explosion):
    pass


class Wall(pygame.sprite.Sprite):
    """ミサイル・ビームを防ぐ壁"""

    def __init__(self, pos, wave):
        self.shield = 80 + 20 * wave  # 耐久力
        # imagesとcontainersはmain()でセット
        pygame.sprite.Sprite.__init__(self, self.containers)
        self.rect = self.image.get_rect()
        self.rect.center = pos

    def update(self):
        pass


def load_image(filename, colorkey=None):
    """画像をロードして画像と矩形を返す"""
    filename = os.path.join("data", filename)
    try:
        image = pygame.image.load(filename)
    except pygame.error as message:
        print("Cannot load image:", filename)
        raise SystemExit(message)
    image = image.convert()
    if colorkey is not None:
        if colorkey is -1:
            colorkey = image.get_at((0,0))
        image.set_colorkey(colorkey, RLEACCEL)
    return image


def split_image(image, n):
    """横に長いイメージを同じ大きさのn枚のイメージに分割
    分割したイメージを格納したリストを返す"""
    image_list = []
    w = image.get_width()
    h = image.get_height()
    w1 = w // n
    for i in range(0, w, w1):
        surface = pygame.Surface((w1,h))
        surface.blit(image, (0,0), (i,0,w1,h))
        surface.set_colorkey(surface.get_at((0,0)), RLEACCEL)
        surface.convert()
        image_list.append(surface)
    return image_list


def load_sound(filename):
    """サウンドをロード"""
    filename = os.path.join("data", filename)
    return pygame.mixer.Sound(filename)


if __name__ == "__main__":
    Invader()

この記事が気に入ったら
いいね ! しよう

Twitter で
フォームズ編集部
オフィスで働く方、ネットショップやホームページを運営されている皆様へ、ネットを使った仕事の効率化、Webマーケティングなど役立つ情報をお送りしています。