set non-causal attention

This commit is contained in:
Michael Yang
2025-03-07 13:52:45 -08:00
parent 631fecc6d9
commit 0df1800436
6 changed files with 57 additions and 25 deletions

View File

@@ -954,6 +954,20 @@ func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
}
}
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
var tt *C.struct_ggml_tensor
switch len(strides) {
case 0:
tt = C.ggml_set_1d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1:
tt = C.ggml_set_2d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default:
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {