Introduce /api/embed endpoint supporting batch embedding (#5127)

* Initial Batch Embedding

* Revert "Initial Batch Embedding"

This reverts commit c22d54895a280b54c727279d85a5fc94defb5a29.

* Initial Draft

* mock up notes

* api/embed draft

* add server function

* check normalization

* clean up

* normalization

* playing around with truncate stuff

* Truncation

* Truncation

* move normalization to go

* Integration Test Template

* Truncation Integration Tests

* Clean up

* use float32

* move normalize

* move normalize test

* refactoring

* integration float32

* input handling and handler testing

* Refactoring of legacy and new

* clear comments

* merge conflicts

* touches

* embedding type 64

* merge conflicts

* fix hanging on single string

* refactoring

* test values

* set context length

* clean up

* testing clean up

* testing clean up

* remove function closure

* Revert "remove function closure"

This reverts commit 55d48c6ed17abe42e7a122e69d603ef0c1506787.

* remove function closure

* remove redundant error check

* clean up

* more clean up

* clean up
This commit is contained in:
royjhan
2024-07-15 12:14:24 -07:00
committed by GitHub
parent e9f7f36029
commit b9f5e16c80
8 changed files with 452 additions and 30 deletions

View File

@@ -3188,26 +3188,33 @@ int main(int argc, char **argv) {
prompt = "";
}
json image_data;
if (body.count("image_data") != 0) {
image_data = body["image_data"];
}
else
{
image_data = "";
if (prompt.size() == 1) {
prompt = prompt[0];
}
// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
json responses;
{
const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
// get the result
task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}
// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
responses = result.result_json.value("results", std::vector<json>{result.result_json});
json embeddings = json::array();
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json embedding_res = json{{"embedding", embeddings}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
}
});
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

View File

@@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Embed(ctx context.Context, input []string) ([][]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@@ -867,15 +867,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil
}
type EmbeddingRequest struct {
Content string `json:"content"`
type EmbedRequest struct {
Content []string `json:"content"`
}
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
type EmbedResponse struct {
Embedding [][]float32 `json:"embedding"`
}
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
@@ -890,7 +890,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: prompt})
data, err := json.Marshal(EmbedRequest{Content: input})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
@@ -917,7 +917,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("%s", body)
}
var embedding EmbeddingResponse
var embedding EmbedResponse
if err := json.Unmarshal(body, &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}