/* Copyright 2014 SAP SE Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package driver import ( "database/sql/driver" "errors" "fmt" "math" "reflect" "time" p "github.com/SAP/go-hdb/internal/protocol" ) const ( minTinyint = 0 maxTinyint = math.MaxUint8 minSmallint = math.MinInt16 maxSmallint = math.MaxInt16 minInteger = math.MinInt32 maxInteger = math.MaxInt32 minBigint = math.MinInt64 maxBigint = math.MaxInt64 maxReal = math.MaxFloat32 maxDouble = math.MaxFloat64 ) // ErrIntegerOutOfRange means that an integer exceeds the size of the hdb integer field. var ErrIntegerOutOfRange = errors.New("integer out of range error") // ErrFloatOutOfRange means that a float exceeds the size of the hdb float field. var ErrFloatOutOfRange = errors.New("float out of range error") var typeOfTime = reflect.TypeOf((*time.Time)(nil)).Elem() var typeOfBytes = reflect.TypeOf((*[]byte)(nil)).Elem() func checkNamedValue(prmFieldSet *p.ParameterFieldSet, nv *driver.NamedValue) error { idx := nv.Ordinal - 1 if idx >= prmFieldSet.NumInputField() { return nil } f := prmFieldSet.Field(idx) dt := f.TypeCode().DataType() value, err := convertNamedValue(idx, f, dt, nv.Value) if err != nil { return err } nv.Value = value return nil } func convertNamedValue(idx int, f *p.ParameterField, dt p.DataType, v driver.Value) (driver.Value, error) { var err error // let fields with own Value converter convert themselves first (e.g. NullInt64, ...) if _, ok := v.(driver.Valuer); ok { if v, err = driver.DefaultParameterConverter.ConvertValue(v); err != nil { return nil, err } } switch dt { default: return nil, fmt.Errorf("convert named value datatype error: %[1]d - %[1]s", dt) case p.DtTinyint: return convertNvInteger(v, minTinyint, maxTinyint) case p.DtSmallint: return convertNvInteger(v, minSmallint, maxSmallint) case p.DtInteger: return convertNvInteger(v, minInteger, maxInteger) case p.DtBigint: return convertNvInteger(v, minBigint, maxBigint) case p.DtReal: return convertNvFloat(v, maxReal) case p.DtDouble: return convertNvFloat(v, maxDouble) case p.DtTime: return convertNvTime(v) case p.DtDecimal: return convertNvDecimal(v) case p.DtString: return convertNvString(v) case p.DtBytes: return convertNvBytes(v) case p.DtLob: return convertNvLob(idx, f, v) } } // integer types func convertNvInteger(v interface{}, min, max int64) (driver.Value, error) { if v == nil { return v, nil } rv := reflect.ValueOf(v) switch rv.Kind() { // bool is represented in HDB as tinyint case reflect.Bool: return rv.Bool(), nil case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i64 := rv.Int() if i64 > max || i64 < min { return nil, ErrIntegerOutOfRange } return i64, nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: u64 := rv.Uint() if u64 > uint64(max) { return nil, ErrIntegerOutOfRange } return int64(u64), nil case reflect.Ptr: // indirect pointers if rv.IsNil() { return nil, nil } return convertNvInteger(rv.Elem().Interface(), min, max) } return nil, fmt.Errorf("unsupported integer conversion type error %[1]T %[1]v", v) } // float types func convertNvFloat(v interface{}, max float64) (driver.Value, error) { if v == nil { return v, nil } rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Float32, reflect.Float64: f64 := rv.Float() if math.Abs(f64) > max { return nil, ErrFloatOutOfRange } return f64, nil case reflect.Ptr: // indirect pointers if rv.IsNil() { return nil, nil } return convertNvFloat(rv.Elem().Interface(), max) } return nil, fmt.Errorf("unsupported float conversion type error %[1]T %[1]v", v) } // time func convertNvTime(v interface{}) (driver.Value, error) { if v == nil { return nil, nil } switch v := v.(type) { case time.Time: return v, nil } rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: // indirect pointers if rv.IsNil() { return nil, nil } return convertNvTime(rv.Elem().Interface()) } if rv.Type().ConvertibleTo(typeOfTime) { tv := rv.Convert(typeOfTime) return tv.Interface().(time.Time), nil } return nil, fmt.Errorf("unsupported time conversion type error %[1]T %[1]v", v) } // decimal func convertNvDecimal(v interface{}) (driver.Value, error) { if v == nil { return nil, nil } if v, ok := v.([]byte); ok { return v, nil } return nil, fmt.Errorf("unsupported decimal conversion type error %[1]T %[1]v", v) } // string func convertNvString(v interface{}) (driver.Value, error) { if v == nil { return v, nil } switch v := v.(type) { case string, []byte: return v, nil } rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.String: return rv.String(), nil case reflect.Slice: if rv.Type() == typeOfBytes { return rv.Bytes(), nil } case reflect.Ptr: // indirect pointers if rv.IsNil() { return nil, nil } return convertNvString(rv.Elem().Interface()) } if rv.Type().ConvertibleTo(typeOfBytes) { bv := rv.Convert(typeOfBytes) return bv.Interface().([]byte), nil } return nil, fmt.Errorf("unsupported character conversion type error %[1]T %[1]v", v) } // bytes func convertNvBytes(v interface{}) (driver.Value, error) { if v == nil { return v, nil } if v, ok := v.([]byte); ok { return v, nil } rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Slice: if rv.Type() == typeOfBytes { return rv.Bytes(), nil } case reflect.Ptr: // indirect pointers if rv.IsNil() { return nil, nil } return convertNvBytes(rv.Elem().Interface()) } if rv.Type().ConvertibleTo(typeOfBytes) { bv := rv.Convert(typeOfBytes) return bv.Interface().([]byte), nil } return nil, fmt.Errorf("unsupported bytes conversion type error %[1]T %[1]v", v) } // Lob func convertNvLob(idx int, f *p.ParameterField, v interface{}) (driver.Value, error) { if v == nil { return v, nil } switch v := v.(type) { case Lob: if v.rd == nil { return nil, fmt.Errorf("lob error: initial reader %[1]T %[1]v", v) } f.SetLobReader(v.rd) return fmt.Sprintf("