Documentation
¶
Index ¶
Examples ¶
Constants ¶
This section is empty.
Variables ¶
var DTypeToWordSize = map[DType]uint64{ BOOL: 1, U8: 1, I8: 1, F8_E5M2: 1, F8_E4M3: 1, I16: 2, U16: 2, F16: 2, BF16: 2, I32: 4, U32: 4, F32: 4, F64: 8, I64: 8, U64: 8, }
DTypeToWordSize is the map of each DType and the number of bytes it represents.
Functions ¶
This section is empty.
Types ¶
type DType ¶
type DType string
DType identifies a data type.
It matches the DType type at https://github.com/huggingface/safetensors/blob/main/safetensors/src/tensor.rs.
const ( // Boolan type BOOL DType = "BOOL" // Unsigned byte U8 DType = "U8" // Signed byte I8 DType = "I8" // FP8 <https://arxiv.org/pdf/2209.05433.pdf> F8_E5M2 DType = "F8_E5M2" // FP8 <https://arxiv.org/pdf/2209.05433.pdf> F8_E4M3 DType = "F8_E4M3" // Signed integer (16-bit) I16 DType = "I16" // Unsigned integer (16-bit) U16 DType = "U16" // Half-precision floating point F16 DType = "F16" // Brain floating point BF16 DType = "BF16" // Signed integer (32-bit) I32 DType = "I32" // Unsigned integer (32-bit) U32 DType = "U32" // Floating point (32-bit) F32 DType = "F32" // Floating point (64-bit) F64 DType = "F64" // Signed integer (64-bit) I64 DType = "I64" // Unsigned integer (64-bit) U64 DType = "U64" )
func (*DType) UnmarshalJSON ¶
UnmarshalJSON implements json.Unmarshaler.
type File ¶
File is a structure owning some metadata to lookup tensors on a shared `data` byte-buffer.
func Parse ¶
Parse parses a byte-buffer representing the whole safetensors file and returns the deserialized form.
It keeps references to the buffer so the buffer must not be modified afterwards.
Example ¶
package main
import (
"fmt"
"log"
"github.com/maruel/safetensors"
)
func main() {
serialized := []byte("\x59\x00\x00\x00\x00\x00\x00\x00" +
`{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]},"__metadata__":{"foo":"bar"}}` +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")
loaded, err := safetensors.Parse(serialized)
if err != nil {
log.Fatal(err)
}
var names []string
for _, t := range loaded.Tensors {
names = append(names, t.Name)
}
fmt.Printf("len = %d\n", len(loaded.Tensors))
fmt.Printf("names = %+v\n", names)
tensor := loaded.Tensors[0]
fmt.Printf("tensor type = %s\n", tensor.DType)
fmt.Printf("tensor shape = %+v\n", tensor.Shape)
fmt.Printf("tensor data len = %+v\n", len(tensor.Data))
}
Output: len = 1 names = [test] tensor type = I32 tensor shape = [2 2] tensor data len = 16
func (*File) Serialize ¶
Serialize the list of tensors to an io.Writer.
Example ¶
package main
import (
"bytes"
"encoding/binary"
"fmt"
"log"
"math"
"github.com/maruel/safetensors"
)
func main() {
floatData := []float32{0, 1, 2, 3, 4, 5}
data := make([]byte, 0, len(floatData)*4)
for _, v := range floatData {
data = binary.LittleEndian.AppendUint32(data, math.Float32bits(v))
}
shape := []uint64{1, 2, 3}
tensor := safetensors.Tensor{Name: "foo", DType: safetensors.F32, Shape: shape, Data: data}
if err := tensor.Validate(); err != nil {
log.Fatal(err)
}
f := safetensors.File{
Tensors: []safetensors.Tensor{tensor},
}
buf := bytes.Buffer{}
if err := f.Serialize(&buf); err != nil {
log.Fatal(err)
}
fmt.Printf("data len = %d\n", buf.Len())
fmt.Printf("data excerpt: ...%s...\n", buf.Bytes()[8:30])
}
Output: data len = 96 data excerpt: ...{"foo":{"dtype":"F32",...
type Mapped ¶ added in v1.1.0
type Mapped struct {
*File
// contains filtered or unexported fields
}
Mapped is a read-only memory mapped SafeTensors file.
This is the fastest way to use a safetensors file.
Example ¶
This is the recommended way to load safe tensors as it is the most efficient.
package main
import (
"fmt"
"log"
"github.com/maruel/safetensors"
)
func main() {
m := safetensors.Mapped{}
if err := m.Open("path/to/model.safetensors"); err != nil {
log.Fatal(err)
}
defer m.Close()
var names []string
for _, t := range m.Tensors {
names = append(names, t.Name)
}
fmt.Printf("len = %d\n", len(m.Tensors))
fmt.Printf("names = %+v\n", names)
tensor := m.Tensors[0]
fmt.Printf("tensor type = %s\n", tensor.DType)
fmt.Printf("tensor shape = %+v\n", tensor.Shape)
fmt.Printf("tensor data len = %+v\n", len(tensor.Data))
}
Output: