Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

It's not true that einsum converts to a for loop, and it is sometimes faster than other numpy built-in functions.

https://stackoverflow.com/questions/18365073/why-is-numpys-e...

Did I misunderstood your comment?



If I understand https://github.com/numpy/numpy/blob/v1.22.0/numpy/core/einsu... and https://github.com/numpy/numpy/blob/v1.22.0/numpy/core/src/m... correctly, using einsum without the optimize flag seems to use a for loop in C to do the multiplication.

The optimizer clearly tries to improve the performance, but in many cases, it doesn't seem to change anything. Let's simply multiply some matrices:

  x, y = np.random.rand(200, 200, 200), np.random.rand(200, 200, 200)
I can do

  %timeit x@y
  40.3 ms ± 2.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
or a naive

  %timeit np.einsum('bik,bkj->bij',x,y)
  1.53 s ± 21.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
But even with optimization, I see

  %timeit np.einsum('bik,bkj->bij',x,y, optimize=True)
  1.54 s ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
I'm not sure if I'm doing something wrong.


The looped matrix multiply that you show is very hard to optimize for in the general case of einsum. Often the looped GEMM is found permuted such as `kbi,kjb->bij`. In this case, heuristics are needed to determine if GEMM is worth it due to unaligned memory copies.

`optimize=True` is generally best when there are more than two tensors in the expression.


I just tried replicating the same experiment using Jax's numpy API, and einsum is still slower, but at least the same order of magnitude:

  %timeit (x_jax @ y_jax).block_until_ready()
  579 µs ± 4.54 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

  %timeit jnp.einsum('bik,bkj->bij',x_jax,y_jax, optimize=True).block_until_ready()
  658 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

  %timeit jnp.einsum('bik,bkj->bij',x_jax,y_jax).block_until_ready()
  660 µs ± 2.82 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Iirc pytorch's einsum used to be very slow




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: