Test authentication multiple profiles

This commit is contained in:
Evan Goode 2024-12-08 23:56:30 -05:00
parent 4fca7cc8e4
commit 568aab84f6
4 changed files with 114 additions and 24 deletions

44
auth.go
View File

@ -6,6 +6,7 @@ import (
"errors" "errors"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/samber/mo"
"gorm.io/gorm" "gorm.io/gorm"
"net/http" "net/http"
) )
@ -102,24 +103,27 @@ func AuthAuthenticate(app *App) func(c echo.Context) error {
return err return err
} }
playerNameOrUsername := req.Username usernameOrPlayerName := req.Username
var user User var user User
var player *Player player := mo.None[Player]()
var playerStruct Player if err := app.DB.First(&user, "username = ?", usernameOrPlayerName).Error; err == nil {
if err := app.DB.Preload("User").First(&playerStruct, "name = ?", playerNameOrUsername).Error; err == nil { if len(user.Players) == 1 {
player = &playerStruct player = mo.Some(user.Players[0])
user = player.User }
} else { } else {
var playerStruct Player
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
if err := app.DB.First(&user, "username = ?", playerNameOrUsername).Error; err != nil { if err := app.DB.Preload("User").First(&playerStruct, "name = ?", usernameOrPlayerName).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return c.JSONBlob(http.StatusUnauthorized, invalidCredentialsBlob) return c.JSONBlob(http.StatusUnauthorized, invalidCredentialsBlob)
} else { } else {
return err return err
} }
} }
player = mo.Some(playerStruct)
user = playerStruct.User
} else { } else {
return err return err
} }
@ -134,9 +138,9 @@ func AuthAuthenticate(app *App) func(c echo.Context) error {
return c.JSONBlob(http.StatusUnauthorized, invalidCredentialsBlob) return c.JSONBlob(http.StatusUnauthorized, invalidCredentialsBlob)
} }
var playerUUID *string = nil playerUUID := mo.None[string]()
if player != nil { if p, ok := player.Get(); ok {
playerUUID = &player.UUID playerUUID = mo.Some(p.UUID)
} }
var client Client var client Client
@ -149,7 +153,7 @@ func AuthAuthenticate(app *App) func(c echo.Context) error {
UUID: uuid.New().String(), UUID: uuid.New().String(),
ClientToken: clientToken, ClientToken: clientToken,
Version: 0, Version: 0,
PlayerUUID: playerUUID, PlayerUUID: OptionToNullString(playerUUID),
} }
user.Clients = append(user.Clients, client) user.Clients = append(user.Clients, client)
} else { } else {
@ -165,7 +169,7 @@ func AuthAuthenticate(app *App) func(c echo.Context) error {
} else { } else {
// If AllowMultipleAccessTokens is disabled, invalidate all // If AllowMultipleAccessTokens is disabled, invalidate all
// clients associated with the same player // clients associated with the same player
if !app.Config.AllowMultipleAccessTokens && player != nil && user.Clients[i].PlayerUUID != nil && *user.Clients[i].PlayerUUID == player.UUID { if !app.Config.AllowMultipleAccessTokens && NullStringToOption(&user.Clients[i].PlayerUUID) == playerUUID {
user.Clients[i].Version += 1 user.Clients[i].Version += 1
} }
} }
@ -176,7 +180,7 @@ func AuthAuthenticate(app *App) func(c echo.Context) error {
UUID: uuid.New().String(), UUID: uuid.New().String(),
ClientToken: clientToken, ClientToken: clientToken,
Version: 0, Version: 0,
PlayerUUID: playerUUID, PlayerUUID: OptionToNullString(playerUUID),
} }
user.Clients = append(user.Clients, client) user.Clients = append(user.Clients, client)
} }
@ -185,14 +189,14 @@ func AuthAuthenticate(app *App) func(c echo.Context) error {
var selectedProfile *Profile = nil var selectedProfile *Profile = nil
var availableProfiles *[]Profile = nil var availableProfiles *[]Profile = nil
if req.Agent != nil { if req.Agent != nil {
if player != nil { if p, ok := player.Get(); ok {
id, err := UUIDToID(player.UUID) id, err := UUIDToID(p.UUID)
if err != nil { if err != nil {
return err return err
} }
selectedProfile = &Profile{ selectedProfile = &Profile{
ID: id, ID: id,
Name: player.Name, Name: p.Name,
} }
} }
availableProfilesArray, err := getAvailableProfiles(&user) availableProfilesArray, err := getAvailableProfiles(&user)
@ -203,8 +207,8 @@ func AuthAuthenticate(app *App) func(c echo.Context) error {
} }
var userResponse *UserResponse var userResponse *UserResponse
if req.RequestUser && player != nil { if p, ok := player.Get(); ok && req.RequestUser {
id, err := UUIDToID(player.UUID) id, err := UUIDToID(p.UUID)
if err != nil { if err != nil {
return err return err
} }
@ -212,7 +216,7 @@ func AuthAuthenticate(app *App) func(c echo.Context) error {
ID: id, ID: id,
Properties: []UserProperty{{ Properties: []UserProperty{{
Name: "preferredLanguage", Name: "preferredLanguage",
Value: player.User.PreferredLanguage, Value: user.PreferredLanguage,
}}, }},
} }
} }
@ -278,7 +282,7 @@ func AuthRefresh(app *App) func(c echo.Context) error {
return err return err
} }
if userPlayer.UUID == requestedUUID { if userPlayer.UUID == requestedUUID {
client.PlayerUUID = &userPlayer.UUID client.PlayerUUID = MakeNullString(&userPlayer.UUID)
player = &userPlayer player = &userPlayer
break break
} }

View File

@ -2,6 +2,7 @@ package main
import ( import (
"encoding/json" "encoding/json"
"github.com/samber/mo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"net/http" "net/http"
"testing" "testing"
@ -21,6 +22,7 @@ func TestAuth(t *testing.T) {
t.Run("Test /", ts.testGetServerInfo) t.Run("Test /", ts.testGetServerInfo)
t.Run("Test /authenticate", ts.testAuthenticate) t.Run("Test /authenticate", ts.testAuthenticate)
t.Run("Test /authenticate, multiple profiles", ts.testAuthenticateMultipleProfiles)
t.Run("Test /invalidate", ts.testInvalidate) t.Run("Test /invalidate", ts.testInvalidate)
t.Run("Test /refresh", ts.testRefresh) t.Run("Test /refresh", ts.testRefresh)
t.Run("Test /signout", ts.testSignout) t.Run("Test /signout", ts.testSignout)
@ -223,6 +225,72 @@ func (ts *TestSuite) testAuthenticate(t *testing.T) {
} }
} }
func (ts *TestSuite) testAuthenticateMultipleProfiles(t *testing.T) {
{
var user User
assert.Nil(t, ts.App.DB.First(&user, "username = ?", TEST_USERNAME).Error)
secondPlayerName := "SecondPlayer"
// player := user.Players[0]
otherPlayer, err := ts.App.CreatePlayer(&GOD, user.UUID, secondPlayerName, nil, false, nil, nil, nil, nil, nil, nil, nil)
assert.Nil(t, err)
authenticatePayload := authenticateRequest{
Username: TEST_USERNAME,
Password: TEST_PASSWORD,
RequestUser: false,
Agent: &Agent{
Name: "Minecraft",
Version: 1,
},
}
rec := ts.PostJSON(t, ts.Server, "/authenticate", authenticatePayload, nil, nil)
assert.Equal(t, http.StatusOK, rec.Code)
var authenticateRes authenticateResponse
assert.Nil(t, json.NewDecoder(rec.Body).Decode(&authenticateRes))
// We did not pass requestUser
assert.Nil(t, authenticateRes.User)
// User has multiple players, selectedProfile should be missing
assert.Nil(t, authenticateRes.SelectedProfile)
assert.Equal(t, 2, len(*authenticateRes.AvailableProfiles))
p := mo.None[Profile]()
for _, availableProfile := range *authenticateRes.AvailableProfiles {
if availableProfile.Name == secondPlayerName {
p = mo.Some(availableProfile)
break
}
}
profile, ok := p.Get()
assert.True(t, ok)
// Now, refresh to select a profile
refreshPayload := refreshRequest{
ClientToken: authenticateRes.ClientToken,
AccessToken: authenticateRes.AccessToken,
RequestUser: false,
SelectedProfile: &profile,
}
rec = ts.PostJSON(t, ts.Server, "/refresh", refreshPayload, nil, nil)
// Refresh should succeed and we should get a new accessToken
assert.Equal(t, http.StatusOK, rec.Code)
var refreshRes refreshResponse
assert.Nil(t, json.NewDecoder(rec.Body).Decode(&refreshRes))
assert.Equal(t, authenticateRes.ClientToken, refreshRes.ClientToken)
assert.NotEqual(t, authenticateRes.AccessToken, refreshRes.AccessToken)
assert.Equal(t, profile, *refreshRes.SelectedProfile)
assert.Nil(t, ts.App.DeletePlayer(&GOD, &otherPlayer))
}
}
func (ts *TestSuite) testInvalidate(t *testing.T) { func (ts *TestSuite) testInvalidate(t *testing.T) {
{ {
authenticateRes := ts.authenticate(t, TEST_PLAYER_NAME, TEST_PASSWORD) authenticateRes := ts.authenticate(t, TEST_PLAYER_NAME, TEST_PASSWORD)
@ -234,7 +302,7 @@ func (ts *TestSuite) testInvalidate(t *testing.T) {
client := ts.App.GetClient(accessToken, StalePolicyDeny) client := ts.App.GetClient(accessToken, StalePolicyDeny)
assert.NotNil(t, client) assert.NotNil(t, client)
var clients []Client var clients []Client
result := ts.App.DB.Model(Client{}).Where("player_uuid = ?", client.Player.UUID).Find(&clients) result := ts.App.DB.Model(Client{}).Where("player_uuid = ?", &client.Player.UUID).Find(&clients)
assert.Nil(t, result.Error) assert.Nil(t, result.Error)
assert.True(t, len(clients) > 0) assert.True(t, len(clients) > 0)
oldVersions := make(map[string]int) oldVersions := make(map[string]int)
@ -254,7 +322,7 @@ func (ts *TestSuite) testInvalidate(t *testing.T) {
// The token version of each client should have been incremented, // The token version of each client should have been incremented,
// invalidating all previously-issued JWTs // invalidating all previously-issued JWTs
assert.Nil(t, ts.App.GetClient(accessToken, StalePolicyDeny)) assert.Nil(t, ts.App.GetClient(accessToken, StalePolicyDeny))
result = ts.App.DB.Model(Client{}).Where("player_uuid = ?", client.Player.UUID).Find(&clients) result = ts.App.DB.Model(Client{}).Where("player_uuid = ?", &client.Player.UUID).Find(&clients)
assert.Nil(t, result.Error) assert.Nil(t, result.Error)
for _, client := range clients { for _, client := range clients {
assert.Equal(t, oldVersions[client.ClientToken]+1, client.Version) assert.Equal(t, oldVersions[client.ClientToken]+1, client.Version)

2
db.go
View File

@ -268,7 +268,7 @@ func Migrate(config *Config, db *gorm.DB, alreadyExisted bool, targetUserVersion
ClientToken: v3Client.ClientToken, ClientToken: v3Client.ClientToken,
Version: v3Client.Version, Version: v3Client.Version,
UserUUID: v3Client.UserUUID, UserUUID: v3Client.UserUUID,
PlayerUUID: &v3Client.UserUUID, PlayerUUID: MakeNullString(&v3Client.UserUUID),
}) })
} }
// If the player name is in use as someone else's username, // If the player name is in use as someone else's username,

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/samber/mo"
"golang.org/x/crypto/scrypt" "golang.org/x/crypto/scrypt"
"gorm.io/gorm" "gorm.io/gorm"
"net/url" "net/url"
@ -43,6 +44,23 @@ func UnmakeNullString(ns *sql.NullString) *string {
return nil return nil
} }
func NullStringToOption(ns *sql.NullString) mo.Option[string] {
if ns.Valid {
return mo.Some(ns.String)
}
return mo.None[string]()
}
func OptionToNullString(option mo.Option[string]) sql.NullString {
if s, ok := option.Get(); ok {
return sql.NullString{
String: s,
Valid: true,
}
}
return sql.NullString{Valid: false}
}
func IsValidSkinModel(model string) bool { func IsValidSkinModel(model string) bool {
switch model { switch model {
case SkinModelSlim, SkinModelClassic: case SkinModelSlim, SkinModelClassic:
@ -454,7 +472,7 @@ type Client struct {
Version int Version int
UserUUID string `gorm:"not null"` UserUUID string `gorm:"not null"`
User User User User
PlayerUUID *string PlayerUUID sql.NullString `gorm:"index"`
Player *Player Player *Player
} }