refactor: enhance UDPInterface with improved concurrency handling and consistent mutex naming

This commit is contained in:
2025-12-30 21:14:42 -06:00
parent 9f36e37f94
commit 63454b3bbb

View File

@@ -14,8 +14,9 @@ type UDPInterface struct {
conn *net.UDPConn conn *net.UDPConn
addr *net.UDPAddr addr *net.UDPAddr
targetAddr *net.UDPAddr targetAddr *net.UDPAddr
mutex sync.RWMutex
readBuffer []byte readBuffer []byte
done chan struct{}
stopOnce sync.Once
} }
func NewUDPInterface(name string, addr string, target string, enabled bool) (*UDPInterface, error) { func NewUDPInterface(name string, addr string, target string, enabled bool) (*UDPInterface, error) {
@@ -36,10 +37,11 @@ func NewUDPInterface(name string, addr string, target string, enabled bool) (*UD
BaseInterface: NewBaseInterface(name, common.IF_TYPE_UDP, enabled), BaseInterface: NewBaseInterface(name, common.IF_TYPE_UDP, enabled),
addr: udpAddr, addr: udpAddr,
targetAddr: targetAddr, targetAddr: targetAddr,
readBuffer: make([]byte, 1064), readBuffer: make([]byte, common.NUM_1064),
done: make(chan struct{}),
} }
ui.MTU = 1064 ui.MTU = common.NUM_1064
return ui, nil return ui, nil
} }
@@ -57,24 +59,30 @@ func (ui *UDPInterface) GetMode() common.InterfaceMode {
} }
func (ui *UDPInterface) IsOnline() bool { func (ui *UDPInterface) IsOnline() bool {
ui.mutex.RLock() ui.Mutex.RLock()
defer ui.mutex.RUnlock() defer ui.Mutex.RUnlock()
return ui.Online return ui.Online
} }
func (ui *UDPInterface) IsDetached() bool { func (ui *UDPInterface) IsDetached() bool {
ui.mutex.RLock() ui.Mutex.RLock()
defer ui.mutex.RUnlock() defer ui.Mutex.RUnlock()
return ui.Detached return ui.Detached
} }
func (ui *UDPInterface) Detach() { func (ui *UDPInterface) Detach() {
ui.mutex.Lock() ui.Mutex.Lock()
defer ui.mutex.Unlock() defer ui.Mutex.Unlock()
ui.Detached = true ui.Detached = true
ui.Online = false
if ui.conn != nil { if ui.conn != nil {
ui.conn.Close() // #nosec G104 ui.conn.Close() // #nosec G104
} }
ui.stopOnce.Do(func() {
if ui.done != nil {
close(ui.done)
}
})
} }
func (ui *UDPInterface) Send(data []byte, addr string) error { func (ui *UDPInterface) Send(data []byte, addr string) error {
@@ -88,10 +96,9 @@ func (ui *UDPInterface) Send(data []byte, addr string) error {
return fmt.Errorf("no target address configured") return fmt.Errorf("no target address configured")
} }
// Update TX stats before sending ui.Mutex.Lock()
ui.mutex.Lock()
ui.TxBytes += uint64(len(data)) ui.TxBytes += uint64(len(data))
ui.mutex.Unlock() ui.Mutex.Unlock()
_, err := ui.conn.WriteTo(data, ui.targetAddr) _, err := ui.conn.WriteTo(data, ui.targetAddr)
if err != nil { if err != nil {
@@ -103,14 +110,14 @@ func (ui *UDPInterface) Send(data []byte, addr string) error {
} }
func (ui *UDPInterface) SetPacketCallback(callback common.PacketCallback) { func (ui *UDPInterface) SetPacketCallback(callback common.PacketCallback) {
ui.mutex.Lock() ui.Mutex.Lock()
defer ui.mutex.Unlock() defer ui.Mutex.Unlock()
ui.packetCallback = callback ui.packetCallback = callback
} }
func (ui *UDPInterface) GetPacketCallback() common.PacketCallback { func (ui *UDPInterface) GetPacketCallback() common.PacketCallback {
ui.mutex.RLock() ui.Mutex.RLock()
defer ui.mutex.RUnlock() defer ui.Mutex.RUnlock()
return ui.packetCallback return ui.packetCallback
} }
@@ -134,9 +141,9 @@ func (ui *UDPInterface) ProcessOutgoing(data []byte) error {
return fmt.Errorf("UDP write failed: %v", err) return fmt.Errorf("UDP write failed: %v", err)
} }
ui.mutex.Lock() ui.Mutex.Lock()
ui.TxBytes += uint64(len(data)) ui.TxBytes += uint64(len(data))
ui.mutex.Unlock() ui.Mutex.Unlock()
return nil return nil
} }
@@ -146,14 +153,14 @@ func (ui *UDPInterface) GetConn() net.Conn {
} }
func (ui *UDPInterface) GetTxBytes() uint64 { func (ui *UDPInterface) GetTxBytes() uint64 {
ui.mutex.RLock() ui.Mutex.RLock()
defer ui.mutex.RUnlock() defer ui.Mutex.RUnlock()
return ui.TxBytes return ui.TxBytes
} }
func (ui *UDPInterface) GetRxBytes() uint64 { func (ui *UDPInterface) GetRxBytes() uint64 {
ui.mutex.RLock() ui.Mutex.RLock()
defer ui.mutex.RUnlock() defer ui.Mutex.RUnlock()
return ui.RxBytes return ui.RxBytes
} }
@@ -166,18 +173,36 @@ func (ui *UDPInterface) GetBitrate() int {
} }
func (ui *UDPInterface) Enable() { func (ui *UDPInterface) Enable() {
ui.mutex.Lock() ui.Mutex.Lock()
defer ui.mutex.Unlock() defer ui.Mutex.Unlock()
ui.Online = true ui.Online = true
} }
func (ui *UDPInterface) Disable() { func (ui *UDPInterface) Disable() {
ui.mutex.Lock() ui.Mutex.Lock()
defer ui.mutex.Unlock() defer ui.Mutex.Unlock()
ui.Online = false ui.Online = false
} }
func (ui *UDPInterface) Start() error { func (ui *UDPInterface) Start() error {
ui.Mutex.Lock()
if ui.conn != nil {
ui.Mutex.Unlock()
return fmt.Errorf("UDP interface already started")
}
// Only recreate done if it's nil or was closed
select {
case <-ui.done:
ui.done = make(chan struct{})
ui.stopOnce = sync.Once{}
default:
if ui.done == nil {
ui.done = make(chan struct{})
ui.stopOnce = sync.Once{}
}
}
ui.Mutex.Unlock()
conn, err := net.ListenUDP("udp", ui.addr) conn, err := net.ListenUDP("udp", ui.addr)
if err != nil { if err != nil {
return err return err
@@ -187,15 +212,17 @@ func (ui *UDPInterface) Start() error {
// Enable broadcast mode if we have a target address // Enable broadcast mode if we have a target address
if ui.targetAddr != nil { if ui.targetAddr != nil {
// Get the raw connection file descriptor to set SO_BROADCAST // Get the raw connection file descriptor to set SO_BROADCAST
if err := conn.SetReadBuffer(1064); err != nil { if err := conn.SetReadBuffer(common.NUM_1064); err != nil {
debug.Log(debug.DEBUG_ERROR, "Failed to set read buffer size", "error", err) debug.Log(debug.DEBUG_ERROR, "Failed to set read buffer size", "error", err)
} }
if err := conn.SetWriteBuffer(1064); err != nil { if err := conn.SetWriteBuffer(common.NUM_1064); err != nil {
debug.Log(debug.DEBUG_ERROR, "Failed to set write buffer size", "error", err) debug.Log(debug.DEBUG_ERROR, "Failed to set write buffer size", "error", err)
} }
} }
ui.Mutex.Lock()
ui.Online = true ui.Online = true
ui.Mutex.Unlock()
// Start the read loop in a goroutine // Start the read loop in a goroutine
go ui.readLoop() go ui.readLoop()
@@ -203,19 +230,43 @@ func (ui *UDPInterface) Start() error {
return nil return nil
} }
func (ui *UDPInterface) Stop() error {
ui.Detach()
return nil
}
func (ui *UDPInterface) readLoop() { func (ui *UDPInterface) readLoop() {
buffer := make([]byte, 1064) buffer := make([]byte, common.NUM_1064)
for ui.IsOnline() && !ui.IsDetached() { for {
n, remoteAddr, err := ui.conn.ReadFromUDP(buffer) ui.Mutex.RLock()
online := ui.Online
detached := ui.Detached
conn := ui.conn
done := ui.done
ui.Mutex.RUnlock()
if !online || detached || conn == nil {
return
}
select {
case <-done:
return
default:
}
n, remoteAddr, err := conn.ReadFromUDP(buffer)
if err != nil { if err != nil {
if ui.IsOnline() { ui.Mutex.RLock()
stillOnline := ui.Online
ui.Mutex.RUnlock()
if stillOnline {
debug.Log(debug.DEBUG_ERROR, "Error reading from UDP interface", "name", ui.Name, "error", err) debug.Log(debug.DEBUG_ERROR, "Error reading from UDP interface", "name", ui.Name, "error", err)
} }
return return
} }
// Update RX stats ui.Mutex.Lock()
ui.mutex.Lock()
// #nosec G115 - Network read sizes are always positive and within safe range // #nosec G115 - Network read sizes are always positive and within safe range
ui.RxBytes += uint64(n) ui.RxBytes += uint64(n)
@@ -224,16 +275,17 @@ func (ui *UDPInterface) readLoop() {
debug.Log(debug.DEBUG_ALL, "UDP interface discovered peer", "name", ui.Name, "peer", remoteAddr.String()) debug.Log(debug.DEBUG_ALL, "UDP interface discovered peer", "name", ui.Name, "peer", remoteAddr.String())
ui.targetAddr = remoteAddr ui.targetAddr = remoteAddr
} }
ui.mutex.Unlock() callback := ui.packetCallback
ui.Mutex.Unlock()
if ui.packetCallback != nil { if callback != nil {
ui.packetCallback(buffer[:n], ui) callback(buffer[:n], ui)
} }
} }
} }
func (ui *UDPInterface) IsEnabled() bool { func (ui *UDPInterface) IsEnabled() bool {
ui.mutex.RLock() ui.Mutex.RLock()
defer ui.mutex.RUnlock() defer ui.Mutex.RUnlock()
return ui.Enabled && ui.Online && !ui.Detached return ui.Enabled && ui.Online && !ui.Detached
} }