backend: API to support full precision matmul

Most tensor backends try to optimize performance by using a lower
precision for matmuls. However, some operations (such as kq) on
some models are sensitive to this and require full precision.
This commit is contained in:
Jesse Gross
2025-02-13 10:01:14 -08:00
committed by Jesse Gross
parent 4d4463b2bd
commit d773b7d671
4 changed files with 12 additions and 2 deletions

View File

@@ -421,6 +421,15 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{
t: mul,
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
if b != nil {