Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion compiler/fory_compiler/generators/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def message_has_unions(self, message: Message) -> bool:
PrimitiveKind.UINT64: "uint64",
PrimitiveKind.VAR_UINT64: "uint64",
PrimitiveKind.TAGGED_UINT64: "uint64",
PrimitiveKind.FLOAT16: "float32",
PrimitiveKind.FLOAT16: "float16.Float16",
PrimitiveKind.FLOAT32: "float32",
PrimitiveKind.FLOAT64: "float64",
PrimitiveKind.STRING: "string",
Expand Down Expand Up @@ -1077,6 +1077,8 @@ def collect_imports(self, field_type: FieldType, imports: Set[str]):
if isinstance(field_type, PrimitiveType):
if field_type.kind in (PrimitiveKind.DATE, PrimitiveKind.TIMESTAMP):
imports.add('"time"')
elif field_type.kind == PrimitiveKind.FLOAT16:
imports.add('float16 "github.com/apache/fory/go/fory/float16"')

elif isinstance(field_type, ListType):
self.collect_imports(field_type.element_type, imports)
Expand Down
77 changes: 77 additions & 0 deletions go/fory/array_primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package fory
import (
"reflect"
"unsafe"

"github.com/apache/fory/go/fory/float16"
)

// ============================================================================
Expand Down Expand Up @@ -794,3 +796,78 @@ func (s uint64ArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType
func (s uint64ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) {
s.Read(ctx, refMode, false, false, value)
}

// ============================================================================
// float16ArraySerializer - optimized [N]float16.Float16 serialization
// ============================================================================

type float16ArraySerializer struct {
arrayType reflect.Type
}

func (s float16ArraySerializer) WriteData(ctx *WriteContext, value reflect.Value) {
buf := ctx.Buffer()
length := value.Len()
size := length * 2
buf.WriteLength(size)
if length > 0 {
if value.CanAddr() && isLittleEndian {
ptr := value.Addr().UnsafePointer()
buf.WriteBinary(unsafe.Slice((*byte)(ptr), size))
} else {
for i := 0; i < length; i++ {
// We can't easily cast the whole array if not addressable/little-endian
// So we iterate.
// value.Index(i) is Float16, we cast to uint16
val := value.Index(i).Interface().(float16.Float16)
buf.WriteUint16(val.Bits())
}
}
}
}

func (s float16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) {
writeArrayRefAndType(ctx, refMode, writeType, value, FLOAT16_ARRAY)
if ctx.HasError() {
return
}
s.WriteData(ctx, value)
}

func (s float16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
size := buf.ReadLength(ctxErr)
length := size / 2
if ctx.HasError() {
return
}
if length != value.Type().Len() {
ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type()))
return
}

if length > 0 {
if isLittleEndian {
ptr := value.Addr().UnsafePointer()
raw := buf.ReadBinary(size, ctxErr)
copy(unsafe.Slice((*byte)(ptr), size), raw)
} else {
for i := 0; i < length; i++ {
value.Index(i).Set(reflect.ValueOf(float16.Float16FromBits(buf.ReadUint16(ctxErr))))
}
}
}
}

func (s float16ArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) {
done := readArrayRefAndType(ctx, refMode, readType, value)
if done || ctx.HasError() {
return
}
s.ReadData(ctx, value)
}

func (s float16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) {
s.Read(ctx, refMode, false, false, value)
}
20 changes: 20 additions & 0 deletions go/fory/array_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package fory
import (
"testing"

"github.com/apache/fory/go/fory/float16"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -130,3 +131,22 @@ func TestArraySliceInteroperability(t *testing.T) {
assert.Contains(t, err.Error(), "array length")
})
}

func TestFloat16Array(t *testing.T) {
f := NewFory()

t.Run("float16_array", func(t *testing.T) {
arr := [3]float16.Float16{
float16.Float16FromFloat32(1.0),
float16.Float16FromFloat32(2.5),
float16.Float16FromFloat32(-0.5),
}
data, err := f.Serialize(arr)
assert.NoError(t, err)

var result [3]float16.Float16
err = f.Deserialize(data, &result)
assert.NoError(t, err)
assert.Equal(t, arr, result)
})
}
Loading
Loading