/* Copyright 2014 SAP SE Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package driver import ( "context" "database/sql" "database/sql/driver" "errors" "fmt" "io" "reflect" "time" "github.com/SAP/go-hdb/driver/sqltrace" p "github.com/SAP/go-hdb/internal/protocol" ) // DriverVersion is the version number of the hdb driver. const DriverVersion = "0.12.0" // DriverName is the driver name to use with sql.Open for hdb databases. const DriverName = "hdb" // Transaction isolation levels supported by hdb. const ( LevelReadCommitted = "READ COMMITTED" LevelRepeatableRead = "REPEATABLE READ" LevelSerializable = "SERIALIZABLE" ) // Access modes supported by hdb. const ( modeReadOnly = "READ ONLY" modeReadWrite = "READ WRITE" ) // map sql isolation level to hdb isolation level. var isolationLevel = map[driver.IsolationLevel]string{ driver.IsolationLevel(sql.LevelDefault): LevelReadCommitted, driver.IsolationLevel(sql.LevelReadCommitted): LevelReadCommitted, driver.IsolationLevel(sql.LevelRepeatableRead): LevelRepeatableRead, driver.IsolationLevel(sql.LevelSerializable): LevelSerializable, } // map sql read only flag to hdb access mode. var readOnly = map[bool]string{ true: modeReadOnly, false: modeReadWrite, } // ErrUnsupportedIsolationLevel is the error raised if a transaction is started with a not supported isolation level. var ErrUnsupportedIsolationLevel = errors.New("Unsupported isolation level") // ErrNestedTransaction is the error raised if a tranasction is created within a transaction as this is not supported by hdb. var ErrNestedTransaction = errors.New("Nested transactions are not supported") // needed for testing const driverDataFormatVersion = 1 // queries const ( pingQuery = "select 1 from dummy" isolationLevelStmt = "set transaction isolation level %s" accessModeStmt = "set transaction %s" ) // bulk statement const noFlush = "$nf" // NoFlush is to be used as parameter in bulk inserts. var NoFlush = sql.Named(noFlush, nil) var drv = &hdbDrv{} func init() { sql.Register(DriverName, drv) } // driver // check if driver implements all required interfaces var ( _ driver.Driver = (*hdbDrv)(nil) ) type hdbDrv struct{} func (d *hdbDrv) Open(dsn string) (driver.Conn, error) { connector, err := NewDSNConnector(dsn) if err != nil { return nil, err } return connector.Connect(context.Background()) } // database connection // check if conn implements all required interfaces var ( _ driver.Conn = (*conn)(nil) _ driver.ConnPrepareContext = (*conn)(nil) _ driver.Pinger = (*conn)(nil) _ driver.ConnBeginTx = (*conn)(nil) _ driver.ExecerContext = (*conn)(nil) //go 1.9 issue (ExecerContext is only called if Execer is implemented) _ driver.Execer = (*conn)(nil) _ driver.QueryerContext = (*conn)(nil) //go 1.9 issue (QueryerContext is only called if Queryer is implemented) // QueryContext is needed for stored procedures with table output parameters. _ driver.Queryer = (*conn)(nil) _ driver.NamedValueChecker = (*conn)(nil) ) type conn struct { session *p.Session } func newConn(ctx context.Context, c *Connector) (driver.Conn, error) { session, err := p.NewSession(ctx, c) if err != nil { return nil, err } return &conn{session: session}, nil } func (c *conn) Prepare(query string) (driver.Stmt, error) { panic("deprecated") } func (c *conn) Close() error { c.session.Close() return nil } func (c *conn) Begin() (driver.Tx, error) { panic("deprecated") } func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { if c.session.IsBad() { return nil, driver.ErrBadConn } if c.session.InTx() { return nil, ErrNestedTransaction } level, ok := isolationLevel[opts.Isolation] if !ok { return nil, ErrUnsupportedIsolationLevel } done := make(chan struct{}) go func() { // set isolation level if _, err = c.ExecContext(ctx, fmt.Sprintf(isolationLevelStmt, level), nil); err != nil { goto done } // set access mode if _, err = c.ExecContext(ctx, fmt.Sprintf(accessModeStmt, readOnly[opts.ReadOnly]), nil); err != nil { goto done } c.session.SetInTx(true) tx = newTx(c.session) done: close(done) }() select { case <-ctx.Done(): return nil, ctx.Err() case <-done: return tx, err } } // Exec implements the database/sql/driver/Execer interface. // delete after go 1.9 compatibility is given up. func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { panic("deprecated") } // ExecContext implements the database/sql/driver/ExecerContext interface. func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { if c.session.IsBad() { return nil, driver.ErrBadConn } if len(args) != 0 { return nil, driver.ErrSkip //fast path not possible (prepare needed) } sqltrace.Traceln(query) done := make(chan struct{}) go func() { r, err = c.session.ExecDirect(query) close(done) }() select { case <-ctx.Done(): return nil, ctx.Err() case <-done: return r, err } } // Queryer implements the database/sql/driver/Queryer interface. // delete after go 1.9 compatibility is given up. func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { panic("deprecated") } func (c *conn) Ping(ctx context.Context) (err error) { if c.session.IsBad() { return driver.ErrBadConn } done := make(chan struct{}) go func() { _, err = c.QueryContext(ctx, pingQuery, nil) close(done) }() select { case <-ctx.Done(): return ctx.Err() case <-done: return err } } // CheckNamedValue implements NamedValueChecker interface. // implemented for conn: // if querier or execer is called, sql checks parameters before in case of // parameters the method can be 'skipped' and force the prepare path // --> guarantee that a valid driver value is returned // --> if not implemented, Lob need to have a pseudo Value method to return a valid driver value func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { switch nv.Value.(type) { case Lob, *Lob: nv.Value = nil } return nil } //transaction // check if tx implements all required interfaces var ( _ driver.Tx = (*tx)(nil) ) type tx struct { session *p.Session } func newTx(session *p.Session) *tx { return &tx{ session: session, } } func (t *tx) Commit() error { if t.session.IsBad() { return driver.ErrBadConn } return t.session.Commit() } func (t *tx) Rollback() error { if t.session.IsBad() { return driver.ErrBadConn } return t.session.Rollback() } //statement // check if stmt implements all required interfaces var ( _ driver.Stmt = (*stmt)(nil) _ driver.StmtExecContext = (*stmt)(nil) _ driver.StmtQueryContext = (*stmt)(nil) _ driver.NamedValueChecker = (*stmt)(nil) ) type stmt struct { qt p.QueryType session *p.Session query string id uint64 prmFieldSet *p.ParameterFieldSet resultFieldSet *p.ResultFieldSet } func newStmt(qt p.QueryType, session *p.Session, query string, id uint64, prmFieldSet *p.ParameterFieldSet, resultFieldSet *p.ResultFieldSet) (*stmt, error) { return &stmt{qt: qt, session: session, query: query, id: id, prmFieldSet: prmFieldSet, resultFieldSet: resultFieldSet}, nil } func (s *stmt) Close() error { return s.session.DropStatementID(s.id) } func (s *stmt) NumInput() int { return s.prmFieldSet.NumInputField() } func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { panic("deprecated") } func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { if s.session.IsBad() { return nil, driver.ErrBadConn } numField := s.prmFieldSet.NumInputField() if len(args) != numField { return nil, fmt.Errorf("invalid number of arguments %d - %d expected", len(args), numField) } sqltrace.Tracef("%s %v", s.query, args) done := make(chan struct{}) go func() { r, err = s.session.Exec(s.id, s.prmFieldSet, args) close(done) }() select { case <-ctx.Done(): return nil, ctx.Err() case <-done: return r, err } } func (s *stmt) Query(args []driver.Value) (rows driver.Rows, err error) { panic("deprecated") } // Deprecated: see NamedValueChecker. //func (s *stmt) ColumnConverter(idx int) driver.ValueConverter { //} // CheckNamedValue implements NamedValueChecker interface. func (s *stmt) CheckNamedValue(nv *driver.NamedValue) error { if nv.Name == noFlush { //... print("remove variable") return driver.ErrRemoveArgument } return checkNamedValue(s.prmFieldSet, nv) } // driver.Rows drop-in replacement if driver Query or QueryRow is used for statements that doesn't return rows var noColumns = []string{} var noResult = new(noResultType) // check if noResultType implements all required interfaces var ( _ driver.Rows = (*noResultType)(nil) ) type noResultType struct{} func (r *noResultType) Columns() []string { return noColumns } func (r *noResultType) Close() error { return nil } func (r *noResultType) Next(dest []driver.Value) error { return io.EOF } // rows type rows struct { } // query result // check if queryResult implements all required interfaces var ( _ driver.Rows = (*queryResult)(nil) _ driver.RowsColumnTypeDatabaseTypeName = (*queryResult)(nil) // go 1.8 _ driver.RowsColumnTypeLength = (*queryResult)(nil) // go 1.8 _ driver.RowsColumnTypeNullable = (*queryResult)(nil) // go 1.8 _ driver.RowsColumnTypePrecisionScale = (*queryResult)(nil) // go 1.8 _ driver.RowsColumnTypeScanType = (*queryResult)(nil) // go 1.8 ) type queryResult struct { session *p.Session id uint64 resultFieldSet *p.ResultFieldSet fieldValues *p.FieldValues pos int attrs p.PartAttributes columns []string lastErr error } func newQueryResult(session *p.Session, id uint64, resultFieldSet *p.ResultFieldSet, fieldValues *p.FieldValues, attrs p.PartAttributes) (driver.Rows, error) { columns := make([]string, resultFieldSet.NumField()) for i := 0; i < len(columns); i++ { columns[i] = resultFieldSet.Field(i).Name() } return &queryResult{ session: session, id: id, resultFieldSet: resultFieldSet, fieldValues: fieldValues, attrs: attrs, columns: columns, }, nil } func (r *queryResult) Columns() []string { return r.columns } func (r *queryResult) Close() error { // if lastError is set, attrs are nil if r.lastErr != nil { return r.lastErr } if !r.attrs.ResultsetClosed() { return r.session.CloseResultsetID(r.id) } return nil } func (r *queryResult) Next(dest []driver.Value) error { if r.session.IsBad() { return driver.ErrBadConn } if r.pos >= r.fieldValues.NumRow() { if r.attrs.LastPacket() { return io.EOF } var err error if r.attrs, err = r.session.FetchNext(r.id, r.resultFieldSet, r.fieldValues); err != nil { r.lastErr = err //fieldValues and attrs are nil return err } if r.attrs.NoRows() { return io.EOF } r.pos = 0 } r.fieldValues.Row(r.pos, dest) r.pos++ return nil } func (r *queryResult) ColumnTypeDatabaseTypeName(idx int) string { return r.resultFieldSet.Field(idx).TypeCode().TypeName() } func (r *queryResult) ColumnTypeLength(idx int) (int64, bool) { return r.resultFieldSet.Field(idx).TypeLength() } func (r *queryResult) ColumnTypePrecisionScale(idx int) (int64, int64, bool) { return r.resultFieldSet.Field(idx).TypePrecisionScale() } func (r *queryResult) ColumnTypeNullable(idx int) (bool, bool) { return r.resultFieldSet.Field(idx).Nullable(), true } var ( scanTypeUnknown = reflect.TypeOf(new(interface{})).Elem() scanTypeTinyint = reflect.TypeOf(uint8(0)) scanTypeSmallint = reflect.TypeOf(int16(0)) scanTypeInteger = reflect.TypeOf(int32(0)) scanTypeBigint = reflect.TypeOf(int64(0)) scanTypeReal = reflect.TypeOf(float32(0.0)) scanTypeDouble = reflect.TypeOf(float64(0.0)) scanTypeTime = reflect.TypeOf(time.Time{}) scanTypeString = reflect.TypeOf(string("")) scanTypeBytes = reflect.TypeOf([]byte{}) scanTypeDecimal = reflect.TypeOf(Decimal{}) scanTypeLob = reflect.TypeOf(Lob{}) ) func (r *queryResult) ColumnTypeScanType(idx int) reflect.Type { switch r.resultFieldSet.Field(idx).TypeCode().DataType() { default: return scanTypeUnknown case p.DtTinyint: return scanTypeTinyint case p.DtSmallint: return scanTypeSmallint case p.DtInteger: return scanTypeInteger case p.DtBigint: return scanTypeBigint case p.DtReal: return scanTypeReal case p.DtDouble: return scanTypeDouble case p.DtTime: return scanTypeTime case p.DtDecimal: return scanTypeDecimal case p.DtString: return scanTypeString case p.DtBytes: return scanTypeBytes case p.DtLob: return scanTypeLob } }