simplify loading

This commit is contained in:
Bruce MacDonald
2023-06-27 14:50:23 -04:00
parent 3b4f45f6bf
commit ecfb4abafb
5 changed files with 40 additions and 40 deletions

View File

@@ -27,47 +27,46 @@ def models_directory():
return models_dir
def load(model=None, path=None):
def load(model):
"""
Load a model.
The model can be specified by providing either the path or the model name,
but not both. If both are provided, this function will raise a ValueError.
If the model does not exist or could not be loaded, this function returns an error.
Args:
model (str, optional): The name of the model to load.
path (str, optional): The path to the model file.
model (str): The name or path of the model to load.
Returns:
str or None: The name of the model
dict or None: If the model cannot be loaded, a dictionary with an 'error' key is returned.
If the model is successfully loaded, None is returned.
"""
with lock:
if path is not None and model is not None:
raise ValueError(
"Both path and model are specified. Please provide only one of them."
)
elif path is not None:
name = os.path.basename(path)
load_from = ""
if os.path.exists(model) and model.endswith(".bin"):
# model is being referenced by path rather than name directly
path = os.path.abspath(model)
base = os.path.basename(path)
load_from = path
elif model is not None:
name = model
dir = models_directory()
load_from = str(dir / f"{model}.bin")
name = os.path.splitext(base)[0] # Split the filename and extension
else:
raise ValueError("Either path or model must be specified.")
# model is being loaded from the ollama models directory
dir = models_directory()
# TODO: download model from a repository if it does not exist
load_from = str(dir / f"{model}.bin")
name = model
if load_from == "":
return None, {"error": "Model not found."}
if not os.path.exists(load_from):
return {"error": f"The model at {load_from} does not exist."}
return None, {"error": f"The model {load_from} does not exist."}
if name not in llms:
# TODO: download model from a repository if it does not exist
llms[name] = Llama(model_path=load_from)
# TODO: this should start a persistent instance of ollama with the model loaded
return None
return name, None
def unload(model):
@@ -84,10 +83,10 @@ def unload(model):
def generate(model, prompt):
# auto load
error = load(model)
name, error = load(model)
if error is not None:
return error
generated = llms[model](
generated = llms[name](
str(prompt), # TODO: optimize prompt based on model
max_tokens=4096,
stop=["Q:", "\n"],