mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 15:57:04 +00:00
additional review comments
This commit is contained in:
committed by
Michael Yang
parent
b27e8f3f10
commit
98272fbd58
@@ -402,7 +402,10 @@ func (b *Backend) NewContext() ml.Context {
|
||||
}
|
||||
|
||||
func (b *Backend) NewContextSize(n int) ml.Context {
|
||||
n = min(n, b.maxGraphNodes)
|
||||
if n > b.maxGraphNodes {
|
||||
panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
|
||||
}
|
||||
|
||||
return &Context{
|
||||
b: b,
|
||||
maxGraphNodes: n,
|
||||
@@ -534,7 +537,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
|
||||
if len(shape) < 1 {
|
||||
if len(shape) < 1 || shape[0] == 0 {
|
||||
var shape C.int64_t = 0
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
|
||||
} else if len(shape) > 4 {
|
||||
@@ -565,6 +568,11 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
|
||||
func checkShape[S ~[]E, E any](s S, shape ...int) error {
|
||||
n := len(s)
|
||||
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, v := range shape {
|
||||
n /= v
|
||||
}
|
||||
@@ -577,22 +585,28 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
|
||||
}
|
||||
|
||||
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
if err := checkShape(s, shape...); err != nil && len(shape) > 0 {
|
||||
if err := checkShape(s, shape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := c.newTensor(ml.DTypeF32, shape)
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
if err := checkShape(s, shape...); err != nil && len(shape) > 0 {
|
||||
if err := checkShape(s, shape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := c.newTensor(ml.DTypeI32, shape)
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user