add safetensors version

This commit is contained in:
Patrick Devine
2024-04-24 18:32:01 -07:00
committed by Michael Yang
parent d88582dffd
commit 4730762e5c
2 changed files with 20 additions and 4 deletions

View File

@@ -20,7 +20,7 @@ type LlamaModel struct {
ModelData
}
func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
data := r.storage.(*pytorch.HalfStorage).Data
@@ -105,9 +105,16 @@ func (m *LlamaModel) GetTensors() error {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
wt := l.WriterTo.(torchWriterTo)
wt.handler = llamaLayerHandler
l.WriterTo = wt
switch l.WriterTo.(type) {
case torchWriterTo:
wt := l.WriterTo.(torchWriterTo)
wt.handler = llamaTorchLayerHandler
l.WriterTo = wt
case safetensorWriterTo:
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler
l.WriterTo = wt
}
}
m.Tensors = append(m.Tensors, l)
}