255 lines
4.9 KiB
Go
255 lines
4.9 KiB
Go
|
package radius
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
)
|
||
|
|
||
|
type packetResponseWriter struct {
|
||
|
// listener that received the packet
|
||
|
conn net.PacketConn
|
||
|
addr net.Addr
|
||
|
}
|
||
|
|
||
|
func (r *packetResponseWriter) Write(packet *Packet) error {
|
||
|
raw, err := packet.Encode()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if _, err := r.conn.WriteTo(raw, r.addr); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// PacketServer listens for RADIUS requests on a packet-based protocols (e.g.
|
||
|
// UDP).
|
||
|
type PacketServer struct {
|
||
|
// The address on which the server listens. Defaults to :1812.
|
||
|
Addr string
|
||
|
// The network on which the server listens. Defaults to udp.
|
||
|
Network string
|
||
|
SecretSource SecretSource
|
||
|
Handler Handler
|
||
|
|
||
|
// Skip incoming packet authenticity validation.
|
||
|
// This should only be set to true for debugging purposes.
|
||
|
InsecureSkipVerify bool
|
||
|
|
||
|
mu sync.Mutex
|
||
|
shuttingDown bool
|
||
|
ctx context.Context
|
||
|
ctxDone context.CancelFunc
|
||
|
running chan struct{}
|
||
|
listeners map[net.PacketConn]int
|
||
|
activeCount int32
|
||
|
}
|
||
|
|
||
|
// TODO: logger on PacketServer
|
||
|
|
||
|
// Serve accepts incoming connections on conn.
|
||
|
func (s *PacketServer) Serve(conn net.PacketConn) error {
|
||
|
if s.Handler == nil {
|
||
|
return errors.New("radius: nil Handler")
|
||
|
}
|
||
|
if s.SecretSource == nil {
|
||
|
return errors.New("radius: nil SecretSource")
|
||
|
}
|
||
|
|
||
|
s.mu.Lock()
|
||
|
if s.shuttingDown {
|
||
|
s.mu.Unlock()
|
||
|
return ErrServerShutdown
|
||
|
}
|
||
|
var ctx context.Context
|
||
|
if s.ctx == nil {
|
||
|
s.ctx, s.ctxDone = context.WithCancel(context.Background())
|
||
|
ctx = s.ctx
|
||
|
}
|
||
|
if s.running == nil {
|
||
|
s.running = make(chan struct{})
|
||
|
}
|
||
|
if s.listeners == nil {
|
||
|
s.listeners = make(map[net.PacketConn]int)
|
||
|
}
|
||
|
s.listeners[conn]++
|
||
|
s.mu.Unlock()
|
||
|
|
||
|
type activeKey struct {
|
||
|
IP string
|
||
|
Identifier byte
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
activeLock sync.Mutex
|
||
|
active = map[activeKey]struct{}{}
|
||
|
)
|
||
|
|
||
|
atomic.AddInt32(&s.activeCount, 1)
|
||
|
defer func() {
|
||
|
s.mu.Lock()
|
||
|
s.listeners[conn]--
|
||
|
if s.listeners[conn] == 0 {
|
||
|
delete(s.listeners, conn)
|
||
|
}
|
||
|
s.mu.Unlock()
|
||
|
|
||
|
if atomic.AddInt32(&s.activeCount, -1) == 0 {
|
||
|
s.mu.Lock()
|
||
|
s.shuttingDown = false
|
||
|
close(s.running)
|
||
|
s.running = nil
|
||
|
s.ctx = nil
|
||
|
s.mu.Unlock()
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
for {
|
||
|
var buff [MaxPacketLength]byte
|
||
|
n, remoteAddr, err := conn.ReadFrom(buff[:])
|
||
|
if err != nil {
|
||
|
s.mu.Lock()
|
||
|
if s.shuttingDown {
|
||
|
s.mu.Unlock()
|
||
|
return nil
|
||
|
}
|
||
|
s.mu.Unlock()
|
||
|
|
||
|
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
||
|
return err
|
||
|
}
|
||
|
// TODO: log error?
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
buffCopy := make([]byte, n)
|
||
|
copy(buffCopy, buff[:n])
|
||
|
|
||
|
atomic.AddInt32(&s.activeCount, 1)
|
||
|
go func(buff []byte, remoteAddr net.Addr) {
|
||
|
secret, err := s.SecretSource.RADIUSSecret(ctx, remoteAddr)
|
||
|
if err != nil {
|
||
|
// TODO: log only if server is not shutting down?
|
||
|
return
|
||
|
}
|
||
|
if len(secret) == 0 {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if !s.InsecureSkipVerify && !IsAuthenticRequest(buff, secret) {
|
||
|
// TODO: log?
|
||
|
return
|
||
|
}
|
||
|
|
||
|
packet, err := Parse(buff, secret)
|
||
|
if err != nil {
|
||
|
// TODO: error logger
|
||
|
return
|
||
|
}
|
||
|
|
||
|
key := activeKey{
|
||
|
IP: remoteAddr.String(),
|
||
|
Identifier: packet.Identifier,
|
||
|
}
|
||
|
activeLock.Lock()
|
||
|
if _, ok := active[key]; ok {
|
||
|
activeLock.Unlock()
|
||
|
return
|
||
|
}
|
||
|
active[key] = struct{}{}
|
||
|
activeLock.Unlock()
|
||
|
|
||
|
response := packetResponseWriter{
|
||
|
conn: conn,
|
||
|
addr: remoteAddr,
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
activeLock.Lock()
|
||
|
delete(active, key)
|
||
|
activeLock.Unlock()
|
||
|
|
||
|
if atomic.AddInt32(&s.activeCount, -1) == 0 {
|
||
|
s.mu.Lock()
|
||
|
s.shuttingDown = false
|
||
|
close(s.running)
|
||
|
s.running = nil
|
||
|
s.ctx = nil
|
||
|
s.mu.Unlock()
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
request := Request{
|
||
|
LocalAddr: conn.LocalAddr(),
|
||
|
RemoteAddr: remoteAddr,
|
||
|
Packet: packet,
|
||
|
ctx: ctx,
|
||
|
}
|
||
|
|
||
|
s.Handler.ServeRADIUS(&response, &request)
|
||
|
}(buffCopy, remoteAddr)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ListenAndServe starts a RADIUS server on the address given in s.
|
||
|
func (s *PacketServer) ListenAndServe() error {
|
||
|
if s.Handler == nil {
|
||
|
return errors.New("radius: nil Handler")
|
||
|
}
|
||
|
if s.SecretSource == nil {
|
||
|
return errors.New("radius: nil SecretSource")
|
||
|
}
|
||
|
|
||
|
addrStr := ":1812"
|
||
|
if s.Addr != "" {
|
||
|
addrStr = s.Addr
|
||
|
}
|
||
|
|
||
|
network := "udp"
|
||
|
if s.Network != "" {
|
||
|
network = s.Network
|
||
|
}
|
||
|
|
||
|
pc, err := net.ListenPacket(network, addrStr)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer pc.Close()
|
||
|
return s.Serve(pc)
|
||
|
}
|
||
|
|
||
|
// Shutdown gracefully stops the server. It first closes all listeners (which
|
||
|
// stops accepting new packets) and then waits for running handlers to complete.
|
||
|
//
|
||
|
// Shutdown returns after all handlers have completed, or when ctx is canceled.
|
||
|
// The PacketServer is ready for re-use once the function returns nil.
|
||
|
func (s *PacketServer) Shutdown(ctx context.Context) error {
|
||
|
s.mu.Lock()
|
||
|
|
||
|
if len(s.listeners) == 0 {
|
||
|
s.mu.Unlock()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
if !s.shuttingDown {
|
||
|
s.shuttingDown = true
|
||
|
s.ctxDone()
|
||
|
for listener := range s.listeners {
|
||
|
listener.Close()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
ch := s.running
|
||
|
s.mu.Unlock()
|
||
|
|
||
|
select {
|
||
|
case <-ch:
|
||
|
return nil
|
||
|
case <-ctx.Done():
|
||
|
return ctx.Err()
|
||
|
}
|
||
|
}
|