242 lines
4.8 KiB
Go
242 lines
4.8 KiB
Go
package radius
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
)
|
|
|
|
type packetResponseWriter struct {
|
|
// listener that received the packet
|
|
conn net.PacketConn
|
|
addr net.Addr
|
|
}
|
|
|
|
func (r *packetResponseWriter) Write(packet *Packet) error {
|
|
encoded, err := packet.Encode()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := r.conn.WriteTo(encoded, r.addr); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// PacketServer listens for RADIUS requests on a packet-based protocols (e.g.
|
|
// UDP).
|
|
type PacketServer struct {
|
|
// The address on which the server listens. Defaults to :1812.
|
|
Addr string
|
|
|
|
// The network on which the server listens. Defaults to udp.
|
|
Network string
|
|
|
|
// The source from which the secret is obtained for parsing and validating
|
|
// the request.
|
|
SecretSource SecretSource
|
|
|
|
// Handler which is called to process the request.
|
|
Handler Handler
|
|
|
|
// Skip incoming packet authenticity validation.
|
|
// This should only be set to true for debugging purposes.
|
|
InsecureSkipVerify bool
|
|
|
|
shutdownRequested int32
|
|
|
|
mu sync.Mutex
|
|
ctx context.Context
|
|
ctxDone context.CancelFunc
|
|
listeners map[net.PacketConn]uint
|
|
lastActive chan struct{} // closed when the last active item finishes
|
|
activeCount int32
|
|
}
|
|
|
|
func (s *PacketServer) initLocked() {
|
|
if s.ctx == nil {
|
|
s.ctx, s.ctxDone = context.WithCancel(context.Background())
|
|
s.listeners = make(map[net.PacketConn]uint)
|
|
s.lastActive = make(chan struct{})
|
|
}
|
|
}
|
|
|
|
func (s *PacketServer) activeAdd() {
|
|
atomic.AddInt32(&s.activeCount, 1)
|
|
}
|
|
|
|
func (s *PacketServer) activeDone() {
|
|
if atomic.AddInt32(&s.activeCount, -1) == -1 {
|
|
close(s.lastActive)
|
|
}
|
|
}
|
|
|
|
// TODO: logger on PacketServer
|
|
|
|
// Serve accepts incoming connections on conn.
|
|
func (s *PacketServer) Serve(conn net.PacketConn) error {
|
|
if s.Handler == nil {
|
|
return errors.New("radius: nil Handler")
|
|
}
|
|
if s.SecretSource == nil {
|
|
return errors.New("radius: nil SecretSource")
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.initLocked()
|
|
if atomic.LoadInt32(&s.shutdownRequested) == 1 {
|
|
s.mu.Unlock()
|
|
return ErrServerShutdown
|
|
}
|
|
|
|
s.listeners[conn]++
|
|
s.mu.Unlock()
|
|
|
|
type requestKey struct {
|
|
IP string
|
|
Identifier byte
|
|
}
|
|
|
|
var (
|
|
requestsLock sync.Mutex
|
|
requests = map[requestKey]struct{}{}
|
|
)
|
|
|
|
s.activeAdd()
|
|
defer func() {
|
|
s.mu.Lock()
|
|
s.listeners[conn]--
|
|
if s.listeners[conn] == 0 {
|
|
delete(s.listeners, conn)
|
|
}
|
|
s.mu.Unlock()
|
|
s.activeDone()
|
|
}()
|
|
|
|
var buff [MaxPacketLength]byte
|
|
for {
|
|
n, remoteAddr, err := conn.ReadFrom(buff[:])
|
|
if err != nil {
|
|
if atomic.LoadInt32(&s.shutdownRequested) == 1 {
|
|
return ErrServerShutdown
|
|
}
|
|
|
|
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
|
|
s.activeAdd()
|
|
go func(buff []byte, remoteAddr net.Addr) {
|
|
defer s.activeDone()
|
|
|
|
secret, err := s.SecretSource.RADIUSSecret(s.ctx, remoteAddr)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if len(secret) == 0 {
|
|
return
|
|
}
|
|
|
|
if !s.InsecureSkipVerify && !IsAuthenticRequest(buff, secret) {
|
|
return
|
|
}
|
|
|
|
packet, err := Parse(buff, secret)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
key := requestKey{
|
|
IP: remoteAddr.String(),
|
|
Identifier: packet.Identifier,
|
|
}
|
|
requestsLock.Lock()
|
|
if _, ok := requests[key]; ok {
|
|
requestsLock.Unlock()
|
|
return
|
|
}
|
|
requests[key] = struct{}{}
|
|
requestsLock.Unlock()
|
|
|
|
response := packetResponseWriter{
|
|
conn: conn,
|
|
addr: remoteAddr,
|
|
}
|
|
|
|
defer func() {
|
|
requestsLock.Lock()
|
|
delete(requests, key)
|
|
requestsLock.Unlock()
|
|
}()
|
|
|
|
request := Request{
|
|
LocalAddr: conn.LocalAddr(),
|
|
RemoteAddr: remoteAddr,
|
|
Packet: packet,
|
|
ctx: s.ctx,
|
|
}
|
|
|
|
s.Handler.ServeRADIUS(&response, &request)
|
|
}(append([]byte(nil), buff[:n]...), remoteAddr)
|
|
}
|
|
}
|
|
|
|
// ListenAndServe starts a RADIUS server on the address given in s.
|
|
func (s *PacketServer) ListenAndServe() error {
|
|
if s.Handler == nil {
|
|
return errors.New("radius: nil Handler")
|
|
}
|
|
if s.SecretSource == nil {
|
|
return errors.New("radius: nil SecretSource")
|
|
}
|
|
|
|
addrStr := ":1812"
|
|
if s.Addr != "" {
|
|
addrStr = s.Addr
|
|
}
|
|
|
|
network := "udp"
|
|
if s.Network != "" {
|
|
network = s.Network
|
|
}
|
|
|
|
pc, err := net.ListenPacket(network, addrStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer pc.Close()
|
|
return s.Serve(pc)
|
|
}
|
|
|
|
// Shutdown gracefully stops the server. It first closes all listeners and then
|
|
// waits for any running handlers to complete.
|
|
//
|
|
// Shutdown returns after nil all handlers have completed. ctx.Err() is
|
|
// returned if ctx is canceled.
|
|
//
|
|
// Any Serve methods return ErrShutdown after Shutdown is called.
|
|
func (s *PacketServer) Shutdown(ctx context.Context) error {
|
|
s.mu.Lock()
|
|
s.initLocked()
|
|
if atomic.CompareAndSwapInt32(&s.shutdownRequested, 0, 1) {
|
|
for listener := range s.listeners {
|
|
listener.Close()
|
|
}
|
|
|
|
s.ctxDone()
|
|
s.activeDone()
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
select {
|
|
case <-s.lastActive:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|