Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

You should be able to make a few small changes to support "mps".

In TrainingConfig set the device to "mps". The run training.

In sample.py modify parse_args() and add support for mps as a possible value for the --device argument.



Thanks! I'll try. I didn't bother believing that if this was developed heavily on CUDA, it was likely going to use kernels that were missing in MPS.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: