232 lines
5.5 KiB
Go
232 lines
5.5 KiB
Go
|
package freeport
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/hashicorp/consul/sdk/testutil/retry"
|
||
|
)
|
||
|
|
||
|
func TestTakeReturn(t *testing.T) {
|
||
|
// NOTE: for global var reasons this cannot execute in parallel
|
||
|
// t.Parallel()
|
||
|
|
||
|
// Since this test is destructive (i.e. it leaks all ports) it means that
|
||
|
// any other test cases in this package will not function after it runs. To
|
||
|
// help out we reset the global state after we run this test.
|
||
|
defer reset()
|
||
|
|
||
|
// OK: do a simple take/return cycle to trigger the package initialization
|
||
|
func() {
|
||
|
ports, err := Take(1)
|
||
|
if err != nil {
|
||
|
t.Fatalf("err: %v", err)
|
||
|
}
|
||
|
defer Return(ports)
|
||
|
|
||
|
if len(ports) != 1 {
|
||
|
t.Fatalf("expected %d but got %d ports", 1, len(ports))
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
waitForStatsReset := func() (numTotal int) {
|
||
|
t.Helper()
|
||
|
numTotal, numPending, numFree := stats()
|
||
|
if numTotal != numFree+numPending {
|
||
|
t.Fatalf("expected total (%d) and free+pending (%d) ports to match", numTotal, numFree+numPending)
|
||
|
}
|
||
|
retry.Run(t, func(r *retry.R) {
|
||
|
numTotal, numPending, numFree = stats()
|
||
|
if numPending != 0 {
|
||
|
r.Fatalf("pending is still non zero: %d", numPending)
|
||
|
}
|
||
|
if numTotal != numFree {
|
||
|
r.Fatalf("total (%d) does not equal free (%d)", numTotal, numFree)
|
||
|
}
|
||
|
})
|
||
|
return numTotal
|
||
|
}
|
||
|
|
||
|
// Reset
|
||
|
numTotal := waitForStatsReset()
|
||
|
|
||
|
// --------------------
|
||
|
// OK: take the max
|
||
|
func() {
|
||
|
ports, err := Take(numTotal)
|
||
|
if err != nil {
|
||
|
t.Fatalf("err: %v", err)
|
||
|
}
|
||
|
defer Return(ports)
|
||
|
|
||
|
if len(ports) != numTotal {
|
||
|
t.Fatalf("expected %d but got %d ports", numTotal, len(ports))
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
// Reset
|
||
|
numTotal = waitForStatsReset()
|
||
|
|
||
|
expectError := func(expected string, got error) {
|
||
|
t.Helper()
|
||
|
if got == nil {
|
||
|
t.Fatalf("expected error but was nil")
|
||
|
}
|
||
|
if got.Error() != expected {
|
||
|
t.Fatalf("expected error %q but got %q", expected, got.Error())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// --------------------
|
||
|
// ERROR: take too many ports
|
||
|
func() {
|
||
|
ports, err := Take(numTotal + 1)
|
||
|
defer Return(ports)
|
||
|
expectError("freeport: block size too small", err)
|
||
|
}()
|
||
|
|
||
|
// --------------------
|
||
|
// ERROR: invalid ports request (negative)
|
||
|
func() {
|
||
|
_, err := Take(-1)
|
||
|
expectError("freeport: cannot take -1 ports", err)
|
||
|
}()
|
||
|
|
||
|
// --------------------
|
||
|
// ERROR: invalid ports request (zero)
|
||
|
func() {
|
||
|
_, err := Take(0)
|
||
|
expectError("freeport: cannot take 0 ports", err)
|
||
|
}()
|
||
|
|
||
|
// --------------------
|
||
|
// OK: Steal a port under the covers and let freeport detect the theft and compensate
|
||
|
leakedPort := peekFree()
|
||
|
func() {
|
||
|
leakyListener, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", leakedPort))
|
||
|
if err != nil {
|
||
|
t.Fatalf("err: %v", err)
|
||
|
}
|
||
|
defer leakyListener.Close()
|
||
|
|
||
|
func() {
|
||
|
ports, err := Take(3)
|
||
|
if err != nil {
|
||
|
t.Fatalf("err: %v", err)
|
||
|
}
|
||
|
defer Return(ports)
|
||
|
|
||
|
if len(ports) != 3 {
|
||
|
t.Fatalf("expected %d but got %d ports", 3, len(ports))
|
||
|
}
|
||
|
|
||
|
for _, port := range ports {
|
||
|
if port == leakedPort {
|
||
|
t.Fatalf("did not expect for Take to return the leaked port")
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
newNumTotal := waitForStatsReset()
|
||
|
if newNumTotal != numTotal-1 {
|
||
|
t.Fatalf("expected total to drop to %d but got %d", numTotal-1, newNumTotal)
|
||
|
}
|
||
|
numTotal = newNumTotal // update outer variable for later tests
|
||
|
}()
|
||
|
|
||
|
// --------------------
|
||
|
// OK: sequence it so that one Take must wait on another Take to Return.
|
||
|
func() {
|
||
|
mostPorts, err := Take(numTotal - 5)
|
||
|
if err != nil {
|
||
|
t.Fatalf("err: %v", err)
|
||
|
}
|
||
|
|
||
|
type reply struct {
|
||
|
ports []int
|
||
|
err error
|
||
|
}
|
||
|
ch := make(chan reply, 1)
|
||
|
go func() {
|
||
|
ports, err := Take(10)
|
||
|
ch <- reply{ports: ports, err: err}
|
||
|
}()
|
||
|
|
||
|
Return(mostPorts)
|
||
|
|
||
|
r := <-ch
|
||
|
if r.err != nil {
|
||
|
t.Fatalf("err: %v", r.err)
|
||
|
}
|
||
|
defer Return(r.ports)
|
||
|
|
||
|
if len(r.ports) != 10 {
|
||
|
t.Fatalf("expected %d ports but got %d", 10, len(r.ports))
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
// Reset
|
||
|
numTotal = waitForStatsReset()
|
||
|
|
||
|
// --------------------
|
||
|
// ERROR: Now we end on the crazy "Ocean's 11" level port theft where we
|
||
|
// orchestrate a situation where all ports are stolen and we don't find out
|
||
|
// until Take.
|
||
|
func() {
|
||
|
// 1. Grab all of the ports.
|
||
|
allPorts := peekAllFree()
|
||
|
|
||
|
// 2. Leak all of the ports
|
||
|
leaked := make([]io.Closer, 0, len(allPorts))
|
||
|
defer func() {
|
||
|
for _, c := range leaked {
|
||
|
c.Close()
|
||
|
}
|
||
|
}()
|
||
|
for _, port := range allPorts {
|
||
|
ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", port))
|
||
|
if err != nil {
|
||
|
t.Fatalf("err: %v", err)
|
||
|
}
|
||
|
leaked = append(leaked, ln)
|
||
|
}
|
||
|
|
||
|
// 3. Request 1 port which will detect the leaked ports and fail.
|
||
|
_, err := Take(1)
|
||
|
expectError("freeport: impossible to satisfy request; there are no actual free ports in the block anymore", err)
|
||
|
|
||
|
// 4. Wait for the block to zero out.
|
||
|
newNumTotal := waitForStatsReset()
|
||
|
if newNumTotal != 0 {
|
||
|
t.Fatalf("expected total to drop to %d but got %d", 0, newNumTotal)
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
func TestIntervalOverlap(t *testing.T) {
|
||
|
cases := []struct {
|
||
|
min1, max1, min2, max2 int
|
||
|
overlap bool
|
||
|
}{
|
||
|
{0, 0, 0, 0, true},
|
||
|
{1, 1, 1, 1, true},
|
||
|
{1, 3, 1, 3, true}, // same
|
||
|
{1, 3, 4, 6, false}, // serial
|
||
|
{1, 4, 3, 6, true}, // inner overlap
|
||
|
{1, 6, 3, 4, true}, // nest
|
||
|
}
|
||
|
|
||
|
for _, tc := range cases {
|
||
|
t.Run(fmt.Sprintf("%d:%d vs %d:%d", tc.min1, tc.max1, tc.min2, tc.max2), func(t *testing.T) {
|
||
|
if tc.overlap != intervalOverlap(tc.min1, tc.max1, tc.min2, tc.max2) { // 1 vs 2
|
||
|
t.Fatalf("expected %v but got %v", tc.overlap, !tc.overlap)
|
||
|
}
|
||
|
if tc.overlap != intervalOverlap(tc.min2, tc.max2, tc.min1, tc.max1) { // 2 vs 1
|
||
|
t.Fatalf("expected %v but got %v", tc.overlap, !tc.overlap)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|