diff --git a/connection/message_buffer.go b/connection/message_buffer.go index d6afa71..f513361 100644 --- a/connection/message_buffer.go +++ b/connection/message_buffer.go @@ -1,9 +1,7 @@ package connection -import "fmt" - type MessageBuffer struct { - messages []*string + messages []message getIndex int insertIndex int size int @@ -11,9 +9,14 @@ type MessageBuffer struct { firstWriteHappened bool } +type message struct { + content string + new bool +} + func newMessageBuffer(size int) *MessageBuffer { return &MessageBuffer{ - messages: make([]*string, size), + messages: make([]message, size), size: size, getIndex: 0, insertIndex: 0, @@ -23,32 +26,42 @@ func newMessageBuffer(size int) *MessageBuffer { } func (b *MessageBuffer) Insert(msg string) { - b.messages[b.insertIndex] = &msg + oldMessage := b.messages[b.insertIndex] + b.messages[b.insertIndex] = message{content: msg, new: true} + + if b.firstWriteHappened && + b.insertIndex == b.getIndex && + oldMessage.new { // insertIndex caught up with getIndex + b.getIndex = b.incrementAndWrapIndex(b.getIndex) + } + b.insertIndex = b.incrementAndWrapIndex(b.insertIndex) + b.firstWriteHappened = true + select { case b.newDataInserted <- true: default: } - - if b.firstWriteHappened && b.insertIndex-1 == b.getIndex { // insertIndex caught up with getIndex - b.getIndex = b.incrementAndWrapIndex(b.getIndex) - } - b.firstWriteHappened = true } func (b *MessageBuffer) Get() (string, error) { - if !b.firstWriteHappened || b.messages[b.getIndex] == nil { + if !b.firstWriteHappened { <-b.newDataInserted } - msg := b.messages[b.getIndex] - if msg == nil { - return "", fmt.Errorf("error getting value from buffer: value was nil") + var msg *message + for { + msg = &b.messages[b.getIndex] + if msg.new { + msg.new = false + break + } + <-b.newDataInserted } b.getIndex = b.incrementAndWrapIndex(b.getIndex) - return *msg, nil + return msg.content, nil } func (b MessageBuffer) incrementAndWrapIndex(index int) int { diff --git a/connection/message_buffer_test.go b/connection/message_buffer_test.go index 423ea2e..847d016 100644 --- a/connection/message_buffer_test.go +++ b/connection/message_buffer_test.go @@ -12,6 +12,9 @@ var ( message2 = "message-2" message3 = "message-3" message4 = "message-4" + message5 = "message-5" + message6 = "message-6" + message7 = "message-7" ) func Test_MessageBuffer_Add(t *testing.T) { @@ -33,7 +36,11 @@ func Test_MessageBuffer_Add(t *testing.T) { getIndex: 0, insertIndex: 0, firstWriteHappened: true, - messages: []*string{&message1, &message2, &message3}, + messages: []message{ + {content: message1, new: true}, + {content: message2, new: true}, + {content: message3, new: true}, + }, }, buf) }) @@ -42,10 +49,10 @@ func Test_MessageBuffer_Add(t *testing.T) { buf.Insert("message-4") assert.Equal( t, - []*string{ - &message4, - &message2, - &message3, + []message{ + {content: message4, new: true}, + {content: message2, new: true}, + {content: message3, new: true}, }, buf.messages) }) @@ -80,7 +87,7 @@ func Test_MessageBuffer_GetWaitsForNewData(t *testing.T) { assert.Equal(t, "message-1", msg) go func() { - timer := time.NewTimer(50 * time.Millisecond) + timer := time.NewTimer(500 * time.Millisecond) <-timer.C buf.Insert("delayed-message") }() @@ -96,10 +103,22 @@ func Test_MessageBuffer_IndexesAreCorrectAfterOverwritingOldData(t *testing.T) { buf.Insert("message-1") buf.Insert("message-2") - assert.Equal(t, []*string{&message1, &message2}, buf.messages) + assert.Equal( + t, + []message{ + {content: message1, new: true}, + {content: message2, new: true}, + }, + buf.messages) buf.Insert("message-3") - assert.Equal(t, []*string{&message3, &message2}, buf.messages) + assert.Equal( + t, + []message{ + {content: message3, new: true}, + {content: message2, new: true}, + }, + buf.messages) msg, err := buf.Get() assert.NoError(t, err) @@ -110,12 +129,11 @@ func Test_MessageBuffer_GetWaitsForNewDataIfOldOneWasAlreadyGotten(t *testing.T) buf := newMessageBuffer(2) buf.Insert(message1) - buf.Insert(message2) - msg, err := buf.Get() assert.NoError(t, err) assert.Equal(t, message1, msg) + buf.Insert(message2) msg, err = buf.Get() assert.NoError(t, err) assert.Equal(t, message2, msg) @@ -130,3 +148,20 @@ func Test_MessageBuffer_GetWaitsForNewDataIfOldOneWasAlreadyGotten(t *testing.T) assert.Equal(t, message3, msg) } + +func Test_MessageBuffer_InsertCatchesUpWithRead(t *testing.T) { + buf := newMessageBuffer(5) + + buf.Insert(message1) + buf.Insert(message2) + buf.Insert(message3) + buf.Insert(message4) + buf.Insert(message5) + buf.Insert(message6) + buf.Insert(message7) + + msg, err := buf.Get() + + assert.NoError(t, err) + assert.Equal(t, message3, msg) +}