open-vault/vendor/github.com/go-ldap/ldap/conn.go

460 lines
12 KiB
Go
Raw Normal View History

2015-05-07 17:59:38 +00:00
package ldap
import (
"crypto/tls"
"errors"
2015-06-29 21:50:55 +00:00
"fmt"
2015-05-07 17:59:38 +00:00
"log"
"net"
"sync"
2017-02-02 21:19:55 +00:00
"sync/atomic"
2015-06-29 21:50:55 +00:00
"time"
2015-10-07 20:10:00 +00:00
"gopkg.in/asn1-ber.v1"
2015-05-07 17:59:38 +00:00
)
const (
2016-07-23 00:11:47 +00:00
// MessageQuit causes the processMessages loop to exit
MessageQuit = 0
// MessageRequest sends a request to the server
MessageRequest = 1
// MessageResponse receives a response from the server
2015-05-07 17:59:38 +00:00
MessageResponse = 2
2016-07-23 00:11:47 +00:00
// MessageFinish indicates the client considers a particular message ID to be finished
MessageFinish = 3
// MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
MessageTimeout = 4
2015-05-07 17:59:38 +00:00
)
2016-07-23 00:11:47 +00:00
// PacketResponse contains the packet or error encountered reading a response
2016-04-26 00:18:04 +00:00
type PacketResponse struct {
2016-07-23 00:11:47 +00:00
// Packet is the packet read from the server
2016-04-26 00:18:04 +00:00
Packet *ber.Packet
2016-07-23 00:11:47 +00:00
// Error is an error encountered while reading
Error error
2016-04-26 00:18:04 +00:00
}
2016-07-23 00:11:47 +00:00
// ReadPacket returns the packet or an error
2016-04-26 00:18:04 +00:00
func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
}
return pr.Packet, pr.Error
}
2016-06-30 18:19:03 +00:00
type messageContext struct {
2016-07-23 00:11:47 +00:00
id int64
// close(done) should only be called from finishMessage()
done chan struct{}
// close(responses) should only be called from processMessages(), and only sent to from sendResponse()
2016-06-30 18:19:03 +00:00
responses chan *PacketResponse
}
// sendResponse should only be called within the processMessages() loop which
// is also responsible for closing the responses channel.
func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
select {
case msgCtx.responses <- packet:
// Successfully sent packet to message handler.
case <-msgCtx.done:
// The request handler is done and will not receive more
// packets.
}
}
2015-05-07 17:59:38 +00:00
type messagePacket struct {
Op int
2015-06-29 21:50:55 +00:00
MessageID int64
2015-05-07 17:59:38 +00:00
Packet *ber.Packet
2016-06-30 18:19:03 +00:00
Context *messageContext
2015-05-07 17:59:38 +00:00
}
2015-06-29 21:50:55 +00:00
type sendMessageFlags uint
const (
startTLS sendMessageFlags = 1 << iota
)
2015-05-07 17:59:38 +00:00
// Conn represents an LDAP Connection
type Conn struct {
2015-06-29 21:50:55 +00:00
conn net.Conn
isTLS bool
2017-10-27 19:06:04 +00:00
closing uint32
2018-06-15 17:13:57 +00:00
closeErr atomic.Value
2015-06-29 21:50:55 +00:00
isStartingTLS bool
Debug debugging
2017-10-27 19:06:04 +00:00
chanConfirm chan struct{}
2016-06-30 18:19:03 +00:00
messageContexts map[int64]*messageContext
2015-06-29 21:50:55 +00:00
chanMessage chan *messagePacket
chanMessageID chan int64
wgClose sync.WaitGroup
outstandingRequests uint
messageMutex sync.Mutex
2017-10-27 19:06:04 +00:00
requestTimeout int64
2015-06-29 21:50:55 +00:00
}
2015-11-02 18:43:12 +00:00
var _ Client = &Conn{}
2015-06-29 22:05:44 +00:00
// DefaultTimeout is a package-level variable that sets the timeout value
// used for the Dial and DialTLS methods.
//
// WARNING: since this is a package-level variable, setting this value from
// multiple places will probably result in undesired behaviour.
var DefaultTimeout = 60 * time.Second
2015-05-07 17:59:38 +00:00
// Dial connects to the given address on the given network using net.Dial
// and then returns a new Conn for the connection.
func Dial(network, addr string) (*Conn, error) {
2015-06-29 22:05:44 +00:00
c, err := net.DialTimeout(network, addr, DefaultTimeout)
2015-05-07 17:59:38 +00:00
if err != nil {
return nil, NewError(ErrorNetwork, err)
}
2015-06-29 21:50:55 +00:00
conn := NewConn(c, false)
conn.Start()
2015-05-07 17:59:38 +00:00
return conn, nil
}
// DialTLS connects to the given address on the given network using tls.Dial
// and then returns a new Conn for the connection.
func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
2018-06-15 17:13:57 +00:00
c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
2015-05-07 17:59:38 +00:00
if err != nil {
return nil, NewError(ErrorNetwork, err)
}
2015-06-29 21:50:55 +00:00
conn := NewConn(c, true)
conn.Start()
2015-05-07 17:59:38 +00:00
return conn, nil
}
// NewConn returns a new Conn using conn for network I/O.
2015-06-29 21:50:55 +00:00
func NewConn(conn net.Conn, isTLS bool) *Conn {
2015-05-07 17:59:38 +00:00
return &Conn{
2016-06-30 18:19:03 +00:00
conn: conn,
2017-10-27 19:06:04 +00:00
chanConfirm: make(chan struct{}),
2016-06-30 18:19:03 +00:00
chanMessageID: make(chan int64),
chanMessage: make(chan *messagePacket, 10),
messageContexts: map[int64]*messageContext{},
requestTimeout: 0,
isTLS: isTLS,
2015-05-07 17:59:38 +00:00
}
}
2016-07-23 00:11:47 +00:00
// Start initializes goroutines to read responses and process messages
2015-06-29 21:50:55 +00:00
func (l *Conn) Start() {
2015-05-07 17:59:38 +00:00
go l.reader()
go l.processMessages()
2015-06-29 21:50:55 +00:00
l.wgClose.Add(1)
2015-05-07 17:59:38 +00:00
}
2017-02-02 21:19:55 +00:00
// isClosing returns whether or not we're currently closing.
func (l *Conn) isClosing() bool {
2017-10-27 19:06:04 +00:00
return atomic.LoadUint32(&l.closing) == 1
2017-02-02 21:19:55 +00:00
}
// setClosing sets the closing value to true
2017-10-27 19:06:04 +00:00
func (l *Conn) setClosing() bool {
return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
2017-02-02 21:19:55 +00:00
}
2015-05-07 17:59:38 +00:00
// Close closes the connection.
func (l *Conn) Close() {
2017-10-27 19:06:04 +00:00
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
2015-05-07 17:59:38 +00:00
2017-10-27 19:06:04 +00:00
if l.setClosing() {
2015-05-07 17:59:38 +00:00
l.Debug.Printf("Sending quit message and waiting for confirmation")
l.chanMessage <- &messagePacket{Op: MessageQuit}
<-l.chanConfirm
close(l.chanMessage)
l.Debug.Printf("Closing network connection")
if err := l.conn.Close(); err != nil {
2017-10-27 19:06:04 +00:00
log.Println(err)
2015-05-07 17:59:38 +00:00
}
2015-06-29 21:50:55 +00:00
l.wgClose.Done()
2017-10-27 19:06:04 +00:00
}
2015-06-29 21:50:55 +00:00
l.wgClose.Wait()
2015-05-07 17:59:38 +00:00
}
2016-07-23 00:11:47 +00:00
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
2016-04-26 00:18:04 +00:00
func (l *Conn) SetTimeout(timeout time.Duration) {
if timeout > 0 {
2017-10-27 19:06:04 +00:00
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
2016-04-26 00:18:04 +00:00
}
}
2015-05-07 17:59:38 +00:00
// Returns the next available messageID
2015-06-29 21:50:55 +00:00
func (l *Conn) nextMessageID() int64 {
2017-10-27 19:06:04 +00:00
if messageID, ok := <-l.chanMessageID; ok {
return messageID
2015-05-07 17:59:38 +00:00
}
return 0
}
// StartTLS sends the command to start a TLS session and then creates a new TLS Client
func (l *Conn) StartTLS(config *tls.Config) error {
if l.isTLS {
return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
}
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
2016-06-30 18:19:03 +00:00
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
2015-05-07 17:59:38 +00:00
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
packet.AppendChild(request)
l.Debug.PrintPacket(packet)
2016-06-30 18:19:03 +00:00
msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
2015-05-07 17:59:38 +00:00
if err != nil {
2015-06-29 21:50:55 +00:00
return err
2015-05-07 17:59:38 +00:00
}
2016-06-30 18:19:03 +00:00
defer l.finishMessage(msgCtx)
2015-05-07 17:59:38 +00:00
2016-06-30 18:19:03 +00:00
l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-msgCtx.responses
2016-04-26 00:18:04 +00:00
if !ok {
2016-06-30 18:19:03 +00:00
return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
2016-04-26 00:18:04 +00:00
}
packet, err = packetResponse.ReadPacket()
2016-06-30 18:19:03 +00:00
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
2016-04-26 00:18:04 +00:00
if err != nil {
return err
}
2015-06-29 21:50:55 +00:00
2015-05-07 17:59:38 +00:00
if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
2015-06-29 21:50:55 +00:00
l.Close()
2015-05-07 17:59:38 +00:00
return err
}
ber.PrintPacket(packet)
}
2015-10-30 22:07:00 +00:00
if resultCode, message := getLDAPResultCode(packet); resultCode == LDAPResultSuccess {
2015-05-07 17:59:38 +00:00
conn := tls.Client(l.conn, config)
2015-06-29 21:50:55 +00:00
if err := conn.Handshake(); err != nil {
l.Close()
return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", err))
}
2015-05-07 17:59:38 +00:00
l.isTLS = true
l.conn = conn
2015-08-19 01:12:51 +00:00
} else {
2015-10-30 22:07:00 +00:00
return NewError(resultCode, fmt.Errorf("ldap: cannot StartTLS (%s)", message))
2015-05-07 17:59:38 +00:00
}
2015-06-29 21:50:55 +00:00
go l.reader()
2015-05-07 17:59:38 +00:00
return nil
}
2016-06-30 18:19:03 +00:00
func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
2015-06-29 21:50:55 +00:00
return l.sendMessageWithFlags(packet, 0)
2015-05-07 17:59:38 +00:00
}
2016-06-30 18:19:03 +00:00
func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
2017-02-02 21:19:55 +00:00
if l.isClosing() {
2015-05-07 17:59:38 +00:00
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
}
2015-06-29 21:50:55 +00:00
l.messageMutex.Lock()
l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
if l.isStartingTLS {
l.messageMutex.Unlock()
2016-07-23 00:11:47 +00:00
return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
2015-06-29 21:50:55 +00:00
}
if flags&startTLS != 0 {
if l.outstandingRequests != 0 {
l.messageMutex.Unlock()
return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
}
2016-07-23 00:11:47 +00:00
l.isStartingTLS = true
2015-06-29 21:50:55 +00:00
}
l.outstandingRequests++
l.messageMutex.Unlock()
2016-06-30 18:19:03 +00:00
responses := make(chan *PacketResponse)
messageID := packet.Children[0].Value.(int64)
2015-05-07 17:59:38 +00:00
message := &messagePacket{
Op: MessageRequest,
2016-06-30 18:19:03 +00:00
MessageID: messageID,
2015-05-07 17:59:38 +00:00
Packet: packet,
2016-06-30 18:19:03 +00:00
Context: &messageContext{
id: messageID,
done: make(chan struct{}),
responses: responses,
},
2015-05-07 17:59:38 +00:00
}
l.sendProcessMessage(message)
2016-06-30 18:19:03 +00:00
return message.Context, nil
2015-05-07 17:59:38 +00:00
}
2016-06-30 18:19:03 +00:00
func (l *Conn) finishMessage(msgCtx *messageContext) {
close(msgCtx.done)
2017-02-02 21:19:55 +00:00
if l.isClosing() {
2015-05-07 17:59:38 +00:00
return
}
2015-06-29 21:50:55 +00:00
l.messageMutex.Lock()
l.outstandingRequests--
if l.isStartingTLS {
l.isStartingTLS = false
}
l.messageMutex.Unlock()
2015-05-07 17:59:38 +00:00
message := &messagePacket{
Op: MessageFinish,
2016-06-30 18:19:03 +00:00
MessageID: msgCtx.id,
2015-05-07 17:59:38 +00:00
}
l.sendProcessMessage(message)
}
func (l *Conn) sendProcessMessage(message *messagePacket) bool {
2017-10-27 19:06:04 +00:00
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
2017-02-02 21:19:55 +00:00
if l.isClosing() {
2015-05-07 17:59:38 +00:00
return false
}
l.chanMessage <- message
return true
}
func (l *Conn) processMessages() {
defer func() {
2015-06-29 21:50:55 +00:00
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in processMessages: %v", err)
}
2016-06-30 18:19:03 +00:00
for messageID, msgCtx := range l.messageContexts {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
2017-02-02 21:19:55 +00:00
if l.isClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
2016-06-30 18:19:03 +00:00
}
2015-05-07 17:59:38 +00:00
l.Debug.Printf("Closing channel for MessageID %d", messageID)
2016-06-30 18:19:03 +00:00
close(msgCtx.responses)
delete(l.messageContexts, messageID)
2015-05-07 17:59:38 +00:00
}
close(l.chanMessageID)
close(l.chanConfirm)
}()
2015-06-29 21:50:55 +00:00
var messageID int64 = 1
2015-05-07 17:59:38 +00:00
for {
select {
case l.chanMessageID <- messageID:
messageID++
2017-10-27 19:06:04 +00:00
case message := <-l.chanMessage:
2016-04-26 00:18:04 +00:00
switch message.Op {
2015-05-07 17:59:38 +00:00
case MessageQuit:
l.Debug.Printf("Shutting down - quit message received")
return
case MessageRequest:
// Add to message list and write to network
2016-04-26 00:18:04 +00:00
l.Debug.Printf("Sending message %d", message.MessageID)
2015-05-07 17:59:38 +00:00
2016-04-26 00:18:04 +00:00
buf := message.Packet.Bytes()
2015-05-07 17:59:38 +00:00
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
2016-06-30 18:19:03 +00:00
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
close(message.Context.responses)
2015-05-07 17:59:38 +00:00
break
}
2016-04-26 00:18:04 +00:00
2016-06-30 18:19:03 +00:00
// Only add to messageContexts if we were able to
// successfully write the message.
l.messageContexts[message.MessageID] = message.Context
2016-04-26 00:18:04 +00:00
// Add timeout if defined
2017-10-27 19:06:04 +00:00
requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
if requestTimeout > 0 {
2016-04-26 00:18:04 +00:00
go func() {
defer func() {
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
}
}()
2017-10-27 19:06:04 +00:00
time.Sleep(requestTimeout)
2016-04-26 00:18:04 +00:00
timeoutMessage := &messagePacket{
Op: MessageTimeout,
MessageID: message.MessageID,
}
l.sendProcessMessage(timeoutMessage)
}()
}
2015-05-07 17:59:38 +00:00
case MessageResponse:
2016-04-26 00:18:04 +00:00
l.Debug.Printf("Receiving message %d", message.MessageID)
2016-06-30 18:19:03 +00:00
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
2015-05-07 17:59:38 +00:00
} else {
2017-02-02 21:19:55 +00:00
log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing())
2016-04-26 00:18:04 +00:00
ber.PrintPacket(message.Packet)
}
case MessageTimeout:
// Handle the timeout by closing the channel
// All reads will return immediately
2016-06-30 18:19:03 +00:00
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
2016-04-26 00:18:04 +00:00
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
2016-06-30 18:19:03 +00:00
msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
2015-05-07 17:59:38 +00:00
}
case MessageFinish:
2016-04-26 00:18:04 +00:00
l.Debug.Printf("Finished message %d", message.MessageID)
2016-06-30 18:19:03 +00:00
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
2016-04-26 00:18:04 +00:00
}
2015-05-07 17:59:38 +00:00
}
}
}
}
func (l *Conn) reader() {
2015-06-29 21:50:55 +00:00
cleanstop := false
2015-05-07 17:59:38 +00:00
defer func() {
2015-06-29 21:50:55 +00:00
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in reader: %v", err)
}
if !cleanstop {
l.Close()
}
2015-05-07 17:59:38 +00:00
}()
for {
2015-06-29 21:50:55 +00:00
if cleanstop {
l.Debug.Printf("reader clean stopping (without closing the connection)")
return
}
2015-05-07 17:59:38 +00:00
packet, err := ber.ReadPacket(l.conn)
if err != nil {
2015-06-29 21:50:55 +00:00
// A read error is expected here if we are closing the connection...
2017-02-02 21:19:55 +00:00
if !l.isClosing() {
l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
2015-06-29 21:50:55 +00:00
l.Debug.Printf("reader error: %s", err.Error())
}
2015-05-07 17:59:38 +00:00
return
}
addLDAPDescriptions(packet)
2015-06-29 21:50:55 +00:00
if len(packet.Children) == 0 {
l.Debug.Printf("Received bad ldap packet")
continue
}
l.messageMutex.Lock()
if l.isStartingTLS {
cleanstop = true
}
l.messageMutex.Unlock()
2015-05-07 17:59:38 +00:00
message := &messagePacket{
Op: MessageResponse,
2015-06-29 21:50:55 +00:00
MessageID: packet.Children[0].Value.(int64),
2015-05-07 17:59:38 +00:00
Packet: packet,
}
if !l.sendProcessMessage(message) {
return
}
}
}