numpy - 在numpy中,相乘多个矩阵

假设你有n个方阵A1 ... An,有没有好的办法相乘这些矩阵?就我所知numpy中的点只接受两个参数,一个明显的方法是定义一个函数来调用它本身并得到结果,有没有更好的方法来完成它?

时间:

 
A.dot(B).dot(C)

 


reduce(numpy.dot, [A1, A2, ..., An])


>>> A = [np.random.random((5, 5)) for i in xrange(4)]
>>> product1 = A[0].dot(A[1]).dot(A[2]).dot(A[3])
>>> product2 = reduce(numpy.dot, A)
>>> numpy.all(product1 == product2)
True

如果你先计算所有矩阵,那么你应该使用矩阵链乘法的优化方案,请参见这篇维基百科文章


A_list = [np.random.randn(100, 100) for i in xrange(10)]
B = np.eye(A_list[0].shape[0])
for A in A_list:
 B = np.dot(B, A)

C = reduce(np.dot, A_list)

assert(B == C)

使用更新恢复旧问题:

2014年11月13日,现在有一个 np.linalg.multi_dot 函数,它可以精确地完成你想要的。 它还具有优化调用顺序的优点,但这在你的情况下是不必要的。

注意,这还没有达到稳定的numpy版本,但是版本 1.10应该包含它。

...