add new gemma model (#11204)

* update patches

* cherry pick metal mean kernel

* cherry pick cuda mean kernel

* gemma3n
This commit is contained in:
Michael Yang
2025-06-25 21:47:09 -07:00
committed by GitHub
parent ad118d8b13
commit 73b642e6f3
25 changed files with 6084 additions and 54 deletions

View File

@@ -253,6 +253,7 @@ type Tensor interface {
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
Div(ctx Context, t2 Tensor) Tensor
@@ -276,6 +277,7 @@ type Tensor interface {
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
SILU(ctx Context) Tensor
RELU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor
Reshape(ctx Context, shape ...int) Tensor
@@ -297,6 +299,12 @@ type Tensor interface {
TopK(ctx Context, k int) Tensor
Argsort(ctx Context) Tensor
Mean(ctx Context) Tensor
Variance(ctx Context) Tensor
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Clamp(ctx Context, min, max float32) Tensor
}
// ScaledDotProductAttention implements a fused attention