diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py index 483cb56796..34ee3410ac 100644 --- a/compiler/fory_compiler/generators/go.py +++ b/compiler/fory_compiler/generators/go.py @@ -189,7 +189,7 @@ def message_has_unions(self, message: Message) -> bool: PrimitiveKind.UINT64: "uint64", PrimitiveKind.VAR_UINT64: "uint64", PrimitiveKind.TAGGED_UINT64: "uint64", - PrimitiveKind.FLOAT16: "float32", + PrimitiveKind.FLOAT16: "float16.Float16", PrimitiveKind.FLOAT32: "float32", PrimitiveKind.FLOAT64: "float64", PrimitiveKind.STRING: "string", @@ -1077,6 +1077,8 @@ def collect_imports(self, field_type: FieldType, imports: Set[str]): if isinstance(field_type, PrimitiveType): if field_type.kind in (PrimitiveKind.DATE, PrimitiveKind.TIMESTAMP): imports.add('"time"') + elif field_type.kind == PrimitiveKind.FLOAT16: + imports.add('float16 "github.com/apache/fory/go/fory/float16"') elif isinstance(field_type, ListType): self.collect_imports(field_type.element_type, imports) diff --git a/go/fory/array_primitive.go b/go/fory/array_primitive.go index a503672879..9777733590 100644 --- a/go/fory/array_primitive.go +++ b/go/fory/array_primitive.go @@ -20,6 +20,8 @@ package fory import ( "reflect" "unsafe" + + "github.com/apache/fory/go/fory/float16" ) // ============================================================================ @@ -794,3 +796,78 @@ func (s uint64ArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType func (s uint64ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { s.Read(ctx, refMode, false, false, value) } + +// ============================================================================ +// float16ArraySerializer - optimized [N]float16.Float16 serialization +// ============================================================================ + +type float16ArraySerializer struct { + arrayType reflect.Type +} + +func (s float16ArraySerializer) WriteData(ctx *WriteContext, value reflect.Value) { + buf := ctx.Buffer() + length := value.Len() + size := length * 2 + buf.WriteLength(size) + if length > 0 { + if value.CanAddr() && isLittleEndian { + ptr := value.Addr().UnsafePointer() + buf.WriteBinary(unsafe.Slice((*byte)(ptr), size)) + } else { + for i := 0; i < length; i++ { + // We can't easily cast the whole array if not addressable/little-endian + // So we iterate. + // value.Index(i) is Float16, we cast to uint16 + val := value.Index(i).Interface().(float16.Float16) + buf.WriteUint16(val.Bits()) + } + } + } +} + +func (s float16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + writeArrayRefAndType(ctx, refMode, writeType, value, FLOAT16_ARRAY) + if ctx.HasError() { + return + } + s.WriteData(ctx, value) +} + +func (s float16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { + buf := ctx.Buffer() + ctxErr := ctx.Err() + size := buf.ReadLength(ctxErr) + length := size / 2 + if ctx.HasError() { + return + } + if length != value.Type().Len() { + ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + return + } + + if length > 0 { + if isLittleEndian { + ptr := value.Addr().UnsafePointer() + raw := buf.ReadBinary(size, ctxErr) + copy(unsafe.Slice((*byte)(ptr), size), raw) + } else { + for i := 0; i < length; i++ { + value.Index(i).Set(reflect.ValueOf(float16.Float16FromBits(buf.ReadUint16(ctxErr)))) + } + } + } +} + +func (s float16ArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + done := readArrayRefAndType(ctx, refMode, readType, value) + if done || ctx.HasError() { + return + } + s.ReadData(ctx, value) +} + +func (s float16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} diff --git a/go/fory/array_primitive_test.go b/go/fory/array_primitive_test.go index a55b6ede86..c2e684af4c 100644 --- a/go/fory/array_primitive_test.go +++ b/go/fory/array_primitive_test.go @@ -20,6 +20,7 @@ package fory import ( "testing" + "github.com/apache/fory/go/fory/float16" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -130,3 +131,22 @@ func TestArraySliceInteroperability(t *testing.T) { assert.Contains(t, err.Error(), "array length") }) } + +func TestFloat16Array(t *testing.T) { + f := NewFory() + + t.Run("float16_array", func(t *testing.T) { + arr := [3]float16.Float16{ + float16.Float16FromFloat32(1.0), + float16.Float16FromFloat32(2.5), + float16.Float16FromFloat32(-0.5), + } + data, err := f.Serialize(arr) + assert.NoError(t, err) + + var result [3]float16.Float16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.Equal(t, arr, result) + }) +} diff --git a/go/fory/float16/float16.go b/go/fory/float16/float16.go new file mode 100644 index 0000000000..1a7ca38d53 --- /dev/null +++ b/go/fory/float16/float16.go @@ -0,0 +1,363 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package float16 + +import ( + "fmt" + "math" +) + +// Float16 represents a half-precision floating point number (IEEE 754-2008 binary16). +// It is stored as a uint16. +type Float16 uint16 + +// Constants for half-precision floating point +const ( + uvNan = 0x7e00 // 0 11111 1000000000 (standard quiet NaN) + uvInf = 0x7c00 // 0 11111 0000000000 (+Inf) + uvNegInf = 0xfc00 // 1 11111 0000000000 (-Inf) + uvNegZero = 0x8000 // 1 00000 0000000000 (-0) + uvMax = 0x7bff // 65504 + uvMinNorm = 0x0400 // 2^-14 (highest subnormal is 0x03ff, lowest normal is 0x0400) + uvMinSub = 0x0001 // 2^-24 + uvOne = 0x3c00 // 1.0 + maskSign = 0x8000 + maskExp = 0x7c00 + maskMant = 0x03ff +) + +// Common values +var ( + NaN = Float16(uvNan) + Inf = Float16(uvInf) + NegInf = Float16(uvNegInf) + Zero = Float16(0) + NegZero = Float16(uvNegZero) + Max = Float16(uvMax) + Smallest = Float16(uvMinSub) // Smallest non-zero + One = Float16(uvOne) +) + +// Float16FromBits returns the Float16 corresponding to the given bit pattern. +func Float16FromBits(b uint16) Float16 { + return Float16(b) +} + +// Bits returns the raw bit pattern of the floating point number. +func (f Float16) Bits() uint16 { + return uint16(f) +} + +// Float16FromFloat32 converts a float32 to a Float16. +// Rounds to nearest, ties to even. +func Float16FromFloat32(f32 float32) Float16 { + bits := math.Float32bits(f32) + sign := (bits >> 31) & 0x1 + exp := (bits >> 23) & 0xff + mant := bits & 0x7fffff + + var outSign uint16 = uint16(sign) << 15 + var outExp uint16 + var outMant uint16 + + if exp == 0xff { + // NaN or Inf + outExp = 0x1f + if mant != 0 { + // NaN - preserve top bit of mantissa for quiet/signaling if possible, but simplest is canonical QNaN + outMant = 0x200 | (uint16(mant>>13) & 0x1ff) + if outMant == 0 { + outMant = 0x200 // Ensure at least one bit + } + } else { + // Inf + outMant = 0 + } + } else if exp == 0 { + // Signed zero or subnormal float32 (which becomes zero in float16 usually) + outExp = 0 + outMant = 0 + } else { + // Normalized + newExp := int(exp) - 127 + 15 + if newExp >= 31 { + // Overflow to Inf + outExp = 0x1f + outMant = 0 + } else if newExp <= 0 { + // Underflow to subnormal or zero + // Shift mantissa to align with float16 subnormal range + // float32 mantissa has implicit 1. + fullMant := mant | 0x800000 + shift := 1 - newExp // 1 for implicit bit alignment + // We need to round. + // Mantissa bits: 23. Subnormal 16 mant bits: 10. + // We want to shift right by (13 + shift). + + // Let's do a more precise soft-float rounding + // Re-assemble float value to handle subnormal rounding correctly is hard with just bit shifts + // But since we have hardware float32... + // Actually pure bit manipulation is robust if careful. + + // Shift right amount + netShift := 13 + shift // 23 - 10 + shift + + if netShift >= 24 { + // Too small, becomes zero + outExp = 0 + outMant = 0 + } else { + outExp = 0 + roundBit := (fullMant >> (netShift - 1)) & 1 + sticky := (fullMant & ((1 << (netShift - 1)) - 1)) + outMant = uint16(fullMant >> netShift) + + if roundBit == 1 { + if sticky != 0 || (outMant&1) == 1 { + outMant++ + } + } + } + } else { + // Normal range + outExp = uint16(newExp) + // Mantissa: float32 has 23 bits, float16 has 10. + // We need to round based on the dropped 13 bits. + // Last kept bit at index 13 (0-indexed from LSB of float32 mant) + // Round bit is index 12. + + // Using helper to round + outMant = uint16(mant >> 13) + roundBit := (mant >> 12) & 1 + sticky := mant & 0xfff + + if roundBit == 1 { + // Round to nearest, ties to even + if sticky != 0 || (outMant&1) == 1 { + outMant++ + if outMant > 0x3ff { + // Overflow mantissa, increment exponent + outMant = 0 + outExp++ + if outExp >= 31 { + outExp = 0x1f // Inf + } + } + } + } + } + } + + return Float16(outSign | (outExp << 10) | outMant) +} + +// Float32 returns the float32 representation of the Float16. +func (f Float16) Float32() float32 { + bits := uint16(f) + sign := (bits >> 15) & 0x1 + exp := (bits >> 10) & 0x1f + mant := bits & 0x3ff + + var outBits uint32 + outBits = uint32(sign) << 31 + + if exp == 0x1f { + // NaN or Inf + outBits |= 0xff << 23 + if mant != 0 { + // NaN - promote mantissa + outBits |= uint32(mant) << 13 + } + } else if exp == 0 { + if mant == 0 { + // Signed zero + outBits |= 0 + } else { + // Subnormal + // Convert to float32 normal + // Normalize the subnormal + shift := 0 + m := uint32(mant) + for (m & 0x400) == 0 { + m <<= 1 + shift++ + } + // m now has bit 10 set (implicit 1 for float32) + // discard implicit bit + m &= 0x3ff + // new float32 exponent + // subnormal 16 is 2^-14 * 0.mant + // = 2^-14 * 2^-10 * mant_integer + // = 2^-24 * mant_integer + // Normalized float32 is 1.mant * 2^(E-127) + // We effectively shift left until we hit the 1. + // The effective exponent is (1 - 15) - shift = -14 - shift? + // Simpler: + // value = mant * 2^-24 + // Reconstruct using float32 operations to avoid bit headaches? + // No, bit ops are faster. + + // Float16 subnormal: (-1)^S * 2^(1-15) * (mant / 1024) + // = (-1)^S * 2^-14 * (mant * 2^-10) + // = (-1)^S * 2^-24 * mant + + // Float32: (-1)^S * 2^(E-127) * (1 + M/2^23) + + // Let's use the magic number method or just float32 arithmetic if lazy + val := float32(mant) * float32(math.Pow(2, -24)) + if sign == 1 { + val = -val + } + return float32(val) + } + } else { + // Normal + outBits |= (uint32(exp) - 15 + 127) << 23 + outBits |= uint32(mant) << 13 + } + + return math.Float32frombits(outBits) +} + +// IsNaN reports whether f is an IEEE 754 “not-a-number” value. +func (f Float16) IsNaN() bool { + return (f&maskExp) == maskExp && (f&maskMant) != 0 +} + +// IsInf reports whether f is an infinity, according to sign. +// If sign > 0, IsInf reports whether f is positive infinity. +// If sign < 0, IsInf reports whether f is negative infinity. +// If sign == 0, IsInf reports whether f is either infinity. +func (f Float16) IsInf(sign int) bool { + isInf := (f&maskExp) == maskExp && (f&maskMant) == 0 + if !isInf { + return false + } + if sign == 0 { + return true + } + hasSign := (f & maskSign) != 0 + if sign > 0 { + return !hasSign + } + return hasSign +} + +// IsZero reports whether f is +0 or -0. +func (f Float16) IsZero() bool { + return (f & (maskExp | maskMant)) == 0 +} + +// IsFinite reports whether f is neither NaN nor an infinity. +func (f Float16) IsFinite() bool { + return (f & maskExp) != maskExp +} + +// IsNormal reports whether f is a normal value (not zero, subnormal, infinite, or NaN). +func (f Float16) IsNormal() bool { + exp := f & maskExp + return exp != 0 && exp != maskExp +} + +// IsSubnormal reports whether f is a subnormal value. +func (f Float16) IsSubnormal() bool { + exp := f & maskExp + mant := f & maskMant + return exp == 0 && mant != 0 +} + +// Signbit reports whether f is negative or negative zero. +func (f Float16) Signbit() bool { + return (f & maskSign) != 0 +} + +// String returns the string representation of f. +func (f Float16) String() string { + return fmt.Sprintf("%g", f.Float32()) +} + +// Arithmetic operations (promoted to float32) + +func (f Float16) Add(other Float16) Float16 { + return Float16FromFloat32(f.Float32() + other.Float32()) +} + +func (f Float16) Sub(other Float16) Float16 { + return Float16FromFloat32(f.Float32() - other.Float32()) +} + +func (f Float16) Mul(other Float16) Float16 { + return Float16FromFloat32(f.Float32() * other.Float32()) +} + +func (f Float16) Div(other Float16) Float16 { + return Float16FromFloat32(f.Float32() / other.Float32()) +} + +func (f Float16) Neg() Float16 { + return f ^ maskSign +} + +func (f Float16) Abs() Float16 { + return f &^ maskSign +} + +// Comparison + +func (f Float16) Equal(other Float16) bool { + // IEEE 754: NaN != NaN + if f.IsNaN() || other.IsNaN() { + return false + } + // +0 == -0 + if f.IsZero() && other.IsZero() { + return true + } + // Direct bit comparison works for typical normals with same sign + // But mixed signs or negative numbers need caear + return f.Float32() == other.Float32() +} + +func (f Float16) Less(other Float16) bool { + if f.IsNaN() || other.IsNaN() { + return false + } + // Handle signed zero: -0 is not less than +0 + if f.IsZero() && other.IsZero() { + return false + } + return f.Float32() < other.Float32() +} + +func (f Float16) LessEq(other Float16) bool { + if f.IsNaN() || other.IsNaN() { + return false + } + if f.IsZero() && other.IsZero() { + return true + } + return f.Float32() <= other.Float32() +} + +func (f Float16) Greater(other Float16) bool { + return other.Less(f) +} + +func (f Float16) GreaterEq(other Float16) bool { + return other.LessEq(f) +} diff --git a/go/fory/float16/float16_test.go b/go/fory/float16/float16_test.go new file mode 100644 index 0000000000..969c04c6e0 --- /dev/null +++ b/go/fory/float16/float16_test.go @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package float16_test + +import ( + "math" + "testing" + + "github.com/apache/fory/go/fory/float16" + "github.com/stretchr/testify/assert" +) + +func TestFloat16_Conversion(t *testing.T) { + tests := []struct { + name string + f32 float32 + want uint16 // bits + check bool // if true, check exact bits, else check float32 roundtrip within epsilon + }{ + {"Zero", 0.0, 0x0000, true}, + {"NegZero", float32(math.Copysign(0, -1)), 0x8000, true}, + {"One", 1.0, 0x3c00, true}, + {"MinusOne", -1.0, 0xbc00, true}, + {"Max", 65504.0, 0x7bff, true}, + {"Inf", float32(math.Inf(1)), 0x7c00, true}, + {"NegInf", float32(math.Inf(-1)), 0xfc00, true}, + // Smallest normal: 2^-14 = 0.000061035156 + {"SmallestNormal", float32(math.Pow(2, -14)), 0x0400, true}, + // Largest subnormal: 2^-14 - 2^-24 = 6.09756...e-5 + {"LargestSubnormal", float32(6.097555e-5), 0x03ff, true}, + // Smallest subnormal: 2^-24 + {"SmallestSubnormal", float32(math.Pow(2, -24)), 0x0001, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f16 := float16.Float16FromFloat32(tt.f32) + if tt.check { + assert.Equal(t, tt.want, f16.Bits(), "Bits match") + } + + // Round trip check + roundTrip := f16.Float32() + if math.IsInf(float64(tt.f32), 0) { + assert.True(t, math.IsInf(float64(roundTrip), 0)) + assert.Equal(t, math.Signbit(float64(tt.f32)), math.Signbit(float64(roundTrip))) + } else if math.IsNaN(float64(tt.f32)) { + assert.True(t, math.IsNaN(float64(roundTrip))) + } else { + // Allow small error due to precision loss + // Epsilon for float16 is 2^-10 ~= 0.001 relative error + // But we check consistency + if tt.check { + // bit exact means round trip should map back to similar float (precision loss expected) + // Verify that converting back to f16 gives same bits + f16back := float16.Float16FromFloat32(roundTrip) + assert.Equal(t, tt.want, f16back.Bits()) + } + } + }) + } +} + +func TestFloat16_NaN(t *testing.T) { + nan := float16.NaN + assert.True(t, nan.IsNaN()) + assert.False(t, nan.IsInf(0)) + assert.False(t, nan.IsZero()) + + // Comparison + assert.False(t, nan.Equal(nan)) + + // Conversion + f32 := nan.Float32() + assert.True(t, math.IsNaN(float64(f32))) +} + +func TestFloat16_Arithmetic(t *testing.T) { + one := float16.Float16FromFloat32(1.0) + two := float16.Float16FromFloat32(2.0) + three := float16.Float16FromFloat32(3.0) + + assert.Equal(t, "3", one.Add(two).String()) + assert.Equal(t, "2", three.Sub(one).String()) + assert.Equal(t, "6", two.Mul(three).String()) + assert.Equal(t, "1.5", three.Div(two).String()) +} diff --git a/go/fory/primitive.go b/go/fory/primitive.go index 63f0c53acb..3057d0dfae 100644 --- a/go/fory/primitive.go +++ b/go/fory/primitive.go @@ -605,3 +605,61 @@ func (s float64Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool func (s float64Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { s.Read(ctx, refMode, false, false, value) } + +// ============================================================================ +// float16Serializer - optimized float16 serialization +// ============================================================================ + +// float16Serializer handles float16 type +type float16Serializer struct{} + +var globalFloat16Serializer = float16Serializer{} + +func (s float16Serializer) WriteData(ctx *WriteContext, value reflect.Value) { + // Value is effectively uint16 (alias) + // We can use WriteUint16, but we check if it is indeed float16 compatible + // The value comes from reflection, likely an interface or concrete type + // Since Float16 is uint16, value.Uint() works. + ctx.buffer.WriteUint16(uint16(value.Uint())) +} + +func (s float16Serializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + if refMode != RefModeNone { + ctx.buffer.WriteInt8(NotNullValueFlag) + } + if writeType { + ctx.buffer.WriteVarUint32Small7(uint32(FLOAT16)) + } + s.WriteData(ctx, value) +} + +func (s float16Serializer) ReadData(ctx *ReadContext, value reflect.Value) { + err := ctx.Err() + // Read uint16 bits + bits := ctx.buffer.ReadUint16(err) + if ctx.HasError() { + return + } + // Set the value. Since Float16 is uint16, SetUint works. + value.SetUint(uint64(bits)) +} + +func (s float16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + err := ctx.Err() + if refMode != RefModeNone { + if ctx.buffer.ReadInt8(err) == NullFlag { + return + } + } + if readType { + _ = ctx.buffer.ReadVarUint32Small7(err) + } + if ctx.HasError() { + return + } + s.ReadData(ctx, value) +} + +func (s float16Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} diff --git a/go/fory/primitive_test.go b/go/fory/primitive_test.go new file mode 100644 index 0000000000..978a81d46b --- /dev/null +++ b/go/fory/primitive_test.go @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "testing" + + "github.com/apache/fory/go/fory/float16" + "github.com/stretchr/testify/require" +) + +func TestFloat16Primitive(t *testing.T) { + f := New(WithXlang(true)) + f16 := float16.Float16FromFloat32(3.14) + + // Directly serialize a float16 value + data, err := f.Serialize(f16) + require.NoError(t, err) + + var res float16.Float16 + err = f.Deserialize(data, &res) + require.NoError(t, err) + + require.True(t, f16.Equal(res)) + + // Value check (approximate) + require.InDelta(t, 3.14, res.Float32(), 0.01) +} + +func TestFloat16PrimitiveSliceDirect(t *testing.T) { + // Tests serializing a slice as a root object + f := New(WithXlang(true)) + f16 := float16.Float16FromFloat32(3.14) + + slice := []float16.Float16{f16, float16.Zero} + data, err := f.Serialize(slice) + require.NoError(t, err) + + var resSlice []float16.Float16 + err = f.Deserialize(data, &resSlice) + require.NoError(t, err) + require.Equal(t, slice, resSlice) +} diff --git a/go/fory/reader.go b/go/fory/reader.go index bcbd1048c3..dd9a837bb0 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -222,6 +222,8 @@ func (c *ReadContext) readFast(ptr unsafe.Pointer, ct DispatchId) { *(*float32)(ptr) = c.buffer.ReadFloat32(err) case PrimitiveFloat64DispatchId: *(*float64)(ptr) = c.buffer.ReadFloat64(err) + case PrimitiveFloat16DispatchId: + *(*uint16)(ptr) = c.buffer.ReadUint16(err) case StringDispatchId: *(*string)(ptr) = readString(c.buffer, err) } diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index 36818de3e1..cf83ab76b8 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -21,6 +21,8 @@ import ( "reflect" "strconv" "unsafe" + + "github.com/apache/fory/go/fory/float16" ) // isNilSlice checks if a value is a nil slice. Safe to call on any value type. @@ -828,6 +830,91 @@ func ReadFloat64Slice(buf *ByteBuffer, err *Error) []float64 { return result } +// ============================================================================ +// float16SliceSerializer - optimized []float16.Float16 serialization +// ============================================================================ + +type float16SliceSerializer struct{} + +func (s float16SliceSerializer) WriteData(ctx *WriteContext, value reflect.Value) { + // Cast to []float16.Float16 + v := value.Interface().([]float16.Float16) + buf := ctx.Buffer() + length := len(v) + size := length * 2 + buf.WriteLength(size) + if length > 0 { + // Float16 is uint16 underneath, so we can cast slice pointer + ptr := unsafe.Pointer(&v[0]) + if isLittleEndian { + buf.WriteBinary(unsafe.Slice((*byte)(ptr), size)) + } else { + // Big-endian architectures need explicit byte swapping + for i := 0; i < length; i++ { + // We can just write as uint16, WriteUint16 handles endianness for us + // Float16.Bits() returns uint16 + buf.WriteUint16(v[i].Bits()) + } + } + } +} + +func (s float16SliceSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + done := writeSliceRefAndType(ctx, refMode, writeType, value, FLOAT16_ARRAY) + if done || ctx.HasError() { + return + } + s.WriteData(ctx, value) +} + +func (s float16SliceSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + done, typeId := readSliceRefAndType(ctx, refMode, readType, value) + if done || ctx.HasError() { + return + } + if readType && typeId != uint32(FLOAT16_ARRAY) { + ctx.SetError(DeserializationErrorf("slice type mismatch: expected FLOAT16_ARRAY (%d), got %d", FLOAT16_ARRAY, typeId)) + return + } + s.ReadData(ctx, value) +} + +func (s float16SliceSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} + +func (s float16SliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { + buf := ctx.Buffer() + ctxErr := ctx.Err() + size := buf.ReadLength(ctxErr) + length := size / 2 + if ctx.HasError() { + return + } + + // Ensure capacity + ptr := (*[]float16.Float16)(value.Addr().UnsafePointer()) + if length == 0 { + *ptr = make([]float16.Float16, 0) + return + } + + result := make([]float16.Float16, length) + + if isLittleEndian { + raw := buf.ReadBinary(size, ctxErr) + // unsafe copy + targetPtr := unsafe.Pointer(&result[0]) + copy(unsafe.Slice((*byte)(targetPtr), size), raw) + } else { + for i := 0; i < length; i++ { + // ReadUint16 handles endianness + result[i] = float16.Float16FromBits(buf.ReadUint16(ctxErr)) + } + } + *ptr = result +} + // WriteIntSlice writes []int to buffer using ARRAY protocol // //go:inline diff --git a/go/fory/slice_primitive_test.go b/go/fory/slice_primitive_test.go new file mode 100644 index 0000000000..61ff7b89c8 --- /dev/null +++ b/go/fory/slice_primitive_test.go @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "testing" + + "github.com/apache/fory/go/fory/float16" + "github.com/stretchr/testify/assert" +) + +func TestFloat16Slice(t *testing.T) { + f := NewFory() + + t.Run("float16_slice", func(t *testing.T) { + slice := []float16.Float16{ + float16.Float16FromFloat32(1.0), + float16.Float16FromFloat32(2.5), + float16.Float16FromFloat32(-0.5), + } + data, err := f.Serialize(slice) + assert.NoError(t, err) + + var result []float16.Float16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.Equal(t, slice, result) + }) + + t.Run("float16_slice_empty", func(t *testing.T) { + slice := []float16.Float16{} + data, err := f.Serialize(slice) + assert.NoError(t, err) + + var result []float16.Float16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Empty(t, result) + }) + + t.Run("float16_slice_nil", func(t *testing.T) { + var slice []float16.Float16 = nil + data, err := f.Serialize(slice) + assert.NoError(t, err) + + var result []float16.Float16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.Nil(t, result) + }) +} diff --git a/go/fory/struct.go b/go/fory/struct.go index 828dd28cd5..ff4aaece94 100644 --- a/go/fory/struct.go +++ b/go/fory/struct.go @@ -1389,6 +1389,17 @@ func (s *structSerializer) WriteData(ctx *WriteContext, value reflect.Value) { } else { binary.LittleEndian.PutUint64(data[bufOffset:], math.Float64bits(v)) } + case PrimitiveFloat16DispatchId: + v, ok := loadFieldValue[uint16](field.Kind, fieldPtr, optInfo) + if !ok { + v = 0 + } + if isLittleEndian { + *(*uint16)(unsafe.Pointer(&data[bufOffset])) = v + } else { + binary.LittleEndian.PutUint16(data[bufOffset:], v) + } + } } // Update writer index ONCE after all fixed fields @@ -2562,6 +2573,16 @@ func (s *structSerializer) ReadData(ctx *ReadContext, value reflect.Value) { v = math.Float64frombits(binary.LittleEndian.Uint64(data[bufOffset:])) } storeFieldValue(field.Kind, fieldPtr, optInfo, v) + case PrimitiveFloat16DispatchId: + var v uint16 + if isLittleEndian { + v = *(*uint16)(unsafe.Pointer(&data[bufOffset])) + } else { + v = binary.LittleEndian.Uint16(data[bufOffset:]) + } + // Float16 is underlying uint16, so we can store it directly + storeFieldValue(field.Kind, fieldPtr, optInfo, v) + } } // Update reader index ONCE after all fixed fields diff --git a/go/fory/struct_test.go b/go/fory/struct_test.go index 51fb806f3c..7f345b7240 100644 --- a/go/fory/struct_test.go +++ b/go/fory/struct_test.go @@ -21,6 +21,7 @@ import ( "reflect" "testing" + "github.com/apache/fory/go/fory/float16" "github.com/apache/fory/go/fory/optional" "github.com/stretchr/testify/require" ) @@ -474,3 +475,36 @@ func TestSkipAnyValueReadsSharedTypeMeta(t *testing.T) { require.True(t, ok) require.Equal(t, "ok", result.Name) } + +func TestFloat16StructField(t *testing.T) { + type StructWithFloat16 struct { + F16 float16.Float16 + SliceF16 []float16.Float16 + ArrayF16 [3]float16.Float16 + } + + f := New(WithXlang(true)) + require.NoError(t, f.RegisterStruct(StructWithFloat16{}, 3001)) + + val := &StructWithFloat16{ + F16: float16.Float16FromFloat32(1.5), + SliceF16: []float16.Float16{float16.Float16FromFloat32(1.0), float16.Float16FromFloat32(2.5)}, + ArrayF16: [3]float16.Float16{float16.Zero, float16.One, float16.NegZero}, + } + + data, err := f.Serialize(val) + require.NoError(t, err) + + // Create new instance + res := &StructWithFloat16{} + err = f.Deserialize(data, res) + require.NoError(t, err) + + // Verify + require.Equal(t, val.F16, res.F16) + require.Equal(t, val.SliceF16, res.SliceF16) + require.Equal(t, val.ArrayF16, res.ArrayF16) + + // Specific value check + require.Equal(t, float32(1.5), res.F16.Float32()) +} diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index e3ef79823b..0069dfc0da 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -29,6 +29,7 @@ import ( "time" "unsafe" + "github.com/apache/fory/go/fory/float16" "github.com/apache/fory/go/fory/meta" ) @@ -68,6 +69,7 @@ var ( uintSliceType = reflect.TypeOf((*[]uint)(nil)).Elem() float32SliceType = reflect.TypeOf((*[]float32)(nil)).Elem() float64SliceType = reflect.TypeOf((*[]float64)(nil)).Elem() + float16SliceType = reflect.TypeOf((*[]float16.Float16)(nil)).Elem() interfaceSliceType = reflect.TypeOf((*[]any)(nil)).Elem() interfaceMapType = reflect.TypeOf((*map[any]any)(nil)).Elem() stringStringMapType = reflect.TypeOf((*map[string]string)(nil)).Elem() @@ -93,6 +95,7 @@ var ( intType = reflect.TypeOf((*int)(nil)).Elem() float32Type = reflect.TypeOf((*float32)(nil)).Elem() float64Type = reflect.TypeOf((*float64)(nil)).Elem() + float16Type = reflect.TypeOf((*float16.Float16)(nil)).Elem() dateType = reflect.TypeOf((*Date)(nil)).Elem() timestampType = reflect.TypeOf((*time.Time)(nil)).Elem() genericSetType = reflect.TypeOf((*Set[any])(nil)).Elem() @@ -243,6 +246,7 @@ func newTypeResolver(fory *Fory) *TypeResolver { int64Type, float32Type, float64Type, + float16Type, stringType, dateType, timestampType, @@ -396,6 +400,7 @@ func (r *TypeResolver) initialize() { {uintSliceType, INT64_ARRAY, uintSliceSerializer{}}, {float32SliceType, FLOAT32_ARRAY, float32SliceSerializer{}}, {float64SliceType, FLOAT64_ARRAY, float64SliceSerializer{}}, + {float16SliceType, FLOAT16_ARRAY, float16SliceSerializer{}}, // Register common map types for fast path with optimized serializers {stringStringMapType, MAP, stringStringMapSerializer{}}, {stringInt64MapType, MAP, stringInt64MapSerializer{}}, @@ -419,6 +424,7 @@ func (r *TypeResolver) initialize() { {intType, VARINT64, intSerializer{}}, // int maps to int64 for xlang {float32Type, FLOAT32, float32Serializer{}}, {float64Type, FLOAT64, float64Serializer{}}, + {float16Type, FLOAT16, float16Serializer{}}, {dateType, DATE, dateSerializer{}}, {timestampType, TIMESTAMP, timeSerializer{}}, {genericSetType, SET, setSerializer{}}, @@ -426,7 +432,7 @@ func (r *TypeResolver) initialize() { for _, elem := range serializers { _, err := r.registerType(elem.Type, uint32(elem.TypeId), "", "", elem.Serializer, true) if err != nil { - fmt.Errorf("init type error: %v", err) + panic(fmt.Errorf("init type error: %v", err)) } } @@ -1610,6 +1616,11 @@ func (r *TypeResolver) createSerializer(type_ reflect.Type, mapInStruct bool) (s } return int32ArraySerializer{arrayType: type_}, nil case reflect.Uint16: + // Check for fory.Float16 (aliased to uint16) + // Check name first to avoid slow PkgPath call + if elem.Name() == "Float16" && (elem.PkgPath() == "github.com/apache/fory/go/fory/float16" || strings.HasSuffix(elem.PkgPath(), "/float16")) { + return float16ArraySerializer{arrayType: type_}, nil + } return uint16ArraySerializer{arrayType: type_}, nil case reflect.Uint32: return uint32ArraySerializer{arrayType: type_}, nil diff --git a/go/fory/types.go b/go/fory/types.go index df91124514..5227cdbad2 100644 --- a/go/fory/types.go +++ b/go/fory/types.go @@ -17,7 +17,10 @@ package fory -import "reflect" +import ( + "reflect" + "strings" +) type TypeId = int16 @@ -315,6 +318,7 @@ const ( PrimitiveUint64DispatchId // 17 - uint64 with fixed encoding PrimitiveFloat32DispatchId // 18 PrimitiveFloat64DispatchId // 19 + PrimitiveFloat16DispatchId // 20 // ========== NULLABLE DISPATCH IDs ========== NullableBoolDispatchId @@ -327,6 +331,7 @@ const ( NullableTaggedInt64DispatchId NullableFloat32DispatchId NullableFloat64DispatchId + NullableFloat16DispatchId NullableUint8DispatchId NullableUint16DispatchId NullableUint32DispatchId @@ -350,6 +355,7 @@ const ( UintSliceDispatchId Float32SliceDispatchId Float64SliceDispatchId + Float16SliceDispatchId BoolSliceDispatchId StringSliceDispatchId @@ -390,6 +396,10 @@ func GetDispatchId(t reflect.Type) DispatchId { case reflect.Uint8: return PrimitiveUint8DispatchId case reflect.Uint16: + // Check for fory.Float16 (aliased to uint16) + if t.Name() == "Float16" && (t.PkgPath() == "github.com/apache/fory/go/fory/float16" || strings.HasSuffix(t.PkgPath(), "/float16")) { + return PrimitiveFloat16DispatchId + } return PrimitiveUint16DispatchId case reflect.Uint32: // Default to varint encoding (VAR_UINT32) for xlang compatibility @@ -426,6 +436,15 @@ func GetDispatchId(t reflect.Type) DispatchId { return Float32SliceDispatchId case reflect.Float64: return Float64SliceDispatchId + case reflect.Uint16: + // Check if it's float16 slice + if t.Elem().Name() == "Float16" && (t.Elem().PkgPath() == "github.com/apache/fory/go/fory/float16" || strings.HasSuffix(t.Elem().PkgPath(), "/float16")) { + return Float16SliceDispatchId + } + // Use Int16SliceDispatchId for Uint16 as they share the same 2-byte size + // and serialization logic in many cases, or it falls back to generic if needed. + return Int16SliceDispatchId + case reflect.Bool: return BoolSliceDispatchId case reflect.String: @@ -480,7 +499,8 @@ func isFixedSizePrimitive(dispatchId DispatchId) bool { PrimitiveInt16DispatchId, PrimitiveUint16DispatchId, PrimitiveInt32DispatchId, PrimitiveUint32DispatchId, PrimitiveInt64DispatchId, PrimitiveUint64DispatchId, - PrimitiveFloat32DispatchId, PrimitiveFloat64DispatchId: + PrimitiveFloat32DispatchId, PrimitiveFloat64DispatchId, + PrimitiveFloat16DispatchId: return true default: return false @@ -495,7 +515,8 @@ func isNullableFixedSizePrimitive(dispatchId DispatchId) bool { NullableInt16DispatchId, NullableUint16DispatchId, NullableInt32DispatchId, NullableUint32DispatchId, NullableInt64DispatchId, NullableUint64DispatchId, - NullableFloat32DispatchId, NullableFloat64DispatchId: + NullableFloat32DispatchId, NullableFloat64DispatchId, + NullableFloat16DispatchId: return true default: return false @@ -536,7 +557,7 @@ func isPrimitiveDispatchId(dispatchId DispatchId) bool { case PrimitiveBoolDispatchId, PrimitiveInt8DispatchId, PrimitiveInt16DispatchId, PrimitiveInt32DispatchId, PrimitiveInt64DispatchId, PrimitiveIntDispatchId, PrimitiveUint8DispatchId, PrimitiveUint16DispatchId, PrimitiveUint32DispatchId, PrimitiveUint64DispatchId, PrimitiveUintDispatchId, - PrimitiveFloat32DispatchId, PrimitiveFloat64DispatchId: + PrimitiveFloat32DispatchId, PrimitiveFloat64DispatchId, PrimitiveFloat16DispatchId: return true default: return false @@ -558,7 +579,7 @@ func isPrimitiveDispatchKind(kind reflect.Kind) bool { switch kind { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Float32, reflect.Float64: + reflect.Float32, reflect.Float64: // Note: Float16 is uint16 kind, handled by dispatch ID logic return true default: return false @@ -589,6 +610,8 @@ func getDispatchIdFromTypeId(typeId TypeId, nullable bool) DispatchId { return NullableTaggedInt64DispatchId case FLOAT32: return NullableFloat32DispatchId + case FLOAT16: + return NullableFloat16DispatchId case FLOAT64: return NullableFloat64DispatchId case UINT8: @@ -631,6 +654,8 @@ func getDispatchIdFromTypeId(typeId TypeId, nullable bool) DispatchId { return PrimitiveTaggedInt64DispatchId case FLOAT32: return PrimitiveFloat32DispatchId + case FLOAT16: + return PrimitiveFloat16DispatchId case FLOAT64: return PrimitiveFloat64DispatchId case UINT8: @@ -674,6 +699,8 @@ func getFixedSizeByDispatchId(dispatchId DispatchId) int { return 2 case PrimitiveInt32DispatchId, PrimitiveUint32DispatchId, PrimitiveFloat32DispatchId: return 4 + case PrimitiveFloat16DispatchId: + return 2 case PrimitiveInt64DispatchId, PrimitiveUint64DispatchId, PrimitiveFloat64DispatchId: return 8 default: @@ -721,7 +748,8 @@ func isPrimitiveFixedDispatchId(id DispatchId) bool { // Fixed-size int32/int64/uint32/uint64 - only when explicitly specified via TypeId PrimitiveInt32DispatchId, PrimitiveUint32DispatchId, PrimitiveInt64DispatchId, PrimitiveUint64DispatchId, - PrimitiveFloat32DispatchId, PrimitiveFloat64DispatchId: + PrimitiveFloat32DispatchId, PrimitiveFloat64DispatchId, + PrimitiveFloat16DispatchId: return true default: return false @@ -737,6 +765,8 @@ func getFixedSizeByPrimitiveDispatchId(id DispatchId) int { return 2 case PrimitiveInt32DispatchId, PrimitiveUint32DispatchId, PrimitiveFloat32DispatchId: return 4 + case PrimitiveFloat16DispatchId: + return 2 case PrimitiveInt64DispatchId, PrimitiveUint64DispatchId, PrimitiveFloat64DispatchId: return 8 default: diff --git a/go/fory/writer.go b/go/fory/writer.go index e49bca8419..27907a5870 100644 --- a/go/fory/writer.go +++ b/go/fory/writer.go @@ -203,6 +203,9 @@ func (c *WriteContext) writeFast(ptr unsafe.Pointer, ct DispatchId) { c.buffer.WriteFloat32(*(*float32)(ptr)) case PrimitiveFloat64DispatchId: c.buffer.WriteFloat64(*(*float64)(ptr)) + case PrimitiveFloat16DispatchId: + // Float16 is uint16 in Go + c.buffer.WriteUint16(*(*uint16)(ptr)) case StringDispatchId: writeString(c.buffer, *(*string)(ptr)) }