open-vault/vendor/github.com/SAP/go-hdb/driver/driver.go
2017-08-02 12:53:45 -04:00

618 lines
14 KiB
Go

/*
Copyright 2014 SAP SE
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package driver
import (
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"regexp"
"sync"
"github.com/SAP/go-hdb/driver/sqltrace"
p "github.com/SAP/go-hdb/internal/protocol"
)
// DriverVersion is the version number of the hdb driver.
const DriverVersion = "0.9.1"
// DriverName is the driver name to use with sql.Open for hdb databases.
const DriverName = "hdb"
func init() {
sql.Register(DriverName, &drv{})
}
var reBulk = regexp.MustCompile("(?i)^(\\s)*(bulk +)(.*)")
func checkBulkInsert(sql string) (string, bool) {
if reBulk.MatchString(sql) {
return reBulk.ReplaceAllString(sql, "${3}"), true
}
return sql, false
}
var reCall = regexp.MustCompile("(?i)^(\\s)*(call +)(.*)")
func checkCallProcedure(sql string) bool {
return reCall.MatchString(sql)
}
var errProcTableQuery = errors.New("Invalid procedure table query")
// driver
type drv struct{}
func (d *drv) Open(dsn string) (driver.Conn, error) {
return newConn(dsn)
}
// database connection
type conn struct {
session *p.Session
}
func newConn(dsn string) (driver.Conn, error) {
sessionPrm, err := parseDSN(dsn)
if err != nil {
return nil, err
}
session, err := p.NewSession(sessionPrm)
if err != nil {
return nil, err
}
return &conn{session: session}, nil
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
if c.session.IsBad() {
return nil, driver.ErrBadConn
}
prepareQuery, bulkInsert := checkBulkInsert(query)
qt, id, parameterFieldSet, resultFieldSet, err := c.session.Prepare(prepareQuery)
if err != nil {
return nil, err
}
if bulkInsert {
return newBulkInsertStmt(c.session, prepareQuery, id, parameterFieldSet)
}
return newStmt(qt, c.session, prepareQuery, id, parameterFieldSet, resultFieldSet)
}
func (c *conn) Close() error {
c.session.Close()
return nil
}
func (c *conn) Begin() (driver.Tx, error) {
if c.session.IsBad() {
return nil, driver.ErrBadConn
}
if c.session.InTx() {
return nil, fmt.Errorf("nested transactions are not supported")
}
c.session.SetInTx(true)
return newTx(c.session), nil
}
// Exec implements the database/sql/driver/Execer interface.
func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
if c.session.IsBad() {
return nil, driver.ErrBadConn
}
if len(args) != 0 {
return nil, driver.ErrSkip //fast path not possible (prepare needed)
}
sqltrace.Traceln(query)
return c.session.ExecDirect(query)
}
// bug?: check args is performed indepently of queryer raising ErrSkip or not
// - leads to different behavior to prepare - stmt - execute default logic
// - seems to be the same for Execer interface
// Queryer implements the database/sql/driver/Queryer interface.
func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
if c.session.IsBad() {
return nil, driver.ErrBadConn
}
if len(args) != 0 {
return nil, driver.ErrSkip //fast path not possible (prepare needed)
}
// direct execution of call procedure
// - returns no parameter metadata (sps 82) but only field values
// --> let's take the 'prepare way' for stored procedures
if checkCallProcedure(query) {
return nil, driver.ErrSkip
}
sqltrace.Traceln(query)
id, idx, ok := decodeTableQuery(query)
if ok {
r := procedureCallResultStore.get(id)
if r == nil {
return nil, fmt.Errorf("invalid procedure table query %s", query)
}
return r.tableRows(int(idx))
}
id, meta, values, attributes, err := c.session.QueryDirect(query)
if err != nil {
return nil, err
}
if id == 0 { // non select query
return noResult, nil
}
return newQueryResult(c.session, id, meta, values, attributes)
}
//transaction
type tx struct {
session *p.Session
}
func newTx(session *p.Session) *tx {
return &tx{
session: session,
}
}
func (t *tx) Commit() error {
if t.session.IsBad() {
return driver.ErrBadConn
}
return t.session.Commit()
}
func (t *tx) Rollback() error {
if t.session.IsBad() {
return driver.ErrBadConn
}
return t.session.Rollback()
}
//statement
type stmt struct {
qt p.QueryType
session *p.Session
query string
id uint64
prmFieldSet *p.FieldSet
resultFieldSet *p.FieldSet
}
func newStmt(qt p.QueryType, session *p.Session, query string, id uint64, prmFieldSet *p.FieldSet, resultFieldSet *p.FieldSet) (*stmt, error) {
return &stmt{qt: qt, session: session, query: query, id: id, prmFieldSet: prmFieldSet, resultFieldSet: resultFieldSet}, nil
}
func (s *stmt) Close() error {
return s.session.DropStatementID(s.id)
}
func (s *stmt) NumInput() int {
return s.prmFieldSet.NumInputField()
}
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
if s.session.IsBad() {
return nil, driver.ErrBadConn
}
numField := s.prmFieldSet.NumInputField()
if len(args) != numField {
return nil, fmt.Errorf("invalid number of arguments %d - %d expected", len(args), numField)
}
sqltrace.Tracef("%s %v", s.query, args)
return s.session.Exec(s.id, s.prmFieldSet, args)
}
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
if s.session.IsBad() {
return nil, driver.ErrBadConn
}
switch s.qt {
default:
rows, err := s.defaultQuery(args)
return rows, err
case p.QtProcedureCall:
rows, err := s.procedureCall(args)
return rows, err
}
}
func (s *stmt) defaultQuery(args []driver.Value) (driver.Rows, error) {
sqltrace.Tracef("%s %v", s.query, args)
rid, values, attributes, err := s.session.Query(s.id, s.prmFieldSet, s.resultFieldSet, args)
if err != nil {
return nil, err
}
if rid == 0 { // non select query
return noResult, nil
}
return newQueryResult(s.session, rid, s.resultFieldSet, values, attributes)
}
func (s *stmt) procedureCall(args []driver.Value) (driver.Rows, error) {
sqltrace.Tracef("%s %v", s.query, args)
fieldValues, tableResults, err := s.session.Call(s.id, s.prmFieldSet, args)
if err != nil {
return nil, err
}
return newProcedureCallResult(s.session, s.prmFieldSet, fieldValues, tableResults)
}
func (s *stmt) ColumnConverter(idx int) driver.ValueConverter {
return columnConverter(s.prmFieldSet.DataType(idx))
}
// bulk insert statement
type bulkInsertStmt struct {
session *p.Session
query string
id uint64
parameterFieldSet *p.FieldSet
numArg int
args []driver.Value
}
func newBulkInsertStmt(session *p.Session, query string, id uint64, parameterFieldSet *p.FieldSet) (*bulkInsertStmt, error) {
return &bulkInsertStmt{session: session, query: query, id: id, parameterFieldSet: parameterFieldSet, args: make([]driver.Value, 0)}, nil
}
func (s *bulkInsertStmt) Close() error {
return s.session.DropStatementID(s.id)
}
func (s *bulkInsertStmt) NumInput() int {
return -1
}
func (s *bulkInsertStmt) Exec(args []driver.Value) (driver.Result, error) {
if s.session.IsBad() {
return nil, driver.ErrBadConn
}
sqltrace.Tracef("%s %v", s.query, args)
if args == nil || len(args) == 0 {
return s.execFlush()
}
return s.execBuffer(args)
}
func (s *bulkInsertStmt) execFlush() (driver.Result, error) {
if s.numArg == 0 {
return driver.ResultNoRows, nil
}
result, err := s.session.Exec(s.id, s.parameterFieldSet, s.args)
s.args = s.args[:0]
s.numArg = 0
return result, err
}
func (s *bulkInsertStmt) execBuffer(args []driver.Value) (driver.Result, error) {
numField := s.parameterFieldSet.NumInputField()
if len(args) != numField {
return nil, fmt.Errorf("invalid number of arguments %d - %d expected", len(args), numField)
}
var result driver.Result = driver.ResultNoRows
var err error
if s.numArg == maxSmallint { // TODO: check why bigArgument count does not work
result, err = s.execFlush()
}
s.args = append(s.args, args...)
s.numArg++
return result, err
}
func (s *bulkInsertStmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, fmt.Errorf("query not allowed in context of bulk insert statement %s", s.query)
}
func (s *bulkInsertStmt) ColumnConverter(idx int) driver.ValueConverter {
return columnConverter(s.parameterFieldSet.DataType(idx))
}
// driver.Rows drop-in replacement if driver Query or QueryRow is used for statements that doesn't return rows
var noColumns = []string{}
var noResult = new(noResultType)
type noResultType struct{}
func (r *noResultType) Columns() []string { return noColumns }
func (r *noResultType) Close() error { return nil }
func (r *noResultType) Next(dest []driver.Value) error { return io.EOF }
// query result
type queryResult struct {
session *p.Session
id uint64
fieldSet *p.FieldSet
fieldValues *p.FieldValues
pos int
attrs p.PartAttributes
columns []string
lastErr error
}
func newQueryResult(session *p.Session, id uint64, fieldSet *p.FieldSet, fieldValues *p.FieldValues, attrs p.PartAttributes) (driver.Rows, error) {
columns := make([]string, fieldSet.NumOutputField())
if err := fieldSet.OutputNames(columns); err != nil {
return nil, err
}
return &queryResult{
session: session,
id: id,
fieldSet: fieldSet,
fieldValues: fieldValues,
attrs: attrs,
columns: columns,
}, nil
}
func (r *queryResult) Columns() []string {
return r.columns
}
func (r *queryResult) Close() error {
// if lastError is set, attrs are nil
if r.lastErr != nil {
return r.lastErr
}
if !r.attrs.ResultsetClosed() {
return r.session.CloseResultsetID(r.id)
}
return nil
}
func (r *queryResult) Next(dest []driver.Value) error {
if r.session.IsBad() {
return driver.ErrBadConn
}
if r.pos >= r.fieldValues.NumRow() {
if r.attrs.LastPacket() {
return io.EOF
}
var err error
if r.fieldValues, r.attrs, err = r.session.FetchNext(r.id, r.fieldSet); err != nil {
r.lastErr = err //fieldValues and attrs are nil
return err
}
if r.attrs.NoRows() {
return io.EOF
}
r.pos = 0
}
r.fieldValues.Row(r.pos, dest)
r.pos++
return nil
}
//call result store
type callResultStore struct {
mu sync.RWMutex
store map[uint64]*procedureCallResult
cnt uint64
free []uint64
}
func (s *callResultStore) get(k uint64) *procedureCallResult {
s.mu.RLock()
defer s.mu.RUnlock()
if r, ok := s.store[k]; ok {
return r
}
return nil
}
func (s *callResultStore) add(v *procedureCallResult) uint64 {
s.mu.Lock()
defer s.mu.Unlock()
var k uint64
if s.free == nil || len(s.free) == 0 {
s.cnt++
k = s.cnt
} else {
size := len(s.free)
k = s.free[size-1]
s.free = s.free[:size-1]
}
if s.store == nil {
s.store = make(map[uint64]*procedureCallResult)
}
s.store[k] = v
return k
}
func (s *callResultStore) del(k uint64) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.store, k)
if s.free == nil {
s.free = []uint64{k}
} else {
s.free = append(s.free, k)
}
}
var procedureCallResultStore = new(callResultStore)
//procedure call result
type procedureCallResult struct {
id uint64
session *p.Session
fieldSet *p.FieldSet
fieldValues *p.FieldValues
_tableRows []driver.Rows
columns []string
eof error
}
func newProcedureCallResult(session *p.Session, fieldSet *p.FieldSet, fieldValues *p.FieldValues, tableResults []*p.TableResult) (driver.Rows, error) {
fieldIdx := fieldSet.NumOutputField()
columns := make([]string, fieldIdx+len(tableResults))
if err := fieldSet.OutputNames(columns); err != nil {
return nil, err
}
tableRows := make([]driver.Rows, len(tableResults))
for i, tableResult := range tableResults {
var err error
if tableRows[i], err = newQueryResult(session, tableResult.ID(), tableResult.FieldSet(), tableResult.FieldValues(), tableResult.Attrs()); err != nil {
return nil, err
}
columns[fieldIdx] = fmt.Sprintf("table %d", i)
fieldIdx++
}
result := &procedureCallResult{
session: session,
fieldSet: fieldSet,
fieldValues: fieldValues,
_tableRows: tableRows,
columns: columns,
}
id := procedureCallResultStore.add(result)
result.id = id
return result, nil
}
func (r *procedureCallResult) Columns() []string {
return r.columns
}
func (r *procedureCallResult) Close() error {
procedureCallResultStore.del(r.id)
return nil
}
func (r *procedureCallResult) Next(dest []driver.Value) error {
if r.session.IsBad() {
return driver.ErrBadConn
}
if r.eof != nil {
return r.eof
}
if r.fieldValues.NumRow() == 0 && len(r._tableRows) == 0 {
r.eof = io.EOF
return r.eof
}
if r.fieldValues.NumRow() != 0 {
r.fieldValues.Row(0, dest)
}
i := r.fieldSet.NumOutputField()
for j := range r._tableRows {
dest[i] = encodeTableQuery(r.id, uint64(j))
i++
}
r.eof = io.EOF
return nil
}
func (r *procedureCallResult) tableRows(idx int) (driver.Rows, error) {
if idx >= len(r._tableRows) {
return nil, fmt.Errorf("table row index %d exceeds maximun %d", idx, len(r._tableRows)-1)
}
return r._tableRows[idx], nil
}
// helper
const tableQueryPrefix = "@tq"
func encodeTableQuery(id, idx uint64) string {
start := len(tableQueryPrefix)
b := make([]byte, start+8+8)
copy(b, tableQueryPrefix)
binary.LittleEndian.PutUint64(b[start:start+8], id)
binary.LittleEndian.PutUint64(b[start+8:start+8+8], idx)
return string(b)
}
func decodeTableQuery(query string) (uint64, uint64, bool) {
size := len(query)
start := len(tableQueryPrefix)
if size != start+8+8 {
return 0, 0, false
}
if query[:start] != tableQueryPrefix {
return 0, 0, false
}
id := binary.LittleEndian.Uint64([]byte(query[start : start+8]))
idx := binary.LittleEndian.Uint64([]byte(query[start+8 : start+8+8]))
return id, idx, true
}