diff --git a/auth.go b/auth.go index d84ab79..9e29bd9 100644 --- a/auth.go +++ b/auth.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/google/uuid" "github.com/labstack/echo/v4" + "github.com/samber/mo" "gorm.io/gorm" "net/http" ) @@ -102,24 +103,27 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { return err } - playerNameOrUsername := req.Username + usernameOrPlayerName := req.Username var user User - var player *Player + player := mo.None[Player]() - var playerStruct Player - if err := app.DB.Preload("User").First(&playerStruct, "name = ?", playerNameOrUsername).Error; err == nil { - player = &playerStruct - user = player.User + if err := app.DB.First(&user, "username = ?", usernameOrPlayerName).Error; err == nil { + if len(user.Players) == 1 { + player = mo.Some(user.Players[0]) + } } else { + var playerStruct Player if errors.Is(err, gorm.ErrRecordNotFound) { - if err := app.DB.First(&user, "username = ?", playerNameOrUsername).Error; err != nil { + if err := app.DB.Preload("User").First(&playerStruct, "name = ?", usernameOrPlayerName).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return c.JSONBlob(http.StatusUnauthorized, invalidCredentialsBlob) } else { return err } } + player = mo.Some(playerStruct) + user = playerStruct.User } else { return err } @@ -134,9 +138,9 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { return c.JSONBlob(http.StatusUnauthorized, invalidCredentialsBlob) } - var playerUUID *string = nil - if player != nil { - playerUUID = &player.UUID + playerUUID := mo.None[string]() + if p, ok := player.Get(); ok { + playerUUID = mo.Some(p.UUID) } var client Client @@ -149,7 +153,7 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { UUID: uuid.New().String(), ClientToken: clientToken, Version: 0, - PlayerUUID: playerUUID, + PlayerUUID: OptionToNullString(playerUUID), } user.Clients = append(user.Clients, client) } else { @@ -165,7 +169,7 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { } else { // If AllowMultipleAccessTokens is disabled, invalidate all // clients associated with the same player - if !app.Config.AllowMultipleAccessTokens && player != nil && user.Clients[i].PlayerUUID != nil && *user.Clients[i].PlayerUUID == player.UUID { + if !app.Config.AllowMultipleAccessTokens && NullStringToOption(&user.Clients[i].PlayerUUID) == playerUUID { user.Clients[i].Version += 1 } } @@ -176,7 +180,7 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { UUID: uuid.New().String(), ClientToken: clientToken, Version: 0, - PlayerUUID: playerUUID, + PlayerUUID: OptionToNullString(playerUUID), } user.Clients = append(user.Clients, client) } @@ -185,14 +189,14 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { var selectedProfile *Profile = nil var availableProfiles *[]Profile = nil if req.Agent != nil { - if player != nil { - id, err := UUIDToID(player.UUID) + if p, ok := player.Get(); ok { + id, err := UUIDToID(p.UUID) if err != nil { return err } selectedProfile = &Profile{ ID: id, - Name: player.Name, + Name: p.Name, } } availableProfilesArray, err := getAvailableProfiles(&user) @@ -203,8 +207,8 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { } var userResponse *UserResponse - if req.RequestUser && player != nil { - id, err := UUIDToID(player.UUID) + if p, ok := player.Get(); ok && req.RequestUser { + id, err := UUIDToID(p.UUID) if err != nil { return err } @@ -212,7 +216,7 @@ func AuthAuthenticate(app *App) func(c echo.Context) error { ID: id, Properties: []UserProperty{{ Name: "preferredLanguage", - Value: player.User.PreferredLanguage, + Value: user.PreferredLanguage, }}, } } @@ -278,7 +282,7 @@ func AuthRefresh(app *App) func(c echo.Context) error { return err } if userPlayer.UUID == requestedUUID { - client.PlayerUUID = &userPlayer.UUID + client.PlayerUUID = MakeNullString(&userPlayer.UUID) player = &userPlayer break } diff --git a/auth_test.go b/auth_test.go index b4bd655..0169fe0 100644 --- a/auth_test.go +++ b/auth_test.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "github.com/samber/mo" "github.com/stretchr/testify/assert" "net/http" "testing" @@ -21,6 +22,7 @@ func TestAuth(t *testing.T) { t.Run("Test /", ts.testGetServerInfo) t.Run("Test /authenticate", ts.testAuthenticate) + t.Run("Test /authenticate, multiple profiles", ts.testAuthenticateMultipleProfiles) t.Run("Test /invalidate", ts.testInvalidate) t.Run("Test /refresh", ts.testRefresh) t.Run("Test /signout", ts.testSignout) @@ -223,6 +225,72 @@ func (ts *TestSuite) testAuthenticate(t *testing.T) { } } +func (ts *TestSuite) testAuthenticateMultipleProfiles(t *testing.T) { + { + var user User + assert.Nil(t, ts.App.DB.First(&user, "username = ?", TEST_USERNAME).Error) + + secondPlayerName := "SecondPlayer" + + // player := user.Players[0] + otherPlayer, err := ts.App.CreatePlayer(&GOD, user.UUID, secondPlayerName, nil, false, nil, nil, nil, nil, nil, nil, nil) + assert.Nil(t, err) + + authenticatePayload := authenticateRequest{ + Username: TEST_USERNAME, + Password: TEST_PASSWORD, + RequestUser: false, + Agent: &Agent{ + Name: "Minecraft", + Version: 1, + }, + } + rec := ts.PostJSON(t, ts.Server, "/authenticate", authenticatePayload, nil, nil) + + assert.Equal(t, http.StatusOK, rec.Code) + var authenticateRes authenticateResponse + assert.Nil(t, json.NewDecoder(rec.Body).Decode(&authenticateRes)) + + // We did not pass requestUser + assert.Nil(t, authenticateRes.User) + + // User has multiple players, selectedProfile should be missing + assert.Nil(t, authenticateRes.SelectedProfile) + + assert.Equal(t, 2, len(*authenticateRes.AvailableProfiles)) + + p := mo.None[Profile]() + for _, availableProfile := range *authenticateRes.AvailableProfiles { + if availableProfile.Name == secondPlayerName { + p = mo.Some(availableProfile) + break + } + } + profile, ok := p.Get() + assert.True(t, ok) + + // Now, refresh to select a profile + refreshPayload := refreshRequest{ + ClientToken: authenticateRes.ClientToken, + AccessToken: authenticateRes.AccessToken, + RequestUser: false, + SelectedProfile: &profile, + } + rec = ts.PostJSON(t, ts.Server, "/refresh", refreshPayload, nil, nil) + + // Refresh should succeed and we should get a new accessToken + assert.Equal(t, http.StatusOK, rec.Code) + var refreshRes refreshResponse + assert.Nil(t, json.NewDecoder(rec.Body).Decode(&refreshRes)) + assert.Equal(t, authenticateRes.ClientToken, refreshRes.ClientToken) + assert.NotEqual(t, authenticateRes.AccessToken, refreshRes.AccessToken) + + assert.Equal(t, profile, *refreshRes.SelectedProfile) + + assert.Nil(t, ts.App.DeletePlayer(&GOD, &otherPlayer)) + } +} + func (ts *TestSuite) testInvalidate(t *testing.T) { { authenticateRes := ts.authenticate(t, TEST_PLAYER_NAME, TEST_PASSWORD) @@ -234,7 +302,7 @@ func (ts *TestSuite) testInvalidate(t *testing.T) { client := ts.App.GetClient(accessToken, StalePolicyDeny) assert.NotNil(t, client) var clients []Client - result := ts.App.DB.Model(Client{}).Where("player_uuid = ?", client.Player.UUID).Find(&clients) + result := ts.App.DB.Model(Client{}).Where("player_uuid = ?", &client.Player.UUID).Find(&clients) assert.Nil(t, result.Error) assert.True(t, len(clients) > 0) oldVersions := make(map[string]int) @@ -254,7 +322,7 @@ func (ts *TestSuite) testInvalidate(t *testing.T) { // The token version of each client should have been incremented, // invalidating all previously-issued JWTs assert.Nil(t, ts.App.GetClient(accessToken, StalePolicyDeny)) - result = ts.App.DB.Model(Client{}).Where("player_uuid = ?", client.Player.UUID).Find(&clients) + result = ts.App.DB.Model(Client{}).Where("player_uuid = ?", &client.Player.UUID).Find(&clients) assert.Nil(t, result.Error) for _, client := range clients { assert.Equal(t, oldVersions[client.ClientToken]+1, client.Version) diff --git a/db.go b/db.go index b18dedc..5f25a03 100644 --- a/db.go +++ b/db.go @@ -268,7 +268,7 @@ func Migrate(config *Config, db *gorm.DB, alreadyExisted bool, targetUserVersion ClientToken: v3Client.ClientToken, Version: v3Client.Version, UserUUID: v3Client.UserUUID, - PlayerUUID: &v3Client.UserUUID, + PlayerUUID: MakeNullString(&v3Client.UserUUID), }) } // If the player name is in use as someone else's username, diff --git a/model.go b/model.go index 7957e4e..0c82a64 100644 --- a/model.go +++ b/model.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" + "github.com/samber/mo" "golang.org/x/crypto/scrypt" "gorm.io/gorm" "net/url" @@ -43,6 +44,23 @@ func UnmakeNullString(ns *sql.NullString) *string { return nil } +func NullStringToOption(ns *sql.NullString) mo.Option[string] { + if ns.Valid { + return mo.Some(ns.String) + } + return mo.None[string]() +} + +func OptionToNullString(option mo.Option[string]) sql.NullString { + if s, ok := option.Get(); ok { + return sql.NullString{ + String: s, + Valid: true, + } + } + return sql.NullString{Valid: false} +} + func IsValidSkinModel(model string) bool { switch model { case SkinModelSlim, SkinModelClassic: @@ -454,7 +472,7 @@ type Client struct { Version int UserUUID string `gorm:"not null"` User User - PlayerUUID *string + PlayerUUID sql.NullString `gorm:"index"` Player *Player }