Join the conversation

Join the community of Machine Learners and AI enthusiasts.

Sign Up
JawardΒ 
posted an update 3 days ago
Post
1617
Triton nanoGPT now has a custom cross entropy loss kernel πŸš€
Next: matmul, gradually overthrowing all major PyTorch ops:)

Simplified pseudo for parallel cross-entropy loss compute:
- init program: get pid, compute offsets, load targets.
- init row_max and row_sum.
- for-loop1 (find max logits): update row_max with max logits.
- for-loop2 (compute softmax and loss): compute row_sum, update loss.
- add log(row_sum) and store loss.

Code: https://github.com/Jaykef/ai-algorithms/blob/main/triton_nanoGPT.ipynb
In this post