mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-12 00:37:04 +00:00
60 lines
1.6 KiB
Python
60 lines
1.6 KiB
Python
import os
|
|
import sys
|
|
from os import path
|
|
from contextlib import contextmanager
|
|
from llama_cpp import Llama as LLM
|
|
|
|
import ollama.prompt
|
|
from ollama.model import models_home
|
|
|
|
|
|
@contextmanager
|
|
def suppress_stderr():
|
|
stderr = os.dup(sys.stderr.fileno())
|
|
with open(os.devnull, "w") as devnull:
|
|
os.dup2(devnull.fileno(), sys.stderr.fileno())
|
|
yield
|
|
|
|
os.dup2(stderr, sys.stderr.fileno())
|
|
|
|
|
|
def generate(model_name, prompt, models={}, *args, **kwargs):
|
|
if "max_tokens" not in kwargs:
|
|
kwargs.update({"max_tokens": 16384})
|
|
|
|
if "stop" not in kwargs:
|
|
kwargs.update({"stop": ["Q:"]})
|
|
|
|
if "stream" not in kwargs:
|
|
kwargs.update({"stream": True})
|
|
|
|
prompt = ollama.prompt.template(model_name, prompt)
|
|
|
|
model = load(model_name, models=models)
|
|
for output in model.create_completion(prompt, *args, **kwargs):
|
|
yield output
|
|
|
|
|
|
def load(model_name, models={}):
|
|
model = models.get(model_name, None)
|
|
if not model:
|
|
model_path = path.expanduser(model_name)
|
|
if not path.exists(model_path):
|
|
model_path = path.join(models_home, model_name + ".bin")
|
|
|
|
try:
|
|
# suppress LLM's output
|
|
with suppress_stderr():
|
|
model = LLM(model_path, verbose=False)
|
|
models.update({model_name: model})
|
|
except Exception:
|
|
# e is sent to devnull, so create a generic exception
|
|
raise Exception(f"Failed to load model: {model}")
|
|
|
|
return model
|
|
|
|
|
|
def unload(model_name, models={}):
|
|
if model_name in models:
|
|
models.pop(model_name)
|