diff --git a/chess/player.go b/chess/player.go index 27d43d3..2ebc183 100644 --- a/chess/player.go +++ b/chess/player.go @@ -6,7 +6,7 @@ import ( "errors" "log" "mchess_server/api" - con "mchess_server/connection" + conn "mchess_server/connection" "mchess_server/types" "time" @@ -16,7 +16,7 @@ import ( type Player struct { Uuid uuid.UUID - Conn *con.Connection + Conn *conn.Connection InGame bool wsConnEstablished chan bool context context.Context @@ -28,11 +28,12 @@ func NewPlayer(uuid uuid.UUID) *Player { Conn: nil, InGame: false, wsConnEstablished: make(chan bool), + context: context.Background(), } } -func (p *Player) SetConnection(ctx context.Context, conn *websocket.Conn) { - p.Conn = con.NewConnection(con.WithWebsocket(conn)) +func (p *Player) SetConnection(ctx context.Context, ws *websocket.Conn) { + p.Conn = conn.NewConnection(conn.WithWebsocket(ws), conn.WithContext(p.context)) p.context = ctx p.wsConnEstablished <- true } diff --git a/connection/message_buffer.go b/connection/message_buffer.go index f513361..5eccc5b 100644 --- a/connection/message_buffer.go +++ b/connection/message_buffer.go @@ -1,5 +1,7 @@ package connection +import "sync" + type MessageBuffer struct { messages []message getIndex int @@ -7,6 +9,7 @@ type MessageBuffer struct { size int newDataInserted chan bool firstWriteHappened bool + cond *sync.Cond } type message struct { @@ -15,6 +18,9 @@ type message struct { } func newMessageBuffer(size int) *MessageBuffer { + mutex := &sync.Mutex{} + cond := sync.NewCond(mutex) + return &MessageBuffer{ messages: make([]message, size), size: size, @@ -22,10 +28,14 @@ func newMessageBuffer(size int) *MessageBuffer { insertIndex: 0, newDataInserted: make(chan bool), firstWriteHappened: false, + cond: cond, } } func (b *MessageBuffer) Insert(msg string) { + b.cond.L.Lock() + defer b.cond.L.Unlock() + oldMessage := b.messages[b.insertIndex] b.messages[b.insertIndex] = message{content: msg, new: true} @@ -38,16 +48,15 @@ func (b *MessageBuffer) Insert(msg string) { b.insertIndex = b.incrementAndWrapIndex(b.insertIndex) b.firstWriteHappened = true - - select { - case b.newDataInserted <- true: - default: - } + b.cond.Broadcast() } func (b *MessageBuffer) Get() (string, error) { + b.cond.L.Lock() + defer b.cond.L.Unlock() + if !b.firstWriteHappened { - <-b.newDataInserted + b.cond.Wait() } var msg *message @@ -57,7 +66,7 @@ func (b *MessageBuffer) Get() (string, error) { msg.new = false break } - <-b.newDataInserted + b.cond.Wait() } b.getIndex = b.incrementAndWrapIndex(b.getIndex) diff --git a/connection/message_buffer_test.go b/connection/message_buffer_test.go index 847d016..0b97e64 100644 --- a/connection/message_buffer_test.go +++ b/connection/message_buffer_test.go @@ -1,6 +1,9 @@ package connection import ( + "fmt" + "strconv" + "sync" "testing" "time" @@ -19,7 +22,6 @@ var ( func Test_MessageBuffer_Add(t *testing.T) { buf := newMessageBuffer(3) - buf.newDataInserted = nil //otherwise, this would break the test t.Run("insert without wrapping", func(t *testing.T) { buf.Insert("message-1") @@ -31,18 +33,13 @@ func Test_MessageBuffer_Add(t *testing.T) { buf.Insert("message-3") assert.Equal( t, - &MessageBuffer{ - size: 3, - getIndex: 0, - insertIndex: 0, - firstWriteHappened: true, - messages: []message{ - {content: message1, new: true}, - {content: message2, new: true}, - {content: message3, new: true}, - }, + []message{ + {content: message1, new: true}, + {content: message2, new: true}, + {content: message3, new: true}, }, - buf) + buf.messages, + ) }) t.Run("insert that causes wrapping", func(t *testing.T) { @@ -87,7 +84,7 @@ func Test_MessageBuffer_GetWaitsForNewData(t *testing.T) { assert.Equal(t, "message-1", msg) go func() { - timer := time.NewTimer(500 * time.Millisecond) + timer := time.NewTimer(100 * time.Millisecond) <-timer.C buf.Insert("delayed-message") }() @@ -165,3 +162,38 @@ func Test_MessageBuffer_InsertCatchesUpWithRead(t *testing.T) { assert.NoError(t, err) assert.Equal(t, message3, msg) } + +func Test_MessageBuffer_FuckShitUp(t *testing.T) { + size := 10 + buf := newMessageBuffer(size) + wg := sync.WaitGroup{} + + wg.Add(2) + var readMsg = make([]string, 0) + go func() { + for i := 0; i < size*10; i++ { + msg, _ := buf.Get() + if msg == "99" { + break + } + fmt.Println("i = ", i, ": msg = ", msg) + readMsg = append(readMsg, msg) + } + wg.Done() + }() + + go func() { + for i := 0; i < size*10; i++ { + if i%10 == 0 { + timer := time.NewTimer(1 * time.Millisecond) + <-timer.C + } + buf.Insert(strconv.Itoa(i)) + } + wg.Done() + }() + + wg.Wait() + fmt.Println(buf.messages) + fmt.Println(readMsg) +} diff --git a/connection/type.go b/connection/type.go index a48582a..5b2d31c 100644 --- a/connection/type.go +++ b/connection/type.go @@ -7,16 +7,29 @@ import ( ) type Connection struct { - ws *websocket.Conn + ws *websocket.Conn + ctx context.Context + buffer MessageBuffer } func NewConnection(options ...func(*Connection)) *Connection { - connection := Connection{} + connection := Connection{ + buffer: *newMessageBuffer(100), + } for _, option := range options { option(&connection) } + if connection.ws != nil { + go func() { + for { + _, msg, _ := connection.ws.Read(connection.ctx) + connection.buffer.Insert(string(msg)) + } + }() + } + return &connection } @@ -26,21 +39,23 @@ func WithWebsocket(ws *websocket.Conn) func(*Connection) { } } +func WithContext(ctx context.Context) func(*Connection) { + return func(c *Connection) { + c.ctx = ctx + } +} + func (conn *Connection) Write(ctx context.Context, msg []byte) error { return conn.ws.Write(ctx, websocket.MessageText, msg) } func (conn *Connection) Read(ctx context.Context) ([]byte, error) { - var msg []byte - var err error - for { - _, msg, err = conn.ws.Read(ctx) - if err != nil { - return nil, err // Tell game-handler that connection was lost - } + msg, err := conn.buffer.Get() + if err != nil { + return nil, err // Tell game-handler that connection was lost } - return msg, err + return []byte(msg), err } func (conn *Connection) Close(msg string) {