diff --git a/.gitignore b/.gitignore index 5fb42c6..4456aa5 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ /node_modules /public/bundle.js /swagger +result diff --git a/auth.go b/auth.go index 374dd7b..d84ab79 100644 --- a/auth.go +++ b/auth.go @@ -14,6 +14,21 @@ import ( Authentication server */ +func getAvailableProfiles(user *User) ([]Profile, error) { + var availableProfiles []Profile + for _, player := range user.Players { + id, err := UUIDToID(player.UUID) + if err != nil { + return nil, err + } + availableProfiles = append(availableProfiles, Profile{ + ID: id, + Name: player.Name, + }) + } + return availableProfiles, nil +} + type UserProperty struct { Name string `json:"name"` Value string `json:"value"` @@ -31,6 +46,10 @@ var invalidAccessTokenBlob []byte = Unwrap(json.Marshal(ErrorResponse{ Error: Ptr("ForbiddenOperationException"), ErrorMessage: Ptr("Invalid token."), })) +var playerNotFoundBlob []byte = Unwrap(json.Marshal(ErrorResponse{ + Error: Ptr("IllegalArgumentException"), + ErrorMessage: Ptr("Player not found."), +})) type serverInfoResponse struct { Status string `json:"Status"` @@ -83,27 +102,43 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { return err } - playerName := req.Username + playerNameOrUsername := req.Username - var player Player - result := app.DB.Preload("User").First(&player, "name = ?", playerName) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return c.JSONBlob(http.StatusUnauthorized, invalidCredentialsBlob) + var user User + var player *Player + + var playerStruct Player + if err := app.DB.Preload("User").First(&playerStruct, "name = ?", playerNameOrUsername).Error; err == nil { + player = &playerStruct + user = player.User + } else { + if errors.Is(err, gorm.ErrRecordNotFound) { + if err := app.DB.First(&user, "username = ?", playerNameOrUsername).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return c.JSONBlob(http.StatusUnauthorized, invalidCredentialsBlob) + } else { + return err + } + } } else { - return result.Error + return err } } - passwordHash, err := HashPassword(req.Password, player.User.PasswordSalt) + passwordHash, err := HashPassword(req.Password, user.PasswordSalt) if err != nil { return err } - if !bytes.Equal(passwordHash, player.User.PasswordHash) { + if !bytes.Equal(passwordHash, user.PasswordHash) { return c.JSONBlob(http.StatusUnauthorized, invalidCredentialsBlob) } + var playerUUID *string = nil + if player != nil { + playerUUID = &player.UUID + } + var client Client if req.ClientToken == nil { clientToken, err := RandomHex(16) @@ -114,20 +149,24 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { UUID: uuid.New().String(), ClientToken: clientToken, Version: 0, + PlayerUUID: playerUUID, } - player.Clients = append(player.Clients, client) + user.Clients = append(user.Clients, client) } else { clientToken := *req.ClientToken clientExists := false - for i := range player.Clients { - if player.Clients[i].ClientToken == clientToken { + + for i := range user.Clients { + if user.Clients[i].ClientToken == clientToken { clientExists = true - player.Clients[i].Version += 1 - client = player.Clients[i] + user.Clients[i].Version += 1 + client = user.Clients[i] break } else { - if !app.Config.AllowMultipleAccessTokens { - player.Clients[i].Version += 1 + // If AllowMultipleAccessTokens is disabled, invalidate all + // clients associated with the same player + if !app.Config.AllowMultipleAccessTokens && player != nil && user.Clients[i].PlayerUUID != nil && *user.Clients[i].PlayerUUID == player.UUID { + user.Clients[i].Version += 1 } } } @@ -137,34 +176,38 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { UUID: uuid.New().String(), ClientToken: clientToken, Version: 0, + PlayerUUID: playerUUID, } - player.Clients = append(player.Clients, client) + user.Clients = append(user.Clients, client) } } - // Save changes to player.Clients - result = app.DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&player) - if result.Error != nil { - return result.Error - } - - id, err := UUIDToID(player.UUID) - if err != nil { - return err - } - - var selectedProfile *Profile - var availableProfiles *[]Profile + var selectedProfile *Profile = nil + var availableProfiles *[]Profile = nil if req.Agent != nil { - selectedProfile = &Profile{ - ID: id, - Name: player.Name, + if player != nil { + id, err := UUIDToID(player.UUID) + if err != nil { + return err + } + selectedProfile = &Profile{ + ID: id, + Name: player.Name, + } } - availableProfiles = &[]Profile{*selectedProfile} + availableProfilesArray, err := getAvailableProfiles(&user) + if err != nil { + return err + } + availableProfiles = &availableProfilesArray } var userResponse *UserResponse - if req.RequestUser { + if req.RequestUser && player != nil { + id, err := UUIDToID(player.UUID) + if err != nil { + return err + } userResponse = &UserResponse{ ID: id, Properties: []UserProperty{{ @@ -179,6 +222,11 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { return err } + // Save changes to user.Clients + if err := app.DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + return err + } + res := authenticateResponse{ ClientToken: client.ClientToken, AccessToken: accessToken, @@ -191,14 +239,15 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { } type refreshRequest struct { - AccessToken string `json:"accessToken"` - ClientToken string `json:"clientToken"` - RequestUser bool `json:"requestUser"` + AccessToken string `json:"accessToken"` + ClientToken string `json:"clientToken"` + RequestUser bool `json:"requestUser"` + SelectedProfile *Profile `json:"selectedProfile"` } type refreshResponse struct { AccessToken string `json:"accessToken"` ClientToken string `json:"clientToken"` - SelectedProfile Profile `json:"selectedProfile,omitempty"` + SelectedProfile *Profile `json:"selectedProfile,omitempty"` AvailableProfiles []Profile `json:"availableProfiles,omitempty"` User *UserResponse `json:"user,omitempty"` } @@ -216,26 +265,53 @@ func AuthRefresh(app *App) func(c echo.Context) error { if client == nil || client.ClientToken != req.ClientToken { return c.JSONBlob(http.StatusUnauthorized, invalidAccessTokenBlob) } + user := client.User player := client.Player - id, err := UUIDToID(player.UUID) + if req.SelectedProfile != nil { + if player == nil { + // Just ignore if there is already a selectedProfile for the + // client + for _, userPlayer := range user.Players { + requestedUUID, err := IDToUUID(req.SelectedProfile.ID) + if err != nil { + return err + } + if userPlayer.UUID == requestedUUID { + client.PlayerUUID = &userPlayer.UUID + player = &userPlayer + break + } + } + if player == nil { + return c.JSONBlob(http.StatusBadRequest, playerNotFoundBlob) + } + } + } + + var selectedProfile *Profile = nil + if player != nil { + id, err := UUIDToID(player.UUID) + if err != nil { + return err + } + selectedProfile = &Profile{ + ID: id, + Name: player.Name, + } + } + availableProfiles, err := getAvailableProfiles(&user) if err != nil { return err } - selectedProfile := Profile{ - ID: id, - Name: player.Name, - } - availableProfiles := []Profile{selectedProfile} - var userResponse *UserResponse - if req.RequestUser { + if req.RequestUser && selectedProfile != nil { userResponse = &UserResponse{ - ID: id, + ID: selectedProfile.ID, Properties: []UserProperty{{ Name: "preferredLanguage", - Value: player.User.PreferredLanguage, + Value: user.PreferredLanguage, }}, } } @@ -246,9 +322,8 @@ func AuthRefresh(app *App) func(c echo.Context) error { return err } - result := app.DB.Save(client) - if result.Error != nil { - return result.Error + if err := app.DB.Save(client).Error; err != nil { + return err } res := refreshResponse{ @@ -300,22 +375,22 @@ func AuthSignout(app *App) func(c echo.Context) error { return err } - var player Player - result := app.DB.Preload("User").First(&player, "name = ?", req.Username) + var user User + result := app.DB.First(&user, "username = ?", req.Username) if result.Error != nil { return result.Error } - passwordHash, err := HashPassword(req.Password, player.User.PasswordSalt) + passwordHash, err := HashPassword(req.Password, user.PasswordSalt) if err != nil { return err } - if !bytes.Equal(passwordHash, player.User.PasswordHash) { + if !bytes.Equal(passwordHash, user.PasswordHash) { return c.JSONBlob(http.StatusUnauthorized, invalidCredentialsBlob) } - err = app.InvalidatePlayer(app.DB, &player) + err = app.InvalidateUser(app.DB, &user) if err != nil { return err } @@ -339,13 +414,20 @@ func AuthInvalidate(app *App) func(c echo.Context) error { } client := app.GetClient(req.AccessToken, StalePolicyAllow) - if client == nil || client.ClientToken != req.ClientToken { + if client == nil { return c.JSONBlob(http.StatusUnauthorized, invalidAccessTokenBlob) } - err := app.InvalidatePlayer(app.DB, &client.Player) - if err != nil { - return err + if client.Player == nil { + err := app.InvalidateUser(app.DB, &client.User) + if err != nil { + return err + } + } else { + err := app.InvalidatePlayer(app.DB, client.Player) + if err != nil { + return err + } } return c.NoContent(http.StatusNoContent) diff --git a/auth_test.go b/auth_test.go index 0873e82..c32dc38 100644 --- a/auth_test.go +++ b/auth_test.go @@ -93,12 +93,15 @@ func (ts *TestSuite) testAuthenticate(t *testing.T) { // Check that the database was updated var client Client - result := ts.App.DB.Preload("Player.User").First(&client, "client_token = ?", response.ClientToken) + result := ts.App.DB.Preload("Player").First(&client, "client_token = ?", response.ClientToken) assert.Nil(t, result.Error) + assert.NotNil(t, client.Player) assert.Equal(t, TEST_PLAYER_NAME, client.Player.Name) accessTokenClient := ts.App.GetClient(response.AccessToken, StalePolicyDeny) assert.NotNil(t, accessTokenClient) + accessTokenClient.Player = client.Player + accessTokenClient.User = client.User assert.Equal(t, client, *accessTokenClient) @@ -261,19 +264,6 @@ func (ts *TestSuite) testInvalidate(t *testing.T) { authenticateRes = ts.authenticate(t, TEST_PLAYER_NAME, TEST_PASSWORD) clientToken = authenticateRes.ClientToken accessToken = authenticateRes.AccessToken - { - // Invalidation should fail when client token is invalid - payload := refreshRequest{ - ClientToken: "invalid", - AccessToken: accessToken, - } - rec := ts.PostJSON(t, ts.Server, "/invalidate", payload, nil, nil) - - // Invalidate should fail - var response ErrorResponse - assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) - assert.Equal(t, "ForbiddenOperationException", *response.Error) - } { // Invalidate should fail if we send an invalid access token payload := refreshRequest{ @@ -285,6 +275,7 @@ func (ts *TestSuite) testInvalidate(t *testing.T) { // Invalidate should fail var response ErrorResponse assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Equal(t, http.StatusUnauthorized, rec.Code) assert.Equal(t, "ForbiddenOperationException", *response.Error) assert.Equal(t, "Invalid token.", *response.ErrorMessage) } @@ -327,7 +318,7 @@ func (ts *TestSuite) testRefresh(t *testing.T) { ID: Unwrap(UUIDToID(player.UUID)), Name: player.Name, } - assert.Equal(t, expectedProfile, refreshRes.SelectedProfile) + assert.Equal(t, expectedProfile, *refreshRes.SelectedProfile) assert.Equal(t, []Profile{expectedProfile}, refreshRes.AvailableProfiles) // We did not pass requestUser @@ -400,15 +391,15 @@ func (ts *TestSuite) testSignout(t *testing.T) { accessToken := authenticateRes.AccessToken { // Successful signout - var player Player - result := ts.App.DB.First(&player, "name = ?", TEST_PLAYER_NAME) + var user User + result := ts.App.DB.First(&user, "username = ?", TEST_USERNAME) assert.Nil(t, result.Error) // We should start with valid clients in the database client := ts.App.GetClient(accessToken, StalePolicyDeny) assert.NotNil(t, client) var clients []Client - result = ts.App.DB.Model(Client{}).Where("player_uuid = ?", client.Player.UUID).Find(&clients) + result = ts.App.DB.Model(Client{}).Where("user_uuid = ?", client.UserUUID).Find(&clients) assert.Nil(t, result.Error) assert.True(t, len(clients) > 0) oldVersions := make(map[string]int) @@ -417,7 +408,7 @@ func (ts *TestSuite) testSignout(t *testing.T) { } payload := signoutRequest{ - Username: TEST_PLAYER_NAME, + Username: TEST_USERNAME, Password: TEST_PASSWORD, } rec := ts.PostJSON(t, ts.Server, "/signout", payload, nil, nil) @@ -428,7 +419,7 @@ func (ts *TestSuite) testSignout(t *testing.T) { // The token version of each client should have been incremented, // invalidating all previously-issued JWTs assert.Nil(t, ts.App.GetClient(accessToken, StalePolicyDeny)) - result = ts.App.DB.Model(Client{}).Where("player_uuid = ?", client.Player.UUID).Find(&clients) + result = ts.App.DB.Model(Client{}).Where("user_uuid = ?", client.UserUUID).Find(&clients) assert.Nil(t, result.Error) assert.True(t, len(clients) > 0) for _, client := range clients { @@ -438,7 +429,7 @@ func (ts *TestSuite) testSignout(t *testing.T) { { // Should fail when incorrect password is sent payload := signoutRequest{ - Username: TEST_PLAYER_NAME, + Username: TEST_USERNAME, Password: "incorrect", } rec := ts.PostJSON(t, ts.Server, "/signout", payload, nil, nil) @@ -446,6 +437,7 @@ func (ts *TestSuite) testSignout(t *testing.T) { // Signout should fail var response ErrorResponse assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Equal(t, http.StatusUnauthorized, rec.Code) assert.Equal(t, "ForbiddenOperationException", *response.Error) assert.Equal(t, "Invalid credentials. Invalid username or password.", *response.ErrorMessage) } diff --git a/db.go b/db.go index ada4eda..beb1131 100644 --- a/db.go +++ b/db.go @@ -249,7 +249,8 @@ func migrate(db *gorm.DB, alreadyExisted bool) error { UUID: v3Client.UUID, ClientToken: v3Client.ClientToken, Version: v3Client.Version, - PlayerUUID: v3Client.UserUUID, + UserUUID: v3Client.UserUUID, + PlayerUUID: &v3Client.UserUUID, }) } player := V4Player{ diff --git a/model.go b/model.go index 3c06812..6167070 100644 --- a/model.go +++ b/model.go @@ -333,7 +333,7 @@ func (app *App) GetClient(accessToken string, stalePolicy StaleTokenPolicy) *Cli } var client Client - result := app.DB.Preload("Player.User").First(&client, "uuid = ?", claims.RegisteredClaims.Subject) + result := app.DB.Preload("User").Preload("Player").First(&client, "uuid = ?", claims.RegisteredClaims.Subject) if result.Error != nil { return nil } @@ -368,6 +368,7 @@ type User struct { PreferredLanguage string Players []Player MaxPlayerCount int + Clients []Client } func (user *User) BeforeDelete(tx *gorm.DB) error { @@ -378,11 +379,50 @@ func (user *User) BeforeDelete(tx *gorm.DB) error { if len(players) > 0 { return tx.Delete(&players).Error } + + var clients []Client + if err := tx.Where("user_uuid = ?", user.UUID).Find(&clients).Error; err != nil { + return err + } + if len(clients) > 0 { + if err := tx.Delete(&clients).Error; err != nil { + return err + } + } + + return nil +} + +func (player *Player) BeforeDelete(tx *gorm.DB) error { + var clients []Client + if err := tx.Where("player_uuid = ?", player.UUID).Find(&clients).Error; err != nil { + return err + } + if len(clients) > 0 { + if err := tx.Delete(&clients).Error; err != nil { + return err + } + } + return nil +} + +func (player *Player) AfterFind(tx *gorm.DB) error { + if err := tx.Find(&player.Clients, "player_uuid = ?", player.UUID).Error; err != nil { + return err + } return nil } func (user *User) AfterFind(tx *gorm.DB) error { - return tx.Find(&user.Players, "user_uuid = ?", user.UUID).Error + err := tx.Find(&user.Players, "user_uuid = ?", user.UUID).Error + if err != nil { + return err + } + err = tx.Find(&user.Clients, "user_uuid = ?", user.UUID).Error + if err != nil { + return err + } + return nil } type Player struct { @@ -396,32 +436,19 @@ type Player struct { CapeHash sql.NullString `gorm:"index"` ServerID sql.NullString FallbackPlayer string - Clients []Client User User UserUUID string `gorm:"not null"` -} - -func (player *Player) BeforeDelete(tx *gorm.DB) (err error) { - var clients []Client - if err := tx.Where("player_uuid = ?", player.UUID).Find(&clients).Error; err != nil { - return err - } - if len(clients) > 0 { - return tx.Delete(&clients).Error - } - return nil -} - -func (player *Player) AfterFind(tx *gorm.DB) error { - return tx.Find(&player.Clients, "player_uuid = ?", player.UUID).Error + Clients []Client } type Client struct { UUID string `gorm:"primaryKey"` ClientToken string Version int - PlayerUUID string `gorm:"not null"` - Player Player + UserUUID string `gorm:"not null"` + User User + PlayerUUID *string + Player *Player } func (app *App) GetSkinURL(player *Player) (*string, error) { diff --git a/player.go b/player.go index a6b669d..d475204 100644 --- a/player.go +++ b/player.go @@ -572,10 +572,18 @@ func (app *App) GetChallengeSkin(playerName string, challengeToken string) ([]by } func (app *App) InvalidatePlayer(db *gorm.DB, player *Player) error { + if player == nil { + return nil + } result := db.Model(Client{}).Where("player_uuid = ?", player.UUID).Update("version", gorm.Expr("version + ?", 1)) return result.Error } +func (app *App) InvalidateUser(db *gorm.DB, user *User) error { + result := db.Model(Client{}).Where("user_uuid = ?", user.UUID).Update("version", gorm.Expr("version + ?", 1)) + return result.Error +} + func (app *App) DeletePlayer(player *Player) error { if err := app.DB.Delete(player).Error; err != nil { return err diff --git a/services.go b/services.go index 1609e16..a797725 100644 --- a/services.go +++ b/services.go @@ -41,8 +41,11 @@ func withBearerAuthentication(app *App, f func(c echo.Context, player *Player) e return c.JSON(http.StatusUnauthorized, ErrorResponse{Path: Ptr(c.Request().URL.Path)}) } player := client.Player + if player == nil { + return c.JSON(http.StatusBadRequest, ErrorResponse{Path: Ptr(c.Request().URL.Path), ErrorMessage: Ptr("Access token does not have a selected profile.")}) + } - return f(c, &player) + return f(c, player) } }