safetensors

package module
v1.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Nov 15, 2024 License: BSD-2-Clause Imports: 9 Imported by: 3

README

Safetensors in Go

Features

  • Performance and Memory efficient, including benchmarks. 0.25ms to decode GPT2.
  • Keep the tensor ordering as specified in the safetensors file.
  • Simple API

See whole documentation at Go Reference

codecov

Credits

This package is inspired on the work of the NLP Odyssey Authors, which was itself inspired by Hugging Face's original Rust implementation.

Documentation

Index

Examples

Constants

This section is empty.

Variables

View Source
var DTypeToWordSize = map[DType]uint64{
	BOOL:    1,
	U8:      1,
	I8:      1,
	F8_E5M2: 1,
	F8_E4M3: 1,
	I16:     2,
	U16:     2,
	F16:     2,
	BF16:    2,
	I32:     4,
	U32:     4,
	F32:     4,
	F64:     8,
	I64:     8,
	U64:     8,
}

DTypeToWordSize is the map of each DType and the number of bytes it represents.

Functions

This section is empty.

Types

type DType

type DType string

DType identifies a data type.

It matches the DType type at https://github.com/huggingface/safetensors/blob/main/safetensors/src/tensor.rs.

const (
	// Boolan type
	BOOL DType = "BOOL"
	// Unsigned byte
	U8 DType = "U8"
	// Signed byte
	I8 DType = "I8"
	// FP8 <https://arxiv.org/pdf/2209.05433.pdf>
	F8_E5M2 DType = "F8_E5M2"
	// FP8 <https://arxiv.org/pdf/2209.05433.pdf>
	F8_E4M3 DType = "F8_E4M3"
	// Signed integer (16-bit)
	I16 DType = "I16"
	// Unsigned integer (16-bit)
	U16 DType = "U16"
	// Half-precision floating point
	F16 DType = "F16"
	// Brain floating point
	BF16 DType = "BF16"
	// Signed integer (32-bit)
	I32 DType = "I32"
	// Unsigned integer (32-bit)
	U32 DType = "U32"
	// Floating point (32-bit)
	F32 DType = "F32"
	// Floating point (64-bit)
	F64 DType = "F64"
	// Signed integer (64-bit)
	I64 DType = "I64"
	// Unsigned integer (64-bit)
	U64 DType = "U64"
)

func (*DType) UnmarshalJSON

func (dt *DType) UnmarshalJSON(data []byte) error

UnmarshalJSON implements json.Unmarshaler.

func (DType) WordSize

func (dt DType) WordSize() uint64

WordSize returns the size in bytes of one element of this data type.

type File

type File struct {
	Tensors  []Tensor
	Metadata map[string]string
}

File is a structure owning some metadata to lookup tensors on a shared `data` byte-buffer.

func Parse

func Parse(buffer []byte) (*File, error)

Parse parses a byte-buffer representing the whole safetensors file and returns the deserialized form.

It keeps references to the buffer so the buffer must not be modified afterwards.

Example
package main

import (
	"fmt"
	"log"

	"github.com/maruel/safetensors"
)

func main() {
	serialized := []byte("\x59\x00\x00\x00\x00\x00\x00\x00" +
		`{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]},"__metadata__":{"foo":"bar"}}` +
		"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")

	loaded, err := safetensors.Parse(serialized)
	if err != nil {
		log.Fatal(err)
	}
	var names []string
	for _, t := range loaded.Tensors {
		names = append(names, t.Name)
	}
	fmt.Printf("len = %d\n", len(loaded.Tensors))
	fmt.Printf("names = %+v\n", names)

	tensor := loaded.Tensors[0]
	fmt.Printf("tensor type = %s\n", tensor.DType)
	fmt.Printf("tensor shape = %+v\n", tensor.Shape)
	fmt.Printf("tensor data len = %+v\n", len(tensor.Data))

}
Output:
len = 1
names = [test]
tensor type = I32
tensor shape = [2 2]
tensor data len = 16

func (*File) Serialize

func (f *File) Serialize(w io.Writer) error

Serialize the list of tensors to an io.Writer.

Example
package main

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"log"
	"math"

	"github.com/maruel/safetensors"
)

func main() {
	floatData := []float32{0, 1, 2, 3, 4, 5}
	data := make([]byte, 0, len(floatData)*4)
	for _, v := range floatData {
		data = binary.LittleEndian.AppendUint32(data, math.Float32bits(v))
	}

	shape := []uint64{1, 2, 3}

	tensor := safetensors.Tensor{Name: "foo", DType: safetensors.F32, Shape: shape, Data: data}
	if err := tensor.Validate(); err != nil {
		log.Fatal(err)
	}
	f := safetensors.File{
		Tensors: []safetensors.Tensor{tensor},
	}

	buf := bytes.Buffer{}
	if err := f.Serialize(&buf); err != nil {
		log.Fatal(err)
	}

	fmt.Printf("data len = %d\n", buf.Len())
	fmt.Printf("data excerpt: ...%s...\n", buf.Bytes()[8:30])

}
Output:
data len = 96
data excerpt: ...{"foo":{"dtype":"F32",...

type Mapped added in v1.1.0

type Mapped struct {
	*File
	// contains filtered or unexported fields
}

Mapped is a read-only memory mapped SafeTensors file.

This is the fastest way to use a safetensors file.

Example

This is the recommended way to load safe tensors as it is the most efficient.

package main

import (
	"fmt"
	"log"

	"github.com/maruel/safetensors"
)

func main() {
	m := safetensors.Mapped{}
	if err := m.Open("path/to/model.safetensors"); err != nil {
		log.Fatal(err)
	}
	defer m.Close()
	var names []string
	for _, t := range m.Tensors {
		names = append(names, t.Name)
	}
	fmt.Printf("len = %d\n", len(m.Tensors))
	fmt.Printf("names = %+v\n", names)

	tensor := m.Tensors[0]
	fmt.Printf("tensor type = %s\n", tensor.DType)
	fmt.Printf("tensor shape = %+v\n", tensor.Shape)
	fmt.Printf("tensor data len = %+v\n", len(tensor.Data))
}

func (*Mapped) Close added in v1.1.0

func (s *Mapped) Close() error

Close releases the memory region and the file handle.

func (*Mapped) Open added in v1.1.0

func (s *Mapped) Open(name string) error

Open opens a file and memory maps it read-only.

type Tensor

type Tensor struct {
	Name  string
	DType DType
	Shape []uint64
	Data  []byte
}

Tensor is a view of a Tensor within a file.

It contains references to data within the full byte-buffer and is thus a readable view of a single tensor.

func (*Tensor) Validate

func (t *Tensor) Validate() error

Validate validates the object.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL