fix token type

This commit is contained in:
Michael Yang
2025-04-23 12:40:05 -07:00
committed by Michael Yang
parent 8d376acc9b
commit d26c18e25c
13 changed files with 36 additions and 25 deletions

View File

@@ -177,7 +177,7 @@ type TextDecoder struct {
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
for i, layer := range d.Layers {
layerType := selfAttentionLayer
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
if slices.Contains(opts.crossAttentionLayers, int32(i)) {
layerType = crossAttentionLayer
}
@@ -202,7 +202,7 @@ type TextModelOptions struct {
eps, ropeBase, ropeScale float32
ropeDim uint32
crossAttentionLayers []uint32
crossAttentionLayers []int32
}
type TextModel struct {
@@ -225,7 +225,7 @@ func newTextModel(c fs.Config) *TextModel {
var decoderLayers []TextDecoderLayer
for i := range c.Uint("block_count") {
var textDecoderLayer TextDecoderLayer
if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) {
textDecoderLayer = &TextCrossAttentionDecoderLayer{}
} else {
textDecoderLayer = &TextSelfAttentionDecoderLayer{}
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
},
}
}