378 lines
7.4 KiB
Go
378 lines
7.4 KiB
Go
|
/*
|
||
|
Copyright 2014 SAP SE
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
*/
|
||
|
|
||
|
package driver
|
||
|
|
||
|
import (
|
||
|
"database/sql/driver"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"math"
|
||
|
"math/big"
|
||
|
"sync"
|
||
|
)
|
||
|
|
||
|
//bigint word size (*--> src/pkg/math/big/arith.go)
|
||
|
const (
|
||
|
// Compute the size _S of a Word in bytes.
|
||
|
_m = ^big.Word(0)
|
||
|
_logS = _m>>8&1 + _m>>16&1 + _m>>32&1
|
||
|
_S = 1 << _logS
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
// http://en.wikipedia.org/wiki/Decimal128_floating-point_format
|
||
|
dec128Digits = 34
|
||
|
dec128Bias = 6176
|
||
|
dec128MinExp = -6176
|
||
|
dec128MaxExp = 6111
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
decimalSize = 16 //number of bytes
|
||
|
)
|
||
|
|
||
|
var natZero = big.NewInt(0)
|
||
|
var natOne = big.NewInt(1)
|
||
|
var natTen = big.NewInt(10)
|
||
|
|
||
|
var nat = []*big.Int{
|
||
|
natOne, //10^0
|
||
|
natTen, //10^1
|
||
|
big.NewInt(100), //10^2
|
||
|
big.NewInt(1000), //10^3
|
||
|
big.NewInt(10000), //10^4
|
||
|
big.NewInt(100000), //10^5
|
||
|
big.NewInt(1000000), //10^6
|
||
|
big.NewInt(10000000), //10^7
|
||
|
big.NewInt(100000000), //10^8
|
||
|
big.NewInt(1000000000), //10^9
|
||
|
big.NewInt(10000000000), //10^10
|
||
|
}
|
||
|
|
||
|
const lg10 = math.Ln10 / math.Ln2 // ~log2(10)
|
||
|
|
||
|
var maxDecimal = new(big.Int).SetBytes([]byte{0x01, 0xED, 0x09, 0xBE, 0xAD, 0x87, 0xC0, 0x37, 0x8D, 0x8E, 0x63, 0xFF, 0xFF, 0xFF, 0xFF})
|
||
|
|
||
|
type decFlags byte
|
||
|
|
||
|
const (
|
||
|
dfNotExact decFlags = 1 << iota
|
||
|
dfOverflow
|
||
|
dfUnderflow
|
||
|
)
|
||
|
|
||
|
// ErrDecimalOutOfRange means that a big.Rat exceeds the size of hdb decimal fields.
|
||
|
var ErrDecimalOutOfRange = errors.New("decimal out of range error")
|
||
|
|
||
|
// big.Int free list
|
||
|
var bigIntFree = sync.Pool{
|
||
|
New: func() interface{} { return new(big.Int) },
|
||
|
}
|
||
|
|
||
|
// big.Rat free list
|
||
|
var bigRatFree = sync.Pool{
|
||
|
New: func() interface{} { return new(big.Rat) },
|
||
|
}
|
||
|
|
||
|
// A Decimal is the driver representation of a database decimal field value as big.Rat.
|
||
|
type Decimal big.Rat
|
||
|
|
||
|
// Scan implements the database/sql/Scanner interface.
|
||
|
func (d *Decimal) Scan(src interface{}) error {
|
||
|
|
||
|
b, ok := src.([]byte)
|
||
|
if !ok {
|
||
|
return fmt.Errorf("decimal: invalid data type %T", src)
|
||
|
}
|
||
|
|
||
|
if len(b) != decimalSize {
|
||
|
return fmt.Errorf("decimal: invalid size %d of %v - %d expected", len(b), b, decimalSize)
|
||
|
}
|
||
|
|
||
|
if (b[15] & 0x60) == 0x60 {
|
||
|
return fmt.Errorf("decimal: format (infinity, nan, ...) not supported : %v", b)
|
||
|
}
|
||
|
|
||
|
v := (*big.Rat)(d)
|
||
|
p := v.Num()
|
||
|
q := v.Denom()
|
||
|
|
||
|
neg, exp := decodeDecimal(b, p)
|
||
|
|
||
|
switch {
|
||
|
case exp < 0:
|
||
|
q.Set(exp10(exp * -1))
|
||
|
case exp == 0:
|
||
|
q.Set(natOne)
|
||
|
case exp > 0:
|
||
|
p.Mul(p, exp10(exp))
|
||
|
q.Set(natOne)
|
||
|
}
|
||
|
|
||
|
if neg {
|
||
|
v.Neg(v)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Value implements the database/sql/Valuer interface.
|
||
|
func (d Decimal) Value() (driver.Value, error) {
|
||
|
m := bigIntFree.Get().(*big.Int)
|
||
|
neg, exp, df := convertRatToDecimal((*big.Rat)(&d), m, dec128Digits, dec128MinExp, dec128MaxExp)
|
||
|
|
||
|
var v driver.Value
|
||
|
var err error
|
||
|
|
||
|
switch {
|
||
|
default:
|
||
|
v, err = encodeDecimal(m, neg, exp)
|
||
|
case df&dfUnderflow != 0: // set to zero
|
||
|
m.Set(natZero)
|
||
|
v, err = encodeDecimal(m, false, 0)
|
||
|
case df&dfOverflow != 0:
|
||
|
err = ErrDecimalOutOfRange
|
||
|
}
|
||
|
|
||
|
// performance (avoid expensive defer)
|
||
|
bigIntFree.Put(m)
|
||
|
|
||
|
return v, err
|
||
|
}
|
||
|
|
||
|
func convertRatToDecimal(x *big.Rat, m *big.Int, digits, minExp, maxExp int) (bool, int, decFlags) {
|
||
|
|
||
|
neg := x.Sign() < 0 //store sign
|
||
|
|
||
|
if x.Num().Cmp(natZero) == 0 { // zero
|
||
|
m.Set(natZero)
|
||
|
return neg, 0, 0
|
||
|
}
|
||
|
|
||
|
c := bigRatFree.Get().(*big.Rat).Abs(x) // copy && abs
|
||
|
a := c.Num()
|
||
|
b := c.Denom()
|
||
|
|
||
|
exp, shift := 0, 0
|
||
|
|
||
|
if c.IsInt() {
|
||
|
exp = digits10(a) - 1
|
||
|
} else {
|
||
|
shift = digits10(a) - digits10(b)
|
||
|
switch {
|
||
|
case shift < 0:
|
||
|
a.Mul(a, exp10(shift*-1))
|
||
|
case shift > 0:
|
||
|
b.Mul(b, exp10(shift))
|
||
|
}
|
||
|
if a.Cmp(b) == -1 {
|
||
|
exp = shift - 1
|
||
|
} else {
|
||
|
exp = shift
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var df decFlags
|
||
|
|
||
|
switch {
|
||
|
default:
|
||
|
exp = max(exp-digits+1, minExp)
|
||
|
case exp < minExp:
|
||
|
df |= dfUnderflow
|
||
|
exp = exp - digits + 1
|
||
|
}
|
||
|
|
||
|
if exp > maxExp {
|
||
|
df |= dfOverflow
|
||
|
}
|
||
|
|
||
|
shift = exp - shift
|
||
|
switch {
|
||
|
case shift < 0:
|
||
|
a.Mul(a, exp10(shift*-1))
|
||
|
case exp > 0:
|
||
|
b.Mul(b, exp10(shift))
|
||
|
}
|
||
|
|
||
|
m.QuoRem(a, b, a) // reuse a as rest
|
||
|
if a.Cmp(natZero) != 0 {
|
||
|
// round (business >= 0.5 up)
|
||
|
df |= dfNotExact
|
||
|
if a.Add(a, a).Cmp(b) >= 0 {
|
||
|
m.Add(m, natOne)
|
||
|
if m.Cmp(exp10(digits)) == 0 {
|
||
|
shift := min(digits, maxExp-exp)
|
||
|
if shift < 1 { // overflow -> shift one at minimum
|
||
|
df |= dfOverflow
|
||
|
shift = 1
|
||
|
}
|
||
|
m.Set(exp10(digits - shift))
|
||
|
exp += shift
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// norm
|
||
|
for exp < maxExp {
|
||
|
a.QuoRem(m, natTen, b) // reuse a, b
|
||
|
if b.Cmp(natZero) != 0 {
|
||
|
break
|
||
|
}
|
||
|
m.Set(a)
|
||
|
exp++
|
||
|
}
|
||
|
|
||
|
// performance (avoid expensive defer)
|
||
|
bigRatFree.Put(c)
|
||
|
|
||
|
return neg, exp, df
|
||
|
}
|
||
|
|
||
|
func min(a, b int) int {
|
||
|
if a < b {
|
||
|
return a
|
||
|
}
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
func max(a, b int) int {
|
||
|
if a > b {
|
||
|
return a
|
||
|
}
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
// performance: tested with reference work variable
|
||
|
// - but int.Set is expensive, so let's live with big.Int creation for n >= len(nat)
|
||
|
func exp10(n int) *big.Int {
|
||
|
if n < len(nat) {
|
||
|
return nat[n]
|
||
|
}
|
||
|
r := big.NewInt(int64(n))
|
||
|
return r.Exp(natTen, r, nil)
|
||
|
}
|
||
|
|
||
|
func digits10(p *big.Int) int {
|
||
|
k := p.BitLen() // 2^k <= p < 2^(k+1) - 1
|
||
|
//i := int(float64(k) / lg10) //minimal digits base 10
|
||
|
//i := int(float64(k) / lg10) //minimal digits base 10
|
||
|
i := k * 100 / 332
|
||
|
if i < 1 {
|
||
|
i = 1
|
||
|
}
|
||
|
|
||
|
for ; ; i++ {
|
||
|
if p.Cmp(exp10(i)) < 0 {
|
||
|
return i
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func decodeDecimal(b []byte, m *big.Int) (bool, int) {
|
||
|
|
||
|
neg := (b[15] & 0x80) != 0
|
||
|
exp := int((((uint16(b[15])<<8)|uint16(b[14]))<<1)>>2) - dec128Bias
|
||
|
|
||
|
b14 := b[14] // save b[14]
|
||
|
b[14] &= 0x01 // keep the mantissa bit (rest: sign and exp)
|
||
|
|
||
|
//most significand byte
|
||
|
msb := 14
|
||
|
for msb > 0 {
|
||
|
if b[msb] != 0 {
|
||
|
break
|
||
|
}
|
||
|
msb--
|
||
|
}
|
||
|
|
||
|
//calc number of words
|
||
|
numWords := (msb / _S) + 1
|
||
|
w := make([]big.Word, numWords)
|
||
|
|
||
|
k := numWords - 1
|
||
|
d := big.Word(0)
|
||
|
for i := msb; i >= 0; i-- {
|
||
|
d |= big.Word(b[i])
|
||
|
if k*_S == i {
|
||
|
w[k] = d
|
||
|
k--
|
||
|
d = 0
|
||
|
}
|
||
|
d <<= 8
|
||
|
}
|
||
|
b[14] = b14 // restore b[14]
|
||
|
m.SetBits(w)
|
||
|
return neg, exp
|
||
|
}
|
||
|
|
||
|
func encodeDecimal(m *big.Int, neg bool, exp int) (driver.Value, error) {
|
||
|
|
||
|
b := make([]byte, decimalSize)
|
||
|
|
||
|
// little endian bigint words (significand) -> little endian db decimal format
|
||
|
j := 0
|
||
|
for _, d := range m.Bits() {
|
||
|
for i := 0; i < 8; i++ {
|
||
|
b[j] = byte(d)
|
||
|
d >>= 8
|
||
|
j++
|
||
|
}
|
||
|
}
|
||
|
|
||
|
exp += dec128Bias
|
||
|
b[14] |= (byte(exp) << 1)
|
||
|
b[15] = byte(uint16(exp) >> 7)
|
||
|
|
||
|
if neg {
|
||
|
b[15] |= 0x80
|
||
|
}
|
||
|
|
||
|
return b, nil
|
||
|
}
|
||
|
|
||
|
// NullDecimal represents an Decimal that may be null.
|
||
|
// NullDecimal implements the Scanner interface so
|
||
|
// it can be used as a scan destination, similar to NullString.
|
||
|
type NullDecimal struct {
|
||
|
Decimal *Decimal
|
||
|
Valid bool // Valid is true if Decimal is not NULL
|
||
|
}
|
||
|
|
||
|
// Scan implements the Scanner interface.
|
||
|
func (n *NullDecimal) Scan(value interface{}) error {
|
||
|
var b []byte
|
||
|
|
||
|
b, n.Valid = value.([]byte)
|
||
|
if !n.Valid {
|
||
|
return nil
|
||
|
}
|
||
|
if n.Decimal == nil {
|
||
|
return fmt.Errorf("invalid decimal value %v", n.Decimal)
|
||
|
}
|
||
|
return n.Decimal.Scan(b)
|
||
|
}
|
||
|
|
||
|
// Value implements the driver Valuer interface.
|
||
|
func (n NullDecimal) Value() (driver.Value, error) {
|
||
|
if !n.Valid {
|
||
|
return nil, nil
|
||
|
}
|
||
|
if n.Decimal == nil {
|
||
|
return nil, fmt.Errorf("invalid decimal value %v", n.Decimal)
|
||
|
}
|
||
|
return n.Decimal.Value()
|
||
|
}
|