ランダムウォークをできるだけ短く速く書きたい。
ランダムウォークとは
ランダムウォークとは、各ステップごとに上下左右をランダムに決めて1歩進む、という操作を繰り返すことでできる軌跡のことです。
今回はできるだけ短いコードでランダムウォークを表せないか考えてみたいと思います。
普通に書くと
普通にPythonで書くと以下のようになります。
import numpy as np import matplotlib.pylab as plt n = 10000
def randomwalk1(n): a = [[0,0]] for i in range(n): rand = np.random.rand() x = a[-1][0] y = a[-1][1] if rand < 0.25: a.append([x+1,y]) elif 0.25<= rand < 0.5: a.append([x-1,y]) elif 0.5<= rand < 0.75: a.append([x,y+1]) else: a.append([x,y-1]) return np.array(a).T
もうちょっと短く書く方法がありそうな気がします。
Pythonで書く以上、for文を使わない方法を取りたいですね。
numpyの関数を活用してみる
def randomwalk2(): x = [np.random.choice([-1,0,1]),np.random.choice([-1,0,1])] if x[0]*x[1] == 0 and (x[0] != 0 or x[1] != 0): return x else: return randomwalk2() a = np.vstack([np.array([0,0]),np.array([randomwalk2() for i in range(n)])]).T a = [np.cumsum(a[0],dtype=int),np.cumsum(a[1],dype=int)]
前述のコードでは、0から1の間の実数をランダムに返すnp.random.rand()を使いましたが、
今は指定した配列の中からランダムで1つ選ぶnp.random.choice()を使います。
返る配列がxy平面の基底(とその‐1倍)になるようにif文で調節しています。
numpy.random.cumsum()は入力した配列を前から順番に足していく関数です。
あんまり短くなっていないような感じもしますが(可読性も悪いですね)、for文を使ってないだけ良しとしましょう。
処理時間計測
number | time |
1 | 18.9 ms ± 484 µs per loop |
2 | 353 ms ± 66.4 ms per loop |
普通に書いた方が約20倍も速いという結果になりました。
なぜだ!!
randomwalk2()でほしい配列が得られないときは始めに戻ってやり直すという操作をしているのですが、
おそらくそこでタイムロスが生まれているのかなー?
もっといい書き方はありそうですね。
一次元の場合
まず普通に書こうとするとこうなります。
n=10000 a=[0] for i in range(n): rand = np.random.rand() x = a[-1] if rand<0.5: a.append(x+1) else: a.append(x-1)
次にnumpyをフル活用すると
n=10000 a = np.cumsum(np.random.choice([-1,1],size=n))
このように、1行で書けてしまいました。
処理時間も
number | time |
1 | 10.3 ms ± 186 µs per loop |
2 | 211 µs ± 4.34 µs per loop |
約500倍速く計算できました。
一次元ランダムウォークの場合はnumpyの関数をフル活用したほうがいいですね。
6/18追記
すこーし調べたらpython自体のライブラリrandomにchoiceという関数があるみたいです。
numpyの方のchoiceは一次元の要素しか選べないようですが、randomの方は要素がlistでも大丈夫。
つまり[1,0],[-1,0],[0,1],[0,-1]の中から一つ選ぶには、
random.choice([[1,0],[-1,0],[0,1],[0,-1]])
こう書けばいいわけです。
これを使って書き直すと以下のようになります。
def randomwalk2(): a = np.vstack([np.array([0,0]),np.array([random.choice([[1,0],[-1,0],[0,1],[0,-1]]) for i in range(10000)])]).T a = [np.cumsum(a[0],dtype=int),np.cumsum(a[1],dtype=int)] return a
ずいぶん短くすることができました。
ちょっと可読性が悪いので説明しますと、1行目では初期位置(0,0)とステップごとに進む方向を配列にしてくっつけて、
2行目ではその進む方向たちを足し合わせて各ステップごとの位置に直しています。
問題は処理時間ですが、、、
number | time |
1 | 18.9 ms ± 484 µs per loop |
2 | 353 ms ± 66.4 ms per loop |
2' | 150 ns ± 7.96 ns per loop |
はや!!!
速すぎる!!!
普通に書いた時と比べても、12万倍くらい速いことになります。
今までは速くしたければnumpyだと思ってましたが、そんなことないですね。