From fd2fb3fab641f4dbde4475a1ce3bcc65221bf18d Mon Sep 17 00:00:00 2001 From: Marco Date: Tue, 21 May 2024 23:41:37 +0200 Subject: [PATCH] change behavior of reading/writing from websockets --- api/handler/websocket.go | 6 +- chess/game.go | 10 ++- chess/player.go | 38 ++++------- connection/message_buffer.go | 4 +- connection/message_buffer_test.go | 26 +++----- connection/type.go | 101 ++++++++++++++---------------- lobbies/usher.go | 12 +--- 7 files changed, 83 insertions(+), 114 deletions(-) diff --git a/api/handler/websocket.go b/api/handler/websocket.go index 646064a..028a497 100644 --- a/api/handler/websocket.go +++ b/api/handler/websocket.go @@ -68,11 +68,9 @@ func waitForAndHandlePlayerID(ctx context.Context, conn *gorillaws.Conn) { conn.Close() return } - if player.Conn.HasWebsocketConnection() { - conn.WriteMessage(msgType, []byte("player already connected")) - return - } + lobby.Game.SetWebsocketConnectionFor(ctx, player, conn) + log.Println("player after setting connection: ") log.Println("id: ", player.Uuid) log.Println("color: ", player.GetColor()) diff --git a/chess/game.go b/chess/game.go index c768873..f0fe484 100644 --- a/chess/game.go +++ b/chess/game.go @@ -117,7 +117,7 @@ func (game *Game) Handle() { log.Println("Error marshalling 'colorDetermined' message for player 1", err) return } - game.currentTurnPlayer.writeMessage(invalidMoveMessage) + game.currentTurnPlayer.writeMessage(string(invalidMoveMessage)) game.gameState = PlayerToMove continue } @@ -181,10 +181,10 @@ func (game Game) notifyPlayersAboutGameStart() error { return err } - game.GetPlayer1().writeMessage(colorDeterminedPlayer1) + game.GetPlayer1().writeMessage(string(colorDeterminedPlayer1)) game.GetPlayer1().SendBoardState(types.Move{}, game.board.PGN(), types.White) - game.GetPlayer2().writeMessage(colorDeterminedPlayer2) + game.GetPlayer2().writeMessage(string(colorDeterminedPlayer2)) game.GetPlayer2().SendBoardState(types.Move{}, game.board.PGN(), types.White) return nil } @@ -223,3 +223,7 @@ func (game *Game) playerDisconnected(p *Player) { func (game *Game) SetWebsocketConnectionFor(ctx context.Context, p *Player, ws *gorillaws.Conn) { p.SetWebsocketConnectionAndSendBoardState(ctx, ws, &game.board) } + +func (game *Game) SendBoardStateTo(p *Player) { + p.SendBoardState(game.board.getLastMove(), game.board.PGN(), game.board.colorToMove) +} diff --git a/chess/player.go b/chess/player.go index f2d9278..79a8c1a 100644 --- a/chess/player.go +++ b/chess/player.go @@ -89,11 +89,8 @@ func (p *Player) SendBoardState(move types.Move, boardPosition string, turnColor return err } - err = p.writeMessage(messageToSend) - if err != nil { - log.Println("Error during message writing:", err) - return err - } + p.writeMessage(string(messageToSend)) + return nil } @@ -108,11 +105,8 @@ func (p *Player) SendMoveAndPosition(move types.Move, boardPosition string) erro return err } - err = p.writeMessage(messageToSend) - if err != nil { - log.Println("Error during message writing:", err) - return err - } + p.writeMessage(string(messageToSend)) + return nil } @@ -126,26 +120,20 @@ func (p *Player) SendGameEnded(reason GameEndedReason) error { log.Println("Error while marshalling: ", err) return err } - err = p.writeMessage(messageToSend) - if err != nil { - log.Println("Error during message writing:", err) - return err - } + p.writeMessage(string(messageToSend)) + return nil } -func (p *Player) writeMessage(msg []byte) error { - return p.Conn.Write(msg) +func (p *Player) writeMessage(msg string) { + p.Conn.Write(msg) } func (p *Player) ReadMove() (types.Move, error) { - receivedMessage, err := p.readMessage() - if err != nil { - return types.Move{}, err - } + receivedMessage := p.readMessage() var msg api.WebsocketMessage - err = json.Unmarshal(receivedMessage, &msg) + err := json.Unmarshal(receivedMessage, &msg) if err != nil { return types.Move{}, err } @@ -157,9 +145,9 @@ func (p *Player) ReadMove() (types.Move, error) { return *msg.Move, nil } -func (p *Player) readMessage() ([]byte, error) { - msg, err := p.Conn.Read() +func (p *Player) readMessage() []byte { + msg := p.Conn.Read() log.Printf("Reading message from %s: %s", p.color.String(), string(msg)) - return msg, err + return msg } diff --git a/connection/message_buffer.go b/connection/message_buffer.go index a2613ce..45704e8 100644 --- a/connection/message_buffer.go +++ b/connection/message_buffer.go @@ -50,7 +50,7 @@ func (b *MessageBuffer) Insert(msg string) { b.cond.Broadcast() } -func (b *MessageBuffer) Get() (string, error) { +func (b *MessageBuffer) Get() string { b.cond.L.Lock() defer b.cond.L.Unlock() @@ -69,7 +69,7 @@ func (b *MessageBuffer) Get() (string, error) { } b.getIndex = b.incrementAndWrapIndex(b.getIndex) - return msg.content, nil + return msg.content } func (b MessageBuffer) incrementAndWrapIndex(index int) int { diff --git a/connection/message_buffer_test.go b/connection/message_buffer_test.go index 0b97e64..e3deb61 100644 --- a/connection/message_buffer_test.go +++ b/connection/message_buffer_test.go @@ -66,8 +66,7 @@ func Test_MessageBuffer_GetWaitsForFirstData(t *testing.T) { buf.Insert("delayed-message") }() - msg, err := buf.Get() - assert.NoError(t, err) + msg := buf.Get() endTime := time.Now() @@ -79,8 +78,7 @@ func Test_MessageBuffer_GetWaitsForNewData(t *testing.T) { buf := newMessageBuffer(2) buf.Insert("message-1") - msg, err := buf.Get() - assert.NoError(t, err) + msg := buf.Get() assert.Equal(t, "message-1", msg) go func() { @@ -89,8 +87,7 @@ func Test_MessageBuffer_GetWaitsForNewData(t *testing.T) { buf.Insert("delayed-message") }() - msg, err = buf.Get() - assert.NoError(t, err) + msg = buf.Get() assert.Equal(t, "delayed-message", msg) } @@ -117,8 +114,7 @@ func Test_MessageBuffer_IndexesAreCorrectAfterOverwritingOldData(t *testing.T) { }, buf.messages) - msg, err := buf.Get() - assert.NoError(t, err) + msg := buf.Get() assert.Equal(t, "message-2", msg) } @@ -126,13 +122,11 @@ func Test_MessageBuffer_GetWaitsForNewDataIfOldOneWasAlreadyGotten(t *testing.T) buf := newMessageBuffer(2) buf.Insert(message1) - msg, err := buf.Get() - assert.NoError(t, err) + msg := buf.Get() assert.Equal(t, message1, msg) buf.Insert(message2) - msg, err = buf.Get() - assert.NoError(t, err) + msg = buf.Get() assert.Equal(t, message2, msg) go func() { @@ -140,8 +134,7 @@ func Test_MessageBuffer_GetWaitsForNewDataIfOldOneWasAlreadyGotten(t *testing.T) buf.Insert(message3) }() - msg, err = buf.Get() - assert.NoError(t, err) + msg = buf.Get() assert.Equal(t, message3, msg) } @@ -157,9 +150,8 @@ func Test_MessageBuffer_InsertCatchesUpWithRead(t *testing.T) { buf.Insert(message6) buf.Insert(message7) - msg, err := buf.Get() + msg := buf.Get() - assert.NoError(t, err) assert.Equal(t, message3, msg) } @@ -172,7 +164,7 @@ func Test_MessageBuffer_FuckShitUp(t *testing.T) { var readMsg = make([]string, 0) go func() { for i := 0; i < size*10; i++ { - msg, _ := buf.Get() + msg := buf.Get() if msg == "99" { break } diff --git a/connection/type.go b/connection/type.go index 7cc3ba9..f3566d7 100644 --- a/connection/type.go +++ b/connection/type.go @@ -4,28 +4,26 @@ import ( "context" "log" "mchess_server/types" - "sync" "github.com/google/uuid" gorillaws "github.com/gorilla/websocket" ) type Connection struct { - ID uuid.UUID - ws *gorillaws.Conn - wsConnectionEstablished chan bool - wsWriteLock sync.Mutex - ctx context.Context - buffer MessageBuffer - disconnectCallback func() - forColor types.ChessColor + ID uuid.UUID + ws *gorillaws.Conn + ctx context.Context + rxBuffer *MessageBuffer + txBuffer *MessageBuffer + disconnectCallback func() + forColor types.ChessColor } func NewConnection(options ...func(*Connection)) *Connection { connection := Connection{ - ID: uuid.New(), - buffer: *newMessageBuffer(100), - wsConnectionEstablished: make(chan bool), + ID: uuid.New(), + rxBuffer: newMessageBuffer(100), + txBuffer: newMessageBuffer(100), } for _, option := range options { @@ -67,6 +65,35 @@ func (conn *Connection) HasWebsocketConnection() bool { return conn.ws != nil } +func (conn *Connection) readFromRxBuffer() { + for { + _, msg, err := conn.ws.ReadMessage() + if err != nil { + conn.logConnection("while reading from websocket: %w", err) + conn.Close("") + conn.txBuffer.Insert("we do this to make txBuffer.Get() return") + return + } + conn.rxBuffer.Insert(string(msg)) + } +} + +func (conn *Connection) writeTxBuffer() { + for { + msg := conn.txBuffer.Get() + + if conn.ws == nil { + return + } + + err := conn.ws.WriteMessage(gorillaws.TextMessage, []byte(msg)) + if err != nil { + conn.logConnection("while writing to websocket: %w", err) + return + } + } +} + func (conn *Connection) SetWebsocketConnection(ws *gorillaws.Conn) { if ws == nil { conn.logConnection("ERROR: setting ws = null") @@ -75,28 +102,9 @@ func (conn *Connection) SetWebsocketConnection(ws *gorillaws.Conn) { conn.ws = ws - select { - case conn.wsConnectionEstablished <- true: - conn.logConnection("case wsConnectionEstablished <- true") - default: - conn.logConnection("DEFAULT CASE") - } + go conn.readFromRxBuffer() + go conn.writeTxBuffer() - go func() { - for { - _, msg, err := conn.ws.ReadMessage() - if err != nil { - conn.logConnection("while reading from websocket: %w", err) - - conn.unsetWebsocketConnection() - if conn.disconnectCallback != nil { - conn.disconnectCallback() - } - return - } - conn.buffer.Insert(string(msg)) - } - }() defer conn.logConnection("websocket connection set") } @@ -105,30 +113,15 @@ func (conn *Connection) unsetWebsocketConnection() { conn.ws = nil } -func (conn *Connection) Write(msg []byte) error { - conn.logConnection("about to write") - conn.logConnection("locking") - conn.wsWriteLock.Lock() - defer conn.logConnection("unlocking") - defer conn.wsWriteLock.Unlock() - - if conn.ws == nil { //if ws is not yet set, we wait for it - conn.logConnection("waiting for wsConnectionEstablished channel") - <-conn.wsConnectionEstablished - } - - conn.logConnection("Writing message: %s", string(msg)) - return conn.ws.WriteMessage(gorillaws.TextMessage, msg) +func (conn *Connection) Write(msg string) { + conn.logConnection("Writing message: ", string(msg)) + conn.txBuffer.Insert(msg) } -func (conn *Connection) Read() ([]byte, error) { - msg, err := conn.buffer.Get() - if err != nil { - conn.ws = nil - return nil, err // TODO: Tell game-handler that connection was lost - } +func (conn *Connection) Read() []byte { + msg := conn.rxBuffer.Get() - return []byte(msg), err + return []byte(msg) } func (conn *Connection) Close(msg string) { diff --git a/lobbies/usher.go b/lobbies/usher.go index 00fcfc1..37a6b69 100644 --- a/lobbies/usher.go +++ b/lobbies/usher.go @@ -22,21 +22,15 @@ func GetUsher() *Usher { } func (u *Usher) WelcomeNewPlayer(player *chess.Player) *Lobby { - lobby := GetLobbyRegistry().GetLobbyForPlayer() - return lobby + return GetLobbyRegistry().GetLobbyForPlayer() } func (u *Usher) CreateNewPrivateLobby(player *chess.Player) *Lobby { - lobby := GetLobbyRegistry().CreateNewPrivateLobby() - return lobby + return GetLobbyRegistry().CreateNewPrivateLobby() } func (u *Usher) FindExistingPrivateLobby(p utils.Passphrase) *Lobby { - lobby := GetLobbyRegistry().GetLobbyByPassphrase(p) - if lobby == nil || lobby.AreBothPlayersConnected() { - return nil - } - return lobby + return GetLobbyRegistry().GetLobbyByPassphrase(p) } func (u *Usher) AddPlayerToLobbyAndStartGameIfFull(player *chess.Player, lobby *Lobby) {