From b70fc4d51e76fc023afcd005c467d415c0c62750 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 5 Mar 2025 13:27:53 -0800 Subject: [PATCH] model: Don't unconditionally add special tokens We sometimes tokenize partial strings. For example, with multimodal inputs, we split the input string around the images and then tokenize each piece. In these cases, we should only add the special tokens on the first piece. --- llm/server.go | 2 +- model/process_text.go | 6 +++--- model/process_text_test.go | 14 +++++++------- runner/ollamarunner/runner.go | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/llm/server.go b/llm/server.go index 09690a5f..9553ba8f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -973,7 +973,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) return s.llamaModel.Tokenize(content, false, true) } if s.textProcessor != nil { - tokens, err := s.textProcessor.Encode(content) + tokens, err := s.textProcessor.Encode(content, false) if err != nil { return nil, err } diff --git a/model/process_text.go b/model/process_text.go index 7083f36f..bfb0a5f2 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -19,7 +19,7 @@ const ( ) type TextProcessor interface { - Encode(string) ([]int32, error) + Encode(s string, addSpecial bool) ([]int32, error) Decode([]int32) (string, error) Is(int32, Special) bool } @@ -144,7 +144,7 @@ type merge struct { runes []rune } -func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { +func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range bpe.vocab.SpecialVocabulary() { // TODO: process special tokens concurrently @@ -282,7 +282,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { } } - if len(ids) > 0 { + if addSpecial && len(ids) > 0 { if bpe.vocab.AddBOS { if ids[0] == bpe.vocab.BOS { slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) diff --git a/model/process_text_test.go b/model/process_text_test.go index cad1f94f..f4830321 100644 --- a/model/process_text_test.go +++ b/model/process_text_test.go @@ -74,7 +74,7 @@ func TestLlama(t *testing.T) { t.Run("simple", func(t *testing.T) { t.Parallel() - ids, err := tokenizer.Encode("hello world") + ids, err := tokenizer.Encode("hello world", true) if err != nil { t.Error(err) } @@ -92,7 +92,7 @@ func TestLlama(t *testing.T) { t.Errorf("got %q, want hello world", s) } - ids, err = tokenizer.Encode("hello <|end_of_text|>") + ids, err = tokenizer.Encode("hello <|end_of_text|>", true) if err != nil { t.Error(err) } @@ -126,7 +126,7 @@ func TestLlama(t *testing.T) { } for s, want := range cases { - ids, err := tokenizer.Encode(s) + ids, err := tokenizer.Encode(s, true) if err != nil { t.Error(err) } @@ -152,7 +152,7 @@ func TestLlama(t *testing.T) { } for _, want := range cases { - ids, err := tokenizer.Encode(want) + ids, err := tokenizer.Encode(want, true) if err != nil { t.Error(err) } @@ -176,7 +176,7 @@ func TestLlama(t *testing.T) { } for s, want := range cases { - ids, err := tokenizer.Encode(s) + ids, err := tokenizer.Encode(s, true) if err != nil { t.Fatal(err) } @@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { b.ResetTimer() for range b.N { - _, err := tokenizer.Encode(string(bts)) + _, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) } @@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { }) b.Run("decode"+strconv.Itoa(n), func(b *testing.B) { - ids, err := tokenizer.Encode(string(bts)) + ids, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 1a4bbf19..9ba6563f 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -161,7 +161,7 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { for i, part := range parts { // text - tokenize - tokens, err := s.model.(model.TextProcessor).Encode(part) + tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) if err != nil { return nil, err }