mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-17 19:27:00 +00:00
server/internal: copy bmizerany/ollama-go to internal package (#9294)
This commit copies (without history) the bmizerany/ollama-go repository with the intention of integrating it into the ollama as a replacement for the pushing, and pulling of models, and management of the cache they are pushed and pulled from. New homes for these packages will be determined as they are integrated and we have a better understanding of proper package boundaries.
This commit is contained in:
48
server/internal/internal/backoff/backoff.go
Normal file
48
server/internal/internal/backoff/backoff.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
|
||||
var n int
|
||||
return func(yield func(int, error) bool) {
|
||||
var t *time.Timer
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
yield(n, ctx.Err())
|
||||
return
|
||||
}
|
||||
|
||||
if !yield(n, nil) {
|
||||
return
|
||||
}
|
||||
|
||||
n++
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := time.Duration(n*n) * 10 * time.Millisecond
|
||||
if d > maxBackoff {
|
||||
d = maxBackoff
|
||||
}
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
|
||||
if t == nil {
|
||||
t = time.NewTimer(d)
|
||||
} else {
|
||||
t.Reset(d)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
case <-t.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
40
server/internal/internal/backoff/backoff_synctest_test.go
Normal file
40
server/internal/internal/backoff/backoff_synctest_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
//go:build goexperiment.synctest
|
||||
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoop(t *testing.T) {
|
||||
synctest.Run(func() {
|
||||
last := -1
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
for n, err := range Loop(ctx, 100*time.Millisecond) {
|
||||
if !errors.Is(err, ctx.Err()) {
|
||||
t.Errorf("err = %v, want nil", err)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if n != last+1 {
|
||||
t.Errorf("n = %d, want %d", n, last+1)
|
||||
}
|
||||
last = n
|
||||
if n > 5 {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
if last != 6 {
|
||||
t.Errorf("last = %d, want 6", last)
|
||||
}
|
||||
})
|
||||
}
|
||||
38
server/internal/internal/backoff/backoff_test.go
Normal file
38
server/internal/internal/backoff/backoff_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoopAllocs(t *testing.T) {
|
||||
for i := range 3 {
|
||||
got := testing.AllocsPerRun(1000, func() {
|
||||
for tick := range Loop(t.Context(), 1) {
|
||||
if tick >= i {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
want := float64(0)
|
||||
if i > 0 {
|
||||
want = 3 // due to time.NewTimer
|
||||
}
|
||||
if got > want {
|
||||
t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLoop(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
synctest.Run(func() {
|
||||
for n := range Loop(ctx, 100*time.Millisecond) {
|
||||
if n == b.N {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
229
server/internal/internal/names/name.go
Normal file
229
server/internal/internal/names/name.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package names
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/internal/stringsx"
|
||||
)
|
||||
|
||||
const MaxNameLength = 50 + 1 + 50 + 1 + 50 // <namespace>/<model>:<tag>
|
||||
|
||||
type Name struct {
|
||||
// Make incomparable to enfoce use of Compare / Equal for
|
||||
// case-insensitive comparisons.
|
||||
_ [0]func()
|
||||
|
||||
h string
|
||||
n string
|
||||
m string
|
||||
t string
|
||||
}
|
||||
|
||||
// Parse parses and assembles a Name from a name string. The
|
||||
// format of a valid name string is:
|
||||
//
|
||||
// s:
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag }
|
||||
// { host } "/" { namespace } "/" { model } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { model }
|
||||
// { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { namespace } "/" { model } ":" { tag }
|
||||
// { namespace } "/" { model } "@" { digest }
|
||||
// { namespace } "/" { model }
|
||||
// { model } ":" { tag } "@" { digest }
|
||||
// { model } ":" { tag }
|
||||
// { model } "@" { digest }
|
||||
// { model }
|
||||
// "@" { digest }
|
||||
// host:
|
||||
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
|
||||
// length: [1, 350]
|
||||
// namespace:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
|
||||
// length: [1, 80]
|
||||
// model:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// tag:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// digest:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
|
||||
// length: [1, 80]
|
||||
//
|
||||
// The name returned is not guaranteed to be valid. If it is not valid, the
|
||||
// field values are left in an undefined state. Use [Name.IsValid] to check
|
||||
// if the name is valid.
|
||||
func Parse(s string) Name {
|
||||
if len(s) > MaxNameLength {
|
||||
return Name{}
|
||||
}
|
||||
|
||||
var n Name
|
||||
var tail string
|
||||
var c byte
|
||||
for {
|
||||
s, tail, c = cutLastAny(s, "/:")
|
||||
switch c {
|
||||
case ':':
|
||||
n.t = tail
|
||||
continue // look for model
|
||||
case '/':
|
||||
n.h, n.n, _ = cutLastAny(s, "/")
|
||||
n.m = tail
|
||||
return n
|
||||
case 0:
|
||||
n.m = tail
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ParseExtended parses and returns any scheme, Name, and digest from from s in
|
||||
// the the form [scheme://][name][@digest]. All parts are optional.
|
||||
//
|
||||
// If the scheme is present, it must be followed by "://". The digest is
|
||||
// prefixed by "@" and comes after the name. The name is parsed using [Parse].
|
||||
//
|
||||
// The scheme and digest are stripped before the name is parsed by [Parse].
|
||||
//
|
||||
// For convience, the scheme is never empty. If the scheme is not present, the
|
||||
// returned scheme is "https".
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// http://ollama.com/bmizerany/smol:latest@digest
|
||||
// https://ollama.com/bmizerany/smol:latest
|
||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
||||
func ParseExtended(s string) (scheme string, _ Name, digest string) {
|
||||
i := strings.Index(s, "://")
|
||||
if i >= 0 {
|
||||
scheme = s[:i]
|
||||
s = s[i+3:]
|
||||
}
|
||||
i = strings.LastIndex(s, "@")
|
||||
if i >= 0 {
|
||||
digest = s[i+1:]
|
||||
s = s[:i]
|
||||
}
|
||||
return scheme, Parse(s), digest
|
||||
}
|
||||
|
||||
func FormatExtended(scheme string, n Name, digest string) string {
|
||||
var b strings.Builder
|
||||
if scheme != "" {
|
||||
b.WriteString(scheme)
|
||||
b.WriteString("://")
|
||||
}
|
||||
b.WriteString(n.String())
|
||||
if digest != "" {
|
||||
b.WriteByte('@')
|
||||
b.WriteString(digest)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Merge merges two names into a single name. Non-empty host, namespace, and
|
||||
// tag parts of a take precedence over fields in b. The model field is left as
|
||||
// is.
|
||||
//
|
||||
// The returned name is not guaranteed to be valid. Use [Name.IsValid] to check
|
||||
// if the name is valid.
|
||||
func Merge(a, b Name) Name {
|
||||
a.h = cmp.Or(a.h, b.h)
|
||||
a.n = cmp.Or(a.n, b.n)
|
||||
a.t = cmp.Or(a.t, b.t)
|
||||
return a
|
||||
}
|
||||
|
||||
// IsValid returns true if the name is valid.
|
||||
func (n Name) IsValid() bool {
|
||||
if n.h != "" && !isValidHost(n.h) {
|
||||
return false
|
||||
}
|
||||
if n.n != "" && !isValidNamespace(n.n) {
|
||||
return false
|
||||
}
|
||||
if n.m != "" && !isValidModel(n.m) {
|
||||
return false
|
||||
}
|
||||
if n.t != "" && !isValidTag(n.t) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (n Name) IsFullyQualified() bool {
|
||||
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
|
||||
}
|
||||
|
||||
func isValidHost(_ string) bool {
|
||||
return true // TODO: implement
|
||||
}
|
||||
|
||||
func isValidNamespace(_ string) bool {
|
||||
return true // TODO: implement
|
||||
}
|
||||
|
||||
func isValidModel(_ string) bool {
|
||||
return true // TODO: implement
|
||||
}
|
||||
|
||||
func isValidTag(_ string) bool {
|
||||
return true // TODO: implement
|
||||
}
|
||||
|
||||
func (n Name) Host() string { return n.h }
|
||||
func (n Name) Namespace() string { return n.n }
|
||||
func (n Name) Model() string { return n.m }
|
||||
func (n Name) Tag() string { return n.t }
|
||||
|
||||
// Compare compares n and o case-insensitively. It returns 0 if n and o are
|
||||
// equal, -1 if n sorts before o, and 1 if n sorts after o.
|
||||
func (n Name) Compare(o Name) int {
|
||||
return cmp.Or(
|
||||
stringsx.CompareFold(n.h, o.h),
|
||||
stringsx.CompareFold(n.n, o.n),
|
||||
stringsx.CompareFold(n.m, o.m),
|
||||
stringsx.CompareFold(n.t, o.t),
|
||||
)
|
||||
}
|
||||
|
||||
// String returns the fully qualified name in the format
|
||||
// <namespace>/<model>:<tag>.
|
||||
func (n Name) String() string {
|
||||
var b strings.Builder
|
||||
if n.h != "" {
|
||||
b.WriteString(n.h)
|
||||
b.WriteByte('/')
|
||||
}
|
||||
if n.n != "" {
|
||||
b.WriteString(n.n)
|
||||
b.WriteByte('/')
|
||||
}
|
||||
b.WriteString(n.m)
|
||||
if n.t != "" {
|
||||
b.WriteByte(':')
|
||||
b.WriteString(n.t)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (n Name) GoString() string {
|
||||
return fmt.Sprintf("<Name %q %q %q %q>", n.h, n.n, n.m, n.t)
|
||||
}
|
||||
|
||||
// cutLastAny is like strings.Cut but scans in reverse for the last character
|
||||
// in chars. If no character is found, before is the empty string and after is
|
||||
// s. The returned sep is the byte value of the character in chars if one was
|
||||
// found; otherwise it is 0.
|
||||
func cutLastAny(s, chars string) (before, after string, sep byte) {
|
||||
i := strings.LastIndexAny(s, chars)
|
||||
if i >= 0 {
|
||||
return s[:i], s[i+1:], s[i]
|
||||
}
|
||||
return "", s, 0
|
||||
}
|
||||
152
server/internal/internal/names/name_test.go
Normal file
152
server/internal/internal/names/name_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package names
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want Name
|
||||
}{
|
||||
{"", Name{}},
|
||||
{"m:t", Name{m: "m", t: "t"}},
|
||||
{"m", Name{m: "m"}},
|
||||
{"/m", Name{m: "m"}},
|
||||
{"/n/m:t", Name{n: "n", m: "m", t: "t"}},
|
||||
{"n/m", Name{n: "n", m: "m"}},
|
||||
{"n/m:t", Name{n: "n", m: "m", t: "t"}},
|
||||
{"n/m", Name{n: "n", m: "m"}},
|
||||
{"n/m", Name{n: "n", m: "m"}},
|
||||
{strings.Repeat("m", MaxNameLength+1), Name{}},
|
||||
{"h/n/m:t", Name{h: "h", n: "n", m: "m", t: "t"}},
|
||||
{"ollama.com/library/_:latest", Name{h: "ollama.com", n: "library", m: "_", t: "latest"}},
|
||||
|
||||
// Invalids
|
||||
// TODO: {"n:t/m:t", Name{}},
|
||||
// TODO: {"/h/n/m:t", Name{}},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
got := Parse(tt.in)
|
||||
if got.Compare(tt.want) != 0 {
|
||||
t.Errorf("parseName(%q) = %#v, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
cases := []string{
|
||||
"",
|
||||
"m:t",
|
||||
"m:t",
|
||||
"m",
|
||||
"n/m",
|
||||
"n/m:t",
|
||||
"n/m",
|
||||
"n/m",
|
||||
"h/n/m:t",
|
||||
"ollama.com/library/_:latest",
|
||||
|
||||
// Special cased to "round trip" without the leading slash.
|
||||
"/m",
|
||||
"/n/m:t",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
s = strings.TrimPrefix(s, "/")
|
||||
if g := Parse(s).String(); g != s {
|
||||
t.Errorf("parse(%q).String() = %q", s, g)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseExtended(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
|
||||
wantScheme string
|
||||
wantName Name
|
||||
wantDigest string
|
||||
}{
|
||||
{"", "", Name{}, ""},
|
||||
{"m", "", Name{m: "m"}, ""},
|
||||
{"http://m", "http", Name{m: "m"}, ""},
|
||||
{"http+insecure://m", "http+insecure", Name{m: "m"}, ""},
|
||||
{"http://m@sha256:deadbeef", "http", Name{m: "m"}, "sha256:deadbeef"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
scheme, name, digest := ParseExtended(tt.in)
|
||||
if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
|
||||
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
|
||||
}
|
||||
|
||||
// Round trip
|
||||
if got := FormatExtended(scheme, name, digest); got != tt.in {
|
||||
t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
cases := []struct {
|
||||
a, b string
|
||||
want string
|
||||
}{
|
||||
{"", "", ""},
|
||||
{"m", "", "m"},
|
||||
{"", "m", ""},
|
||||
{"x", "y", "x"},
|
||||
{"o.com/n/m:t", "o.com/n/m:t", "o.com/n/m:t"},
|
||||
{"o.com/n/m:t", "o.com/n/_:t", "o.com/n/m:t"},
|
||||
|
||||
{"bmizerany/smol", "ollama.com/library/_:latest", "ollama.com/bmizerany/smol:latest"},
|
||||
{"localhost:8080/bmizerany/smol", "ollama.com/library/_:latest", "localhost:8080/bmizerany/smol:latest"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
a, b := Parse(tt.a), Parse(tt.b)
|
||||
got := Merge(a, b)
|
||||
if got.Compare(Parse(tt.want)) != 0 {
|
||||
t.Errorf("merge(%q, %q) = %#v, want %q", tt.a, tt.b, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStringRoundTrip(t *testing.T) {
|
||||
cases := []string{
|
||||
"",
|
||||
"m",
|
||||
"m:t",
|
||||
"n/m",
|
||||
"n/m:t",
|
||||
"n/m:t",
|
||||
"n/m",
|
||||
"n/m",
|
||||
"h/n/m:t",
|
||||
"ollama.com/library/_:latest",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
if got := Parse(s).String(); got != s {
|
||||
t.Errorf("parse(%q).String() = %q", s, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var junkName Name
|
||||
|
||||
func BenchmarkParseName(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for range b.N {
|
||||
junkName = Parse("h/n/m:t")
|
||||
}
|
||||
}
|
||||
52
server/internal/internal/stringsx/stringsx.go
Normal file
52
server/internal/internal/stringsx/stringsx.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package stringsx provides additional string manipulation functions
|
||||
// that aren't in the standard library's strings package or go4.org/mem.
|
||||
package stringsx
|
||||
|
||||
import (
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// CompareFold returns -1, 0, or 1 depending on whether a < b, a == b, or a > b,
|
||||
// like cmp.Compare, but case insensitively.
|
||||
func CompareFold(a, b string) int {
|
||||
// Track our position in both strings
|
||||
ia, ib := 0, 0
|
||||
for ia < len(a) && ib < len(b) {
|
||||
ra, wa := nextRuneLower(a[ia:])
|
||||
rb, wb := nextRuneLower(b[ib:])
|
||||
if ra < rb {
|
||||
return -1
|
||||
}
|
||||
if ra > rb {
|
||||
return 1
|
||||
}
|
||||
ia += wa
|
||||
ib += wb
|
||||
if wa == 0 || wb == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If we've reached here, one or both strings are exhausted
|
||||
// The shorter string is "less than" if they match up to this point
|
||||
switch {
|
||||
case ia == len(a) && ib == len(b):
|
||||
return 0
|
||||
case ia == len(a):
|
||||
return -1
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// nextRuneLower returns the next rune in the string, lowercased, along with its
|
||||
// original (consumed) width in bytes. If the string is empty, it returns
|
||||
// (utf8.RuneError, 0)
|
||||
func nextRuneLower(s string) (r rune, width int) {
|
||||
r, width = utf8.DecodeRuneInString(s)
|
||||
return unicode.ToLower(r), width
|
||||
}
|
||||
78
server/internal/internal/stringsx/stringsx_test.go
Normal file
78
server/internal/internal/stringsx/stringsx_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package stringsx
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompareFold(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b string
|
||||
}{
|
||||
// Basic ASCII cases
|
||||
{"", ""},
|
||||
{"a", "a"},
|
||||
{"a", "A"},
|
||||
{"A", "a"},
|
||||
{"a", "b"},
|
||||
{"b", "a"},
|
||||
{"abc", "ABC"},
|
||||
{"ABC", "abc"},
|
||||
{"abc", "abd"},
|
||||
{"abd", "abc"},
|
||||
|
||||
// Length differences
|
||||
{"abc", "ab"},
|
||||
{"ab", "abc"},
|
||||
|
||||
// Unicode cases
|
||||
{"世界", "世界"},
|
||||
{"Hello世界", "hello世界"},
|
||||
{"世界Hello", "世界hello"},
|
||||
{"世界", "世界x"},
|
||||
{"世界x", "世界"},
|
||||
|
||||
// Special case folding examples
|
||||
{"ß", "ss"}, // German sharp s
|
||||
{"fi", "fi"}, // fi ligature
|
||||
{"Σ", "σ"}, // Greek sigma
|
||||
{"İ", "i\u0307"}, // Turkish dotted I
|
||||
|
||||
// Mixed cases
|
||||
{"HelloWorld", "helloworld"},
|
||||
{"HELLOWORLD", "helloworld"},
|
||||
{"helloworld", "HELLOWORLD"},
|
||||
{"HelloWorld", "helloworld"},
|
||||
{"helloworld", "HelloWorld"},
|
||||
|
||||
// Edge cases
|
||||
{" ", " "},
|
||||
{"1", "1"},
|
||||
{"123", "123"},
|
||||
{"!@#", "!@#"},
|
||||
}
|
||||
|
||||
wants := []int{}
|
||||
for _, tt := range tests {
|
||||
got := CompareFold(tt.a, tt.b)
|
||||
want := cmp.Compare(strings.ToLower(tt.a), strings.ToLower(tt.b))
|
||||
if got != want {
|
||||
t.Errorf("CompareFold(%q, %q) = %v, want %v", tt.a, tt.b, got, want)
|
||||
}
|
||||
wants = append(wants, want)
|
||||
}
|
||||
|
||||
if n := testing.AllocsPerRun(1000, func() {
|
||||
for i, tt := range tests {
|
||||
if CompareFold(tt.a, tt.b) != wants[i] {
|
||||
panic("unexpected")
|
||||
}
|
||||
}
|
||||
}); n > 0 {
|
||||
t.Errorf("allocs = %v; want 0", int(n))
|
||||
}
|
||||
}
|
||||
201
server/internal/internal/syncs/line.go
Normal file
201
server/internal/internal/syncs/line.go
Normal file
@@ -0,0 +1,201 @@
|
||||
// Package syncs provides synchronization primitives.
|
||||
package syncs
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var closedChan = func() chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}()
|
||||
|
||||
// Ticket represents a ticket in a sequence of tickets. The zero value is
|
||||
// invalid. Use [Line.Take] to get a valid ticket.
|
||||
//
|
||||
// A Ticket is not safe for concurrent use.
|
||||
type Ticket struct {
|
||||
ahead chan struct{} // ticket ahead of this one
|
||||
ch chan struct{}
|
||||
}
|
||||
|
||||
// Ready returns a channel that is closed when the ticket before this one is
|
||||
// done.
|
||||
//
|
||||
// It is incorrect to wait on Ready after the ticket is done.
|
||||
func (t *Ticket) Ready() chan struct{} {
|
||||
return cmp.Or(t.ahead, closedChan)
|
||||
}
|
||||
|
||||
// Done signals that this ticket is done and that the next ticket in line can
|
||||
// proceed.
|
||||
//
|
||||
// The first call to [Done] unblocks the ticket after it, if any. Subsequent
|
||||
// calls are no-ops.
|
||||
func (t *Ticket) Done() {
|
||||
if t.ch != nil {
|
||||
close(t.ch)
|
||||
}
|
||||
t.ch = nil
|
||||
}
|
||||
|
||||
// Line is an ordered sequence of tickets waiting for their turn to proceed.
|
||||
//
|
||||
// To get a ticket use [Line.Take].
|
||||
// To signal that a ticket is done use [Ticket.Done].
|
||||
// To wait your turn use [Ticket.Ready].
|
||||
//
|
||||
// A Line is not safe for concurrent use.
|
||||
type Line struct {
|
||||
last chan struct{} // last ticket in line
|
||||
}
|
||||
|
||||
func (q *Line) Take() *Ticket {
|
||||
t := &Ticket{
|
||||
ahead: q.last,
|
||||
ch: make(chan struct{}),
|
||||
}
|
||||
q.last = t.ch
|
||||
return t
|
||||
}
|
||||
|
||||
// RelayReader implements an [io.WriterTo] that yields the passed
|
||||
// writer to its [WriteTo] method each [io.WriteCloser] taken from [Take], in
|
||||
// the order they are taken. Each [io.WriteCloser] blocks until the previous
|
||||
// one is closed, or a call to [RelayReader.CloseWithError] is made.
|
||||
//
|
||||
// The zero value is invalid. Use [NewWriteToLine] to get a valid RelayReader.
|
||||
//
|
||||
// It is not safe for concurrent use.
|
||||
type RelayReader struct {
|
||||
line Line
|
||||
t *Ticket
|
||||
w io.Writer
|
||||
n int64
|
||||
|
||||
mu sync.Mutex
|
||||
err error // set by CloseWithError
|
||||
closedCh chan struct{} // closed if err is set
|
||||
}
|
||||
|
||||
var (
|
||||
_ io.Closer = (*RelayReader)(nil)
|
||||
_ io.WriterTo = (*RelayReader)(nil)
|
||||
_ io.Reader = (*RelayReader)(nil)
|
||||
)
|
||||
|
||||
func NewRelayReader() *RelayReader {
|
||||
var q RelayReader
|
||||
q.closedCh = make(chan struct{})
|
||||
q.t = q.line.Take()
|
||||
return &q
|
||||
}
|
||||
|
||||
// CloseWithError terminates the line, unblocking any writer waiting for its
|
||||
// turn with the error, or [io.EOF] if err is nil. It is safe to call
|
||||
// [CloseWithError] multiple times and across multiple goroutines.
|
||||
//
|
||||
// If the line is already closed, [CloseWithError] is a no-op.
|
||||
//
|
||||
// It never returns an error.
|
||||
func (q *RelayReader) CloseWithError(err error) error {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
if q.err == nil {
|
||||
q.err = cmp.Or(q.err, err, io.EOF)
|
||||
close(q.closedCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the line. Any writer waiting for its turn will be unblocked
|
||||
// with an [io.ErrClosedPipe] error.
|
||||
//
|
||||
// It never returns an error.
|
||||
func (q *RelayReader) Close() error {
|
||||
return q.CloseWithError(nil)
|
||||
}
|
||||
|
||||
func (q *RelayReader) closed() <-chan struct{} {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
return q.closedCh
|
||||
}
|
||||
|
||||
func (q *RelayReader) Read(p []byte) (int, error) {
|
||||
panic("RelayReader.Read is for show only; use WriteTo")
|
||||
}
|
||||
|
||||
// WriteTo yields the writer w to the first writer in line and blocks until the
|
||||
// first call to [Close].
|
||||
//
|
||||
// It is safe to call [Take] concurrently with [WriteTo].
|
||||
func (q *RelayReader) WriteTo(dst io.Writer) (int64, error) {
|
||||
select {
|
||||
case <-q.closed():
|
||||
return 0, io.ErrClosedPipe
|
||||
default:
|
||||
}
|
||||
|
||||
// We have a destination writer; let the relay begin.
|
||||
q.w = dst
|
||||
q.t.Done()
|
||||
<-q.closed()
|
||||
return q.n, nil
|
||||
}
|
||||
|
||||
// Take returns a writer that will be passed to the next writer in line.
|
||||
//
|
||||
// It is not safe for use across multiple goroutines.
|
||||
func (q *RelayReader) Take() io.WriteCloser {
|
||||
return &relayWriter{q: q, t: q.line.Take()}
|
||||
}
|
||||
|
||||
type relayWriter struct {
|
||||
q *RelayReader
|
||||
t *Ticket
|
||||
ready bool
|
||||
}
|
||||
|
||||
var _ io.StringWriter = (*relayWriter)(nil)
|
||||
|
||||
// Write writes to the writer passed to [RelayReader.WriteTo] as soon as the
|
||||
// writer is ready. It returns io.ErrClosedPipe if the line is closed before
|
||||
// the writer is ready.
|
||||
func (w *relayWriter) Write(p []byte) (int, error) {
|
||||
if !w.awaitTurn() {
|
||||
return 0, w.q.err
|
||||
}
|
||||
n, err := w.q.w.Write(p)
|
||||
w.q.n += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *relayWriter) WriteString(s string) (int, error) {
|
||||
if !w.awaitTurn() {
|
||||
return 0, w.q.err
|
||||
}
|
||||
return io.WriteString(w.q.w, s)
|
||||
}
|
||||
|
||||
// Close signals that the writer is done, unblocking the next writer in line.
|
||||
func (w *relayWriter) Close() error {
|
||||
w.t.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *relayWriter) awaitTurn() (ok bool) {
|
||||
if t.ready {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-t.t.Ready():
|
||||
t.ready = true
|
||||
return true
|
||||
case <-t.q.closed():
|
||||
return false
|
||||
}
|
||||
}
|
||||
65
server/internal/internal/syncs/line_test.go
Normal file
65
server/internal/internal/syncs/line_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package syncs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
)
|
||||
|
||||
func TestPipelineReadWriterTo(t *testing.T) {
|
||||
for range 10 {
|
||||
synctest.Run(func() {
|
||||
q := NewRelayReader()
|
||||
|
||||
tickets := []struct {
|
||||
io.WriteCloser
|
||||
s string
|
||||
}{
|
||||
{q.Take(), "you"},
|
||||
{q.Take(), " say hi,"},
|
||||
{q.Take(), " and "},
|
||||
{q.Take(), "I say "},
|
||||
{q.Take(), "hello"},
|
||||
}
|
||||
|
||||
rand.Shuffle(len(tickets), func(i, j int) {
|
||||
tickets[i], tickets[j] = tickets[j], tickets[i]
|
||||
})
|
||||
|
||||
var g Group
|
||||
for i, t := range tickets {
|
||||
g.Go(func() {
|
||||
defer t.Close()
|
||||
if i%2 == 0 {
|
||||
// Use [relayWriter.WriteString]
|
||||
io.WriteString(t.WriteCloser, t.s)
|
||||
} else {
|
||||
t.Write([]byte(t.s))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var got bytes.Buffer
|
||||
var copyErr error // checked at end
|
||||
g.Go(func() {
|
||||
_, copyErr = io.Copy(&got, q)
|
||||
})
|
||||
|
||||
synctest.Wait()
|
||||
|
||||
q.Close()
|
||||
g.Wait()
|
||||
|
||||
if copyErr != nil {
|
||||
t.Fatal(copyErr)
|
||||
}
|
||||
|
||||
want := "you say hi, and I say hello"
|
||||
if got.String() != want {
|
||||
t.Fatalf("got %q, want %q", got.String(), want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
41
server/internal/internal/syncs/syncs.go
Normal file
41
server/internal/internal/syncs/syncs.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package syncs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Group is a [sync.WaitGroup] with a Go method.
|
||||
type Group struct {
|
||||
wg sync.WaitGroup
|
||||
n atomic.Int64
|
||||
}
|
||||
|
||||
func (g *Group) Go(f func()) {
|
||||
g.wg.Add(1)
|
||||
go func() {
|
||||
g.n.Add(1) // Now we are running
|
||||
defer func() {
|
||||
g.wg.Done()
|
||||
g.n.Add(-1) // Now we are done
|
||||
}()
|
||||
f()
|
||||
}()
|
||||
}
|
||||
|
||||
// Running returns the number of goroutines that are currently running.
|
||||
//
|
||||
// If a call to [Running] returns zero, and a call to [Wait] is made without
|
||||
// any calls to [Go], then [Wait] will return immediately. This is true even if
|
||||
// a goroutine is started and finishes between the two calls.
|
||||
//
|
||||
// It is possible for [Running] to return non-zero and for [Wait] to return
|
||||
// immediately. This can happen if the all running goroutines finish between
|
||||
// the two calls.
|
||||
func (g *Group) Running() int64 {
|
||||
return g.n.Load()
|
||||
}
|
||||
|
||||
func (g *Group) Wait() {
|
||||
g.wg.Wait()
|
||||
}
|
||||
74
server/internal/internal/testutil/testutil.go
Normal file
74
server/internal/internal/testutil/testutil.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Check calls t.Fatal(err) if err is not nil.
|
||||
func Check(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Helper()
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckFunc exists so other packages do not need to invent their own type for
|
||||
// taking a Check function.
|
||||
type CheckFunc func(err error)
|
||||
|
||||
// Checker returns a check function that
|
||||
// calls t.Fatal if err is not nil.
|
||||
func Checker(t *testing.T) (check func(err error)) {
|
||||
return func(err error) {
|
||||
if err != nil {
|
||||
t.Helper()
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopPanic runs f but silently recovers from any panic f causes.
|
||||
// The normal usage is:
|
||||
//
|
||||
// testutil.StopPanic(func() {
|
||||
// callThatShouldPanic()
|
||||
// t.Errorf("callThatShouldPanic did not panic")
|
||||
// })
|
||||
func StopPanic(f func()) {
|
||||
defer func() { recover() }()
|
||||
f()
|
||||
}
|
||||
|
||||
// CheckTime calls t.Fatalf if got != want. Included in the error message is
|
||||
// want.Sub(got) to help diagnose the difference, along with their values in
|
||||
// UTC.
|
||||
func CheckTime(t *testing.T, got, want time.Time) {
|
||||
t.Helper()
|
||||
if !got.Equal(want) {
|
||||
t.Fatalf("got %v, want %v (%v)", got.UTC(), want.UTC(), want.Sub(got))
|
||||
}
|
||||
}
|
||||
|
||||
// WriteFile writes data to a file named name. It makes the directory if it
|
||||
// doesn't exist and sets the file mode to perm.
|
||||
//
|
||||
// The name must be a relative path and must not contain .. or start with a /;
|
||||
// otherwise WriteFile will panic.
|
||||
func WriteFile[S []byte | string](t testing.TB, name string, data S) {
|
||||
t.Helper()
|
||||
|
||||
if filepath.IsAbs(name) {
|
||||
t.Fatalf("WriteFile: name must be a relative path, got %q", name)
|
||||
}
|
||||
name = filepath.Clean(name)
|
||||
dir := filepath.Dir(name)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(name, []byte(data), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user