Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Dp4MatMulNBits accuracy level 4 matmul for WebGPU EP #23365

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

sushraja-msft
Copy link
Contributor

@sushraja-msft sushraja-msft commented Jan 14, 2025

Description

This change implements accuracy level 4 - quantize A to int8 matmul for the WebGPU EP. The matmul kernel here uses DP4A for matrix multiplication, in order to keep the DP4A fed co-operative matrix multiplication is implemented which preloads the row/col into local variables before the multiplication operation.

Credits to @qjia7 for help with the quantizer shader.

Performance metrics on intel ADL/TGL GPU.

PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500
Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       2.76762e+06
        **avg (tokens/s): 181.022**   <<< Prefill speed
        p50 (us):       2.74843e+06
        stddev (us):    41756.4
        n:              5 * 501 token(s)
Token generation:
        avg (us):       81500.7
        avg (tokens/s): 12.2698
        p50 (us):       81104.1
        stddev (us):    2961.31
        n:              635 * 1 token(s)
Token sampling:
        avg (us):       13.1836
        avg (tokens/s): 75851.9
        p50 (us):       12
        stddev (us):    6.47085
        n:              640 * 1 token(s)
E2E generation (entire generation loop):
        avg (ms):       13120
        p50 (ms):       13081.6
        stddev (ms):    114.689
        n:              5
Peak working set size (bytes): 5467533312
WebGPU device lost (2): Device was destroyed.

This kernel is 2.10x faster than its F16 counterpart for a 500 token prefill. Previous prefill record is 86tks/s.

In order to support devices with subgroup size 8/32, a no subgroup version of the same shader is included. Performance is slower than the subgroup version on ADL.

PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 
Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       4.11989e+06
        avg (tokens/s): 121.605
        p50 (us):       4.11847e+06
        stddev (us):    2147.48
        n:              5 * 501 token(s)
Token generation:
        avg (us):       81174.9
        avg (tokens/s): 12.3191
        p50 (us):       81301.1
        stddev (us):    2177.2
        n:              635 * 1 token(s)
Token sampling:
        avg (us):       14.7998
        avg (tokens/s): 67568.3
        p50 (us):       12.3
        stddev (us):    11.5481
        n:              640 * 1 token(s)
E2E generation (entire generation loop):
        avg (ms):       14431.1
        p50 (ms):       14433.8
        stddev (ms):    5.02473
        n:              5
Peak working set size (bytes): 5466480640
WebGPU device lost (2): Device was destroyed.

@sushraja-msft sushraja-msft changed the title Dp4MatMulNBits low accuracy matmul for WebGPU EP WIP: Dp4MatMulNBits low accuracy matmul for WebGPU EP Jan 14, 2025
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Jan 16, 2025
@sushraja-msft sushraja-msft changed the title WIP: Dp4MatMulNBits low accuracy matmul for WebGPU EP WIP: Dp4MatMulNBits accuracy level 4 matmul for WebGPU EP Jan 17, 2025
@sushraja-msft sushraja-msft force-pushed the user/sushraja/dp4_matmul branch from 0dd9e67 to 73ee5d1 Compare January 17, 2025 20:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants