From b25f2c2bdc297af6da784435c1bbabda2abdde78 Mon Sep 17 00:00:00 2001 From: Sudo-Ivan Date: Mon, 29 Dec 2025 22:32:57 -0600 Subject: [PATCH] Refactor Channel message handler management to use structured entries with IDs for easier identification. Update AddMessageHandler and RemoveMessageHandler methods accordingly, and adjust tests to validate new functionality. --- pkg/channel/channel.go | 29 +++++++++++++++++++---------- pkg/channel/channel_test.go | 15 ++++++++------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/pkg/channel/channel.go b/pkg/channel/channel.go index 76f3fa2..ce2c1a5 100644 --- a/pkg/channel/channel.go +++ b/pkg/channel/channel.go @@ -67,7 +67,13 @@ type Channel struct { maxTries int fastRateRounds int medRateRounds int - messageHandlers []func(MessageBase) bool + messageHandlers []messageHandlerEntry + nextHandlerID int +} + +type messageHandlerEntry struct { + id int + handler func(MessageBase) bool } // Envelope wraps a message with metadata for transmission @@ -84,7 +90,7 @@ type Envelope struct { func NewChannel(link transport.LinkInterface) *Channel { return &Channel{ link: link, - messageHandlers: make([]func(MessageBase) bool, 0), + messageHandlers: make([]messageHandlerEntry, 0), mutex: sync.RWMutex{}, windowMax: WindowMaxSlow, windowMin: WindowMinSlow, @@ -177,17 +183,20 @@ func (c *Channel) getPacketTimeout(tries int) time.Duration { return time.Duration(timeout * float64(time.Second)) } -func (c *Channel) AddMessageHandler(handler func(MessageBase) bool) { +func (c *Channel) AddMessageHandler(handler func(MessageBase) bool) int { c.mutex.Lock() defer c.mutex.Unlock() - c.messageHandlers = append(c.messageHandlers, handler) + id := c.nextHandlerID + c.nextHandlerID++ + c.messageHandlers = append(c.messageHandlers, messageHandlerEntry{id: id, handler: handler}) + return id } -func (c *Channel) RemoveMessageHandler(handler func(MessageBase) bool) { +func (c *Channel) RemoveMessageHandler(id int) { c.mutex.Lock() defer c.mutex.Unlock() - for i, h := range c.messageHandlers { - if &h == &handler { + for i, entry := range c.messageHandlers { + if entry.id == id { c.messageHandlers = append(c.messageHandlers[:i], c.messageHandlers[i+1:]...) break } @@ -236,14 +245,14 @@ func (c *Channel) HandleInbound(data []byte) error { c.mutex.Lock() defer c.mutex.Unlock() - for _, handler := range c.messageHandlers { - if handler != nil { + for _, entry := range c.messageHandlers { + if entry.handler != nil { msg := &GenericMessage{ Type: msgType, Data: msgData, Seq: sequence, } - if handler(msg) { + if entry.handler(msg) { break } } diff --git a/pkg/channel/channel_test.go b/pkg/channel/channel_test.go index ddcad13..bff2936 100644 --- a/pkg/channel/channel_test.go +++ b/pkg/channel/channel_test.go @@ -95,19 +95,20 @@ func TestHandleInbound(t *testing.T) { } func TestMessageHandlers(t *testing.T) { - c := &Channel{} + c := &Channel{ + messageHandlers: make([]messageHandlerEntry, 0), + } h := func(m MessageBase) bool { return true } - c.AddMessageHandler(h) + id := c.AddMessageHandler(h) if len(c.messageHandlers) != 1 { t.Errorf("Expected 1 handler, got %d", len(c.messageHandlers)) } - // RemoveMessageHandler in channel.go uses &h == &handler which is tricky - // for function comparisons. Let's see if it works. - c.RemoveMessageHandler(h) - // It likely won't work as expected because of how Go handles function pointers - // and closures in comparisons. But we're testing the code as is. + c.RemoveMessageHandler(id) + if len(c.messageHandlers) != 0 { + t.Errorf("Expected 0 handlers, got %d", len(c.messageHandlers)) + } } func TestGenericMessage(t *testing.T) {