mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 07:46:59 +00:00
ml: update Context.Forward interface
update Context.Forward to accept multiple tensors to match Context.Compute signature update Context.Forward to return Context such that it can be chained with Context.Compute
This commit is contained in:
@@ -65,7 +65,7 @@ type Context interface {
|
||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||
|
||||
Forward(Tensor)
|
||||
Forward(...Tensor) Context
|
||||
Compute(...Tensor)
|
||||
MaxTensors() int
|
||||
Close()
|
||||
@@ -186,8 +186,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
|
||||
func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||
if t.Bytes() == nil {
|
||||
ctx.Forward(t)
|
||||
ctx.Compute(t)
|
||||
ctx.Forward(t).Compute(t)
|
||||
}
|
||||
|
||||
s := make(S, mul(t.Shape()...))
|
||||
|
||||
@@ -256,12 +256,16 @@ type Context struct {
|
||||
nodes int
|
||||
}
|
||||
|
||||
func (c *Context) Forward(t ml.Tensor) {
|
||||
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
if c.graph == nil {
|
||||
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
|
||||
}
|
||||
|
||||
C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
|
||||
for _, tensor := range tensors {
|
||||
C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
|
||||
Reference in New Issue
Block a user