I thought I'd do something smart and inline all the matrix multiplications into the einsums of the vectorized multi-head attention implementation from the article and set optimize="optimal" to make use of the optimal matrix chain multiplication algorithm https://en.wikipedia.org/wiki/Matrix_chain_multiplication to get a nice performance boost.
This is indeed twice as fast as the vectorized implementation, but, disappointingly, the naive implementation with loops is even faster. Here is the code if someone wants to figure out why the performance is like that: https://pastebin.com/raw/peptFyCw
My guess is that einsum could do a better job of considering cache coherency when evaluating the sum.
To be fair, you could replace `import numpy as np` with `import cupy as np` and it would run on GPU without further changes. It is not any good though. PyTorch is roughly 12 times faster.
My guess is that einsum could do a better job of considering cache coherency when evaluating the sum.