additional review comments

This commit is contained in:
Jesse Gross
2025-03-07 11:19:03 -08:00
committed by Michael Yang
parent b27e8f3f10
commit 98272fbd58
2 changed files with 32 additions and 16 deletions

View File

@@ -402,7 +402,10 @@ func (b *Backend) NewContext() ml.Context {
}
func (b *Backend) NewContextSize(n int) ml.Context {
n = min(n, b.maxGraphNodes)
if n > b.maxGraphNodes {
panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
}
return &Context{
b: b,
maxGraphNodes: n,
@@ -534,7 +537,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
panic("unsupported dtype")
}
if len(shape) < 1 {
if len(shape) < 1 || shape[0] == 0 {
var shape C.int64_t = 0
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
} else if len(shape) > 4 {
@@ -565,6 +568,11 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
func checkShape[S ~[]E, E any](s S, shape ...int) error {
n := len(s)
if n == 0 {
return nil
}
for _, v := range shape {
n /= v
}
@@ -577,22 +585,28 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
}
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
if err := checkShape(s, shape...); err != nil && len(shape) > 0 {
if err := checkShape(s, shape...); err != nil {
return nil, err
}
t := c.newTensor(ml.DTypeF32, shape)
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
return t, nil
}
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
if err := checkShape(s, shape...); err != nil && len(shape) > 0 {
if err := checkShape(s, shape...); err != nil {
return nil, err
}
t := c.newTensor(ml.DTypeI32, shape)
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
return t, nil
}