diff --git a/account.go b/account.go index 5116905..e096c42 100644 --- a/account.go +++ b/account.go @@ -1,23 +1,182 @@ package main import ( + "bytes" "encoding/json" "errors" "fmt" + mapset "github.com/deckarep/golang-set/v2" "github.com/labstack/echo/v4" "github.com/samber/mo" "gorm.io/gorm" "log" "net/http" - "net/url" "strings" + "time" ) -type playerNameToUUIDResponse struct { +type PlayerNameToIDResponse struct { Name string `json:"name"` ID string `json:"id"` } +type playerNameToIDJob struct { + LowerName string + ReturnCh chan mo.Option[PlayerNameToIDResponse] +} + +func (fallbackAPIServer *FallbackAPIServer) PlayerNamesToIDs(remainingLowerNames mapset.Set[string]) []PlayerNameToIDResponse { + responses := make([]PlayerNameToIDResponse, 0, remainingLowerNames.Cardinality()) + + // Use responses from the cache, if available. + if fallbackAPIServer.PlayerNameToIDCache != nil { + for _, lowerName := range remainingLowerNames.ToSlice() { + cachedResponse, found := fallbackAPIServer.PlayerNameToIDCache.Get(lowerName) + if found { + remainingLowerNames.Remove(lowerName) + if response, isPresent := cachedResponse.(mo.Option[PlayerNameToIDResponse]).Get(); isPresent { + responses = append(responses, response) + } + } + } + } + + playerNameToIDJobs := make([]playerNameToIDJob, 0, remainingLowerNames.Cardinality()) + for lowerName := range mapset.Elements(remainingLowerNames) { + playerNameToIDJobs = append(playerNameToIDJobs, playerNameToIDJob{ + LowerName: lowerName, + ReturnCh: make(chan mo.Option[PlayerNameToIDResponse], 1), + }) + } + fallbackAPIServer.PlayerNameToIDJobCh <- playerNameToIDJobs + + for _, job := range playerNameToIDJobs { + maybeRes := <-job.ReturnCh + if res, ok := maybeRes.Get(); ok { + responses = append(responses, res) + } + } + return responses +} + +func (fallbackAPIServer *FallbackAPIServer) PlayerNamesToIDsWorker() { + // All communication with the POST /profiles/minecraft (a.k.a. POST + // /minecraft/profile/lookup/bulk/byname) route on a fallback API server is + // done by a single goroutine running this function. It buffers a queue of + // requested (lowercase) player names and makes requests to the fallback + // API server in batches of MAX_PLAYER_NAMES_TO_IDS, waiting at least + // MAX_PLAYER_NAMES_TO_IDS_INTERVAL in between requests, in order to avoid + // rate-limiting. + + url := fallbackAPIServer.Config.AccountURL + "/profiles/minecraft" + + // Queue of player names to fetch that may exceed MAX_PLAYER_NAMES_TO_IDS + // in size + lowerNameQueue := make([]*string, 0) + + // Map lowercase player name to a list of return channels where we should + // send the result of the query for that lowercase player name + lowerNameToResponseChs := make(map[string][]chan mo.Option[PlayerNameToIDResponse]) + + var timeout <-chan time.Time = nil + + for { + select { + case jobs := <-fallbackAPIServer.PlayerNameToIDJobCh: + for _, job := range PtrSlice(jobs) { + // Double-check the cache + if fallbackAPIServer.PlayerNameToIDCache != nil { + cachedResponse, found := fallbackAPIServer.PlayerNameToIDCache.Get(job.LowerName) + if found { + job.ReturnCh <- cachedResponse.(mo.Option[PlayerNameToIDResponse]) + continue + } + } + + if _, ok := lowerNameToResponseChs[job.LowerName]; !ok { + lowerNameQueue = append(lowerNameQueue, &job.LowerName) + } + lowerNameToResponseChs[job.LowerName] = append(lowerNameToResponseChs[job.LowerName], job.ReturnCh) + } + case <-timeout: + timeout = nil + } + + // Wait until we have player names in the queue AND have waited long + // enough to make another request + if !(len(lowerNameQueue) > 0 && timeout == nil) { + continue + } + + // Dequeue the next batch of MAX_PLAYER_NAMES_TO_IDS lowercase player names + batchSize := min(len(lowerNameQueue), MAX_PLAYER_NAMES_TO_IDS) + batch := lowerNameQueue[:batchSize] + lowerNameQueue = lowerNameQueue[batchSize:] + + fallbackResponses, fallbackError := (func() ([]PlayerNameToIDResponse, error) { + body, err := json.Marshal(batch) + if err != nil { + return nil, err + } + + res, err := MakeHTTPClient().Post(url, "application/json", bytes.NewBuffer(body)) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("received status code %d", res.StatusCode) + } + + buf := new(bytes.Buffer) + _, err = buf.ReadFrom(res.Body) + if err != nil { + return nil, err + } + + var fallbackResponses []PlayerNameToIDResponse + err = json.Unmarshal(buf.Bytes(), &fallbackResponses) + if err != nil { + return nil, err + } + return fallbackResponses, nil + })() + + timeout = time.After(MAX_PLAYER_NAMES_TO_IDS_INTERVAL) + + lowerNameToResponse := make(map[string]*PlayerNameToIDResponse) + if fallbackError != nil { + log.Printf("Error requesting player IDs from fallback API server at %s: %s", url, fallbackError) + } else { + for _, fallbackResponse := range PtrSlice(fallbackResponses) { + lowerName := strings.ToLower(fallbackResponse.Name) + lowerNameToResponse[lowerName] = fallbackResponse + } + } + + for _, lowerName := range batch { + if fallbackError == nil && fallbackAPIServer.PlayerNameToIDCache != nil { + ttl := time.Duration(fallbackAPIServer.Config.CacheTTLSeconds) * time.Second + if res, ok := lowerNameToResponse[*lowerName]; ok { + fallbackAPIServer.PlayerNameToIDCache.SetWithTTL(*lowerName, mo.Some(*res), 0, ttl) + } else { + fallbackAPIServer.PlayerNameToIDCache.SetWithTTL(*lowerName, mo.None[PlayerNameToIDResponse](), 0, ttl) + } + fallbackAPIServer.PlayerNameToIDCache.Wait() + } + for _, responseCh := range lowerNameToResponseChs[*lowerName] { + if res, ok := lowerNameToResponse[*lowerName]; ok { + responseCh <- mo.Some(*res) + } else { + responseCh <- mo.None[PlayerNameToIDResponse]() + } + } + } + clear(lowerNameToResponseChs) + } +} + // GET /users/profiles/minecraft/:playerName // GET /minecraft/profile/lookup/name/:playerName // https://minecraft.wiki/w/Mojang_API#Query_player's_UUID @@ -25,39 +184,27 @@ func AccountPlayerNameToID(app *App) func(c echo.Context) error { return func(c echo.Context) error { playerName := c.Param("playerName") + if len(playerName) > Constants.MaxPlayerNameLength { + // This error message is consistent with GET + // https://api.mojang.com/users/profiles/minecraft/:playerName as + // of 2025-04-02 + errorMessage := fmt.Sprintf("getProfileName.name: Invalid profile name, getProfileName.name: size must be between 1 and %d", Constants.MaxPlayerNameLength) + return &YggdrasilError{ + Code: http.StatusBadRequest, + Error_: mo.Some("CONSTRAINT_VIOLATION"), + ErrorMessage: mo.Some(errorMessage), + } + } + + lowerName := strings.ToLower(playerName) + var player Player - result := app.DB.First(&player, "name = ?", playerName) + result := app.DB.First(&player, "name = ?", lowerName) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - for _, fallbackAPIServer := range app.Config.FallbackAPIServers { - reqURL, err := url.JoinPath(fallbackAPIServer.AccountURL, "profiles/minecraft") - if err != nil { - log.Println(err) - continue - } - - payload := []string{playerName} - body, err := json.Marshal(payload) - if err != nil { - return err - } - - res, err := app.CachedPostJSON(reqURL, body, fallbackAPIServer.CacheTTLSeconds) - if err != nil { - log.Printf("Couldn't access fallback API server at %s: %s\n", reqURL, err) - continue - } - if res.StatusCode != http.StatusOK { - continue - } - - var fallbackResponses []playerNameToUUIDResponse - err = json.Unmarshal(res.BodyBytes, &fallbackResponses) - if err != nil { - log.Printf("Received invalid response from fallback API server at %s\n", reqURL) - continue - } - if len(fallbackResponses) == 1 && strings.EqualFold(playerName, fallbackResponses[0].Name) { + for _, fallbackAPIServer := range app.FallbackAPIServers { + fallbackResponses := fallbackAPIServer.PlayerNamesToIDs(mapset.NewSet(lowerName)) + if len(fallbackResponses) == 1 && strings.EqualFold(lowerName, fallbackResponses[0].Name) { return c.JSON(http.StatusOK, fallbackResponses[0]) } } @@ -71,7 +218,7 @@ func AccountPlayerNameToID(app *App) func(c echo.Context) error { if err != nil { return err } - res := playerNameToUUIDResponse{ + res := PlayerNameToIDResponse{ Name: player.Name, ID: id, } @@ -90,25 +237,48 @@ func AccountPlayerNamesToIDs(app *App) func(c echo.Context) error { return err } - n := len(playerNames) - if !(1 <= n && n <= 10) { + if len(playerNames) == 0 { + // This error message is consistent with POST + // https://api.mojang.com/profiles/minecraft as of 2025-04-02 + errorMessage := fmt.Sprintf("getProfileName.profileNames: must not be empty") return &YggdrasilError{ Code: http.StatusBadRequest, Error_: mo.Some("CONSTRAINT_VIOLATION"), - ErrorMessage: mo.Some("getProfileName.profileNames: size must be between 1 and 10"), + ErrorMessage: mo.Some(errorMessage), + } + } else if len(playerNames) > MAX_PLAYER_NAMES_TO_IDS { + // This error message is consistent with POST + // https://api.mojang.com/profiles/minecraft as of 2025-04-02 + errorMessage := fmt.Sprintf("getProfileName.profileNames: size must be between 0 and %d", MAX_PLAYER_NAMES_TO_IDS) + return &YggdrasilError{ + Code: http.StatusBadRequest, + Error_: mo.Some("CONSTRAINT_VIOLATION"), + ErrorMessage: mo.Some(errorMessage), } } - response := make([]playerNameToUUIDResponse, 0, n) + response := make([]PlayerNameToIDResponse, 0, len(playerNames)) - remainingPlayers := map[string]bool{} - for _, playerName := range playerNames { + remainingLowerNames := mapset.NewSet[string]() + for i, playerName := range playerNames { + if !(1 <= len(playerName) && len(playerName) <= Constants.MaxPlayerNameLength) { + // This error message is consistent with POST + // https://api.mojang.com/profiles/minecraft as of 2025-04-02 + errorMessage := fmt.Sprintf("getProfileName.profileNames[%d].: size must be between 1 and %d, getProfileName.profileNames[%d].: Invalid profile name", i, Constants.MaxPlayerNameLength, 1) + return &YggdrasilError{ + Code: http.StatusBadRequest, + Error_: mo.Some("CONSTRAINT_VIOLATION"), + ErrorMessage: mo.Some(errorMessage), + } + } + remainingLowerNames.Add(strings.ToLower(playerName)) + } + + for _, lowerName := range remainingLowerNames.ToSlice() { var player Player - result := app.DB.First(&player, "name = ?", playerName) + result := app.DB.First(&player, "name = ?", lowerName) if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - remainingPlayers[strings.ToLower(playerName)] = true - } else { + if !errors.Is(result.Error, gorm.ErrRecordNotFound) { return result.Error } } else { @@ -116,58 +286,28 @@ func AccountPlayerNamesToIDs(app *App) func(c echo.Context) error { if err != nil { return err } - playerRes := playerNameToUUIDResponse{ + playerRes := PlayerNameToIDResponse{ Name: player.Name, ID: id, } response = append(response, playerRes) + remainingLowerNames.Remove(lowerName) } } - for _, fallbackAPIServer := range app.Config.FallbackAPIServers { - reqURL, err := url.JoinPath(fallbackAPIServer.AccountURL, "profiles/minecraft") - if err != nil { - log.Println(err) - continue - } - - payload := make([]string, 0, len(remainingPlayers)) - for remainingPlayer := range remainingPlayers { - payload = append(payload, remainingPlayer) - } - body, err := json.Marshal(payload) - if err != nil { - return err - } - - res, err := app.CachedPostJSON(reqURL, body, fallbackAPIServer.CacheTTLSeconds) - if err != nil { - log.Printf("Couldn't access fallback API server at %s: %s\n", reqURL, err) - continue - } - - if res.StatusCode != http.StatusOK { - continue - } - - var fallbackResponses []playerNameToUUIDResponse - err = json.Unmarshal(res.BodyBytes, &fallbackResponses) - if err != nil { - log.Printf("Received invalid response from fallback API server at %s\n", reqURL) - continue + for _, fallbackAPIServer := range app.FallbackAPIServers { + if remainingLowerNames.Cardinality() == 0 { + break } + fallbackResponses := fallbackAPIServer.PlayerNamesToIDs(remainingLowerNames) for _, fallbackResponse := range fallbackResponses { lowerName := strings.ToLower(fallbackResponse.Name) - if _, ok := remainingPlayers[lowerName]; ok { + if remainingLowerNames.Contains(lowerName) { response = append(response, fallbackResponse) - delete(remainingPlayers, lowerName) + remainingLowerNames.Remove(lowerName) } } - - if len(remainingPlayers) == 0 { - break - } } return c.JSON(http.StatusOK, response) diff --git a/account_test.go b/account_test.go index 1eee3ad..590fc79 100644 --- a/account_test.go +++ b/account_test.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "fmt" "github.com/stretchr/testify/assert" "net/http" "testing" @@ -29,11 +30,14 @@ func TestAccount(t *testing.T) { ts.SetupAux(auxConfig) config := testConfig() - config.FallbackAPIServers = []FallbackAPIServer{ts.ToFallbackAPIServer(ts.AuxApp, "Aux")} + config.FallbackAPIServers = []FallbackAPIServerConfig{ts.ToFallbackAPIServer(ts.AuxApp, "Aux")} ts.Setup(config) defer ts.Teardown() ts.CreateTestUser(t, ts.AuxApp, ts.AuxServer, TEST_USERNAME) + for i := 1; i <= 20; i += 1 { + ts.CreateTestUser(t, ts.AuxApp, ts.AuxServer, fmt.Sprintf("%s%d", TEST_USERNAME, i)) + } t.Run("Test /users/profiles/minecraft/:playerName, fallback API server", ts.testAccountPlayerNameToIDFallback) t.Run("Test /profile/minecraft, fallback API server", ts.testAccountPlayerNamesToIDsFallback) @@ -44,7 +48,7 @@ func (ts *TestSuite) testAccountPlayerNameToID(t *testing.T) { rec := ts.Get(t, ts.Server, "/users/profiles/minecraft/"+TEST_PLAYER_NAME, nil, nil) assert.Equal(t, http.StatusOK, rec.Code) - var response playerNameToUUIDResponse + var response PlayerNameToIDResponse assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) // Check that the player name is correct @@ -72,7 +76,7 @@ func (ts *TestSuite) testAccountPlayerNameToIDFallback(t *testing.T) { { rec := ts.Get(t, ts.Server, "/users/profiles/minecraft/"+TEST_PLAYER_NAME, nil, nil) assert.Equal(t, http.StatusOK, rec.Code) - var response playerNameToUUIDResponse + var response PlayerNameToIDResponse assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) // Check that the player name is correct @@ -110,7 +114,7 @@ func (ts *TestSuite) testAccountPlayerNameToIDFallback(t *testing.T) { // Test a non-existent user { - rec := ts.Get(t, ts.Server, "/users/profiles/minecraft/", nil, nil) + rec := ts.Get(t, ts.Server, "/users/profiles/minecraft/nonexistent", nil, nil) assert.Equal(t, http.StatusNotFound, rec.Code) } } @@ -121,7 +125,7 @@ func (ts *TestSuite) testAccountPlayerNamesToIDsFallback(t *testing.T) { rec := ts.PostJSON(t, ts.Server, "/profiles/minecraft", payload, nil, nil) assert.Equal(t, http.StatusOK, rec.Code) - var response []playerNameToUUIDResponse + var response []PlayerNameToIDResponse assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) // Get the real UUID @@ -132,7 +136,7 @@ func (ts *TestSuite) testAccountPlayerNamesToIDsFallback(t *testing.T) { // There should only be one player, the nonexistent player should not be present id, err := UUIDToID(player.UUID) assert.Nil(t, err) - assert.Equal(t, []playerNameToUUIDResponse{{Name: TEST_PLAYER_NAME, ID: id}}, response) + assert.Equal(t, []PlayerNameToIDResponse{{Name: TEST_PLAYER_NAME, ID: id}}, response) } { payload := []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"} @@ -144,6 +148,39 @@ func (ts *TestSuite) testAccountPlayerNamesToIDsFallback(t *testing.T) { assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) assert.Equal(t, "CONSTRAINT_VIOLATION", *response.Error) } + { + // Test multiple batches + { + payload := make([]string, 0) + for i := 1; i <= 10; i += 1 { + payload = append(payload, fmt.Sprintf("%s%d", TEST_PLAYER_NAME, i)) + } + rec := ts.PostJSON(t, ts.Server, "/profiles/minecraft", payload, nil, nil) + assert.Equal(t, http.StatusOK, rec.Code) + var response []PlayerNameToIDResponse + assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) + } + { + payload := make([]string, 0) + for i := 11; i <= 15; i += 1 { + payload = append(payload, fmt.Sprintf("%s%d", TEST_PLAYER_NAME, i)) + } + rec := ts.PostJSON(t, ts.Server, "/profiles/minecraft", payload, nil, nil) + assert.Equal(t, http.StatusOK, rec.Code) + var response []PlayerNameToIDResponse + assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) + } + { + payload := make([]string, 0) + for i := 16; i <= 20; i += 1 { + payload = append(payload, fmt.Sprintf("%s%d", TEST_PLAYER_NAME, i)) + } + rec := ts.PostJSON(t, ts.Server, "/profiles/minecraft", payload, nil, nil) + assert.Equal(t, http.StatusOK, rec.Code) + var response []PlayerNameToIDResponse + assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) + } + } } func (ts *TestSuite) testAccountVerifySecurityLocation(t *testing.T) { diff --git a/authlib_injector_test.go b/authlib_injector_test.go index 6808b4c..87f89ad 100644 --- a/authlib_injector_test.go +++ b/authlib_injector_test.go @@ -40,7 +40,7 @@ func TestAuthlibInjector(t *testing.T) { config := testConfig() fallback := ts.ToFallbackAPIServer(ts.AuxApp, "Aux") fallback.SkinDomains = []string{FALLBACK_SKIN_DOMAIN_A, FALLBACK_SKIN_DOMAIN_B} - config.FallbackAPIServers = []FallbackAPIServer{fallback} + config.FallbackAPIServers = []FallbackAPIServerConfig{fallback} ts.Setup(config) defer ts.Teardown() diff --git a/common.go b/common.go index 3e364b8..56faef6 100644 --- a/common.go +++ b/common.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/dgraph-io/ristretto" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/samber/mo" @@ -28,6 +29,9 @@ import ( "time" ) +const MAX_PLAYER_NAMES_TO_IDS = 10 +const MAX_PLAYER_NAMES_TO_IDS_INTERVAL = 1 * time.Second + func (app *App) AEADEncrypt(plaintext []byte) ([]byte, error) { nonceSize := app.AEAD.NonceSize() @@ -691,6 +695,7 @@ func (app *App) GetFallbackSkinTexturesProperty(player *Player) (*SessionProfile } else { // Otherwise, we only know the player name. Query the fallback API // server to get the fallback player's UUID + // TODO this should POST /profiles/minecraft instead to be authlib-injector-compatible reqURL, err := url.JoinPath(fallbackAPIServer.AccountURL, "/users/profiles/minecraft/", fallbackPlayer) if err != nil { log.Println(err) @@ -707,7 +712,7 @@ func (app *App) GetFallbackSkinTexturesProperty(player *Player) (*SessionProfile continue } - var playerResponse playerNameToUUIDResponse + var playerResponse PlayerNameToIDResponse err = json.Unmarshal(res.BodyBytes, &playerResponse) if err != nil { log.Printf("Received invalid response from fallback API server at %s\n", reqURL) @@ -937,3 +942,25 @@ func (app *App) GetSkinTexturesProperty(player *Player, sign bool) (SessionProfi func MakeHTTPClient() *http.Client { return &http.Client{Timeout: 30 * time.Second} } + +type FallbackAPIServer struct { + Config *FallbackAPIServerConfig + PlayerNameToIDCache *ristretto.Cache + PlayerNameToIDJobCh chan []playerNameToIDJob +} + +func NewFallbackAPIServer(config *FallbackAPIServerConfig) (FallbackAPIServer, error) { + var playerNameToIDCache *ristretto.Cache = nil + if config.CacheTTLSeconds > 0 { + var err error + playerNameToIDCache, err = ristretto.NewCache(DefaultRistrettoConfig) + if err != nil { + return FallbackAPIServer{}, err + } + } + return FallbackAPIServer{ + Config: config, + PlayerNameToIDCache: playerNameToIDCache, + PlayerNameToIDJobCh: make(chan []playerNameToIDJob), + }, nil +} diff --git a/config.go b/config.go index 2133f07..ff0d4d6 100644 --- a/config.go +++ b/config.go @@ -28,7 +28,7 @@ type bodyLimitConfig struct { SizeLimitKiB int } -type FallbackAPIServer struct { +type FallbackAPIServerConfig struct { Nickname string SessionURL string AccountURL string @@ -114,7 +114,7 @@ type Config struct { EnableBackgroundEffect bool EnableFooter bool EnableWebFrontEnd bool - FallbackAPIServers []FallbackAPIServer + FallbackAPIServers []FallbackAPIServerConfig ForwardSkins bool InstanceName string ImportExistingPlayer importExistingPlayerConfig @@ -146,6 +146,13 @@ var defaultBodyLimitConfig = bodyLimitConfig{ SizeLimitKiB: 8192, } +var DefaultRistrettoConfig = &ristretto.Config{ + // Defaults from https://pkg.go.dev/github.com/dgraph-io/ristretto#readme-config + NumCounters: 1e7, + MaxCost: 1 << 30, // 1 GiB + BufferItems: 64, +} + func DefaultConfig() Config { return Config{ AllowCapes: true, @@ -190,12 +197,7 @@ func DefaultConfig() Config { Allow: true, RequireInvite: false, }, - RequestCache: ristretto.Config{ - // Defaults from https://pkg.go.dev/github.com/dgraph-io/ristretto#readme-config - NumCounters: 1e7, - MaxCost: 1 << 30, // 1 GiB - BufferItems: 64, - }, + RequestCache: *DefaultRistrettoConfig, SignPublicKeys: true, SkinSizeLimit: 128, StateDirectory: GetDefaultStateDirectory(), diff --git a/config_test.go b/config_test.go index 48e1d20..0754557 100644 --- a/config_test.go +++ b/config_test.go @@ -92,7 +92,7 @@ func TestConfig(t *testing.T) { assert.NotNil(t, CleanConfig(config)) config = configTestConfig(sd) - testFallbackAPIServer := FallbackAPIServer{ + testFallbackAPIServer := FallbackAPIServerConfig{ Nickname: "Nickname", SessionURL: "https://δρασλ.example.com/", AccountURL: "https://δρασλ.example.com/", @@ -100,10 +100,10 @@ func TestConfig(t *testing.T) { SkinDomains: []string{"δρασλ.example.com"}, } fb := testFallbackAPIServer - config.FallbackAPIServers = []FallbackAPIServer{fb} + config.FallbackAPIServers = []FallbackAPIServerConfig{fb} assert.Nil(t, CleanConfig(config)) - assert.Equal(t, []FallbackAPIServer{{ + assert.Equal(t, []FallbackAPIServerConfig{{ Nickname: fb.Nickname, SessionURL: "https://xn--mxafwwl.example.com", AccountURL: "https://xn--mxafwwl.example.com", @@ -113,37 +113,37 @@ func TestConfig(t *testing.T) { fb = testFallbackAPIServer fb.Nickname = "" - config.FallbackAPIServers = []FallbackAPIServer{fb} + config.FallbackAPIServers = []FallbackAPIServerConfig{fb} assert.NotNil(t, CleanConfig(config)) fb = testFallbackAPIServer fb.SessionURL = "" - config.FallbackAPIServers = []FallbackAPIServer{fb} + config.FallbackAPIServers = []FallbackAPIServerConfig{fb} assert.NotNil(t, CleanConfig(config)) fb = testFallbackAPIServer fb.SessionURL = ":invalid URL" - config.FallbackAPIServers = []FallbackAPIServer{fb} + config.FallbackAPIServers = []FallbackAPIServerConfig{fb} assert.NotNil(t, CleanConfig(config)) fb = testFallbackAPIServer fb.AccountURL = "" - config.FallbackAPIServers = []FallbackAPIServer{fb} + config.FallbackAPIServers = []FallbackAPIServerConfig{fb} assert.NotNil(t, CleanConfig(config)) fb = testFallbackAPIServer fb.AccountURL = ":invalid URL" - config.FallbackAPIServers = []FallbackAPIServer{fb} + config.FallbackAPIServers = []FallbackAPIServerConfig{fb} assert.NotNil(t, CleanConfig(config)) fb = testFallbackAPIServer fb.ServicesURL = "" - config.FallbackAPIServers = []FallbackAPIServer{fb} + config.FallbackAPIServers = []FallbackAPIServerConfig{fb} assert.NotNil(t, CleanConfig(config)) fb = testFallbackAPIServer fb.ServicesURL = ":invalid URL" - config.FallbackAPIServers = []FallbackAPIServer{fb} + config.FallbackAPIServers = []FallbackAPIServerConfig{fb} assert.NotNil(t, CleanConfig(config)) // Test that TEMPLATE_CONFIG_FILE is valid diff --git a/flake.nix b/flake.nix index 19f108a..f6cfa25 100644 --- a/flake.nix +++ b/flake.nix @@ -48,7 +48,7 @@ ]; # Update whenever Go dependencies change - vendorHash = "sha256-jthuA1MlP83sXYuZHX6MwD33JfhjrFPax5B+26iLh20="; + vendorHash = "sha256-iGOYsgrOwx3nbvlc3ln6awg23CZBdtaqQbYY30q25dU="; outputs = ["out"]; diff --git a/front_test.go b/front_test.go index 19e46fb..910e24b 100644 --- a/front_test.go +++ b/front_test.go @@ -44,7 +44,7 @@ func setupRegistrationExistingPlayerTS(t *testing.T, requireSkinVerification boo AccountURL: ts.AuxApp.AccountURL, RequireSkinVerification: requireSkinVerification, } - config.FallbackAPIServers = []FallbackAPIServer{ + config.FallbackAPIServers = []FallbackAPIServerConfig{ { Nickname: "Aux", SessionURL: ts.AuxApp.SessionURL, diff --git a/go.mod b/go.mod index 456a921..af46dc8 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ toolchain go1.23.2 require ( github.com/BurntSushi/toml v1.3.2 - github.com/deckarep/golang-set/v2 v2.6.0 + github.com/deckarep/golang-set/v2 v2.8.0 github.com/dgraph-io/ristretto v0.1.1 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index 2b56e01..c678b5c 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/deckarep/golang-set/v2 v2.6.0 h1:XfcQbWM1LlMB8BsJ8N9vW5ehnnPVIw0je80NsVHagjM= -github.com/deckarep/golang-set/v2 v2.6.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= +github.com/deckarep/golang-set/v2 v2.8.0 h1:swm0rlPCmdWn9mESxKOjWk8hXSqoxOp+ZlfuyaAdFlQ= +github.com/deckarep/golang-set/v2 v2.8.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= diff --git a/main.go b/main.go index cf76b04..fe042df 100644 --- a/main.go +++ b/main.go @@ -70,6 +70,7 @@ type App struct { OIDCProviderNames []string OIDCProvidersByName map[string]*OIDCProvider OIDCProvidersByIssuer map[string]*OIDCProvider + FallbackAPIServers []FallbackAPIServer } func LogInfo(args ...interface{}) { @@ -453,14 +454,15 @@ func setup(config *Config) *App { log.Fatal("Invalid verification skin!") } - // Keys + // Keys, FallbackAPIServers + fallbackAPIServers := make([]FallbackAPIServer, 0, len(config.FallbackAPIServers)) playerCertificateKeys := make([]rsa.PublicKey, 0, 1) profilePropertyKeys := make([]rsa.PublicKey, 0, 1) profilePropertyKeys = append(profilePropertyKeys, key.PublicKey) playerCertificateKeys = append(playerCertificateKeys, key.PublicKey) - for _, fallbackAPIServer := range config.FallbackAPIServers { - reqURL := Unwrap(url.JoinPath(fallbackAPIServer.ServicesURL, "publickeys")) + for _, fallbackAPIServerConfig := range config.FallbackAPIServers { + reqURL := Unwrap(url.JoinPath(fallbackAPIServerConfig.ServicesURL, "publickeys")) res, err := MakeHTTPClient().Get(reqURL) if err != nil { log.Printf("Couldn't access fallback API server at %s: %s\n", reqURL, err) @@ -500,7 +502,10 @@ func setup(config *Config) *App { playerCertificateKeys = append(playerCertificateKeys, *publicKey) } } - log.Printf("Fetched public keys from fallback API server %s", fallbackAPIServer.Nickname) + log.Printf("Fetched public keys from fallback API server %s", fallbackAPIServerConfig.Nickname) + + fallbackAPIServer := Unwrap(NewFallbackAPIServer(&fallbackAPIServerConfig)) + fallbackAPIServers = append(fallbackAPIServers, fallbackAPIServer) } // OIDC providers @@ -564,6 +569,7 @@ func setup(config *Config) *App { OIDCProviderNames: oidcProviderNames, OIDCProvidersByName: oidcProvidersByName, OIDCProvidersByIssuer: oidcProvidersByIssuer, + FallbackAPIServers: fallbackAPIServers, } // Post-setup @@ -604,6 +610,12 @@ func setup(config *Config) *App { return app } +func (app *App) Run() { + for _, fallbackAPIServer := range PtrSlice(app.FallbackAPIServers) { + go (*fallbackAPIServer).PlayerNamesToIDsWorker() + } +} + func main() { defaultConfigPath := path.Join(Constants.ConfigDirectory, "config.toml") @@ -623,6 +635,6 @@ func main() { log.Fatalf("Error in config: %s", err) } app := setup(&config) - + go app.Run() Check(app.MakeServer().Start(app.Config.ListenAddress)) } diff --git a/player.go b/player.go index 5f204ab..21d9dd9 100644 --- a/player.go +++ b/player.go @@ -398,7 +398,7 @@ func (app *App) ValidateChallenge(playerName string, challengeToken *string) (*P return nil, errors.New("registration server returned error") } - var idRes playerNameToUUIDResponse + var idRes PlayerNameToIDResponse err = json.NewDecoder(res.Body).Decode(&idRes) if err != nil { return nil, err diff --git a/services_test.go b/services_test.go index 27c0e80..613593b 100644 --- a/services_test.go +++ b/services_test.go @@ -27,7 +27,7 @@ func TestServices(t *testing.T) { config := testConfig() config.ForwardSkins = false - config.FallbackAPIServers = []FallbackAPIServer{ + config.FallbackAPIServers = []FallbackAPIServerConfig{ { Nickname: "Aux", SessionURL: ts.AuxApp.SessionURL, @@ -502,7 +502,7 @@ func (ts *TestSuite) makeTestAccountPlayerNamesToIDs(url string) func(t *testing ts.Server.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - var response []playerNameToUUIDResponse + var response []PlayerNameToIDResponse assert.Nil(t, json.NewDecoder(rec.Body).Decode(&response)) // Get the real UUID @@ -513,6 +513,6 @@ func (ts *TestSuite) makeTestAccountPlayerNamesToIDs(url string) func(t *testing assert.Nil(t, err) // There should only be one user, the nonexistent user should not be present - assert.Equal(t, []playerNameToUUIDResponse{{Name: TEST_USERNAME, ID: id}}, response) + assert.Equal(t, []PlayerNameToIDResponse{{Name: TEST_USERNAME, ID: id}}, response) } } diff --git a/test_suite_test.go b/test_suite_test.go index ef424f5..ec11f96 100644 --- a/test_suite_test.go +++ b/test_suite_test.go @@ -90,8 +90,10 @@ func (ts *TestSuite) Setup(config *Config) { tsConfig := *config ts.Config = &tsConfig ts.App = setup(config) - ts.Server = ts.App.MakeServer() + go ts.App.Run() + + ts.Server = ts.App.MakeServer() go func() { Ignore(ts.Server.Start("")) }() } @@ -105,8 +107,10 @@ func (ts *TestSuite) SetupAux(config *Config) { auxConfig := *config ts.AuxConfig = &auxConfig ts.AuxApp = setup(config) - ts.AuxServer = ts.AuxApp.MakeServer() + go ts.AuxApp.Run() + + ts.AuxServer = ts.AuxApp.MakeServer() go func() { Ignore(ts.AuxServer.Start("")) }() // Wait until the server has a listen address... polling seems like the @@ -127,13 +131,13 @@ func (ts *TestSuite) SetupAux(config *Config) { ts.AuxApp.SessionURL = Unwrap(url.JoinPath(baseURL, "session")) } -func (ts *TestSuite) ToFallbackAPIServer(app *App, nickname string) FallbackAPIServer { - return FallbackAPIServer{ +func (ts *TestSuite) ToFallbackAPIServer(app *App, nickname string) FallbackAPIServerConfig { + return FallbackAPIServerConfig{ Nickname: nickname, SessionURL: app.SessionURL, AccountURL: app.AccountURL, ServicesURL: app.ServicesURL, - CacheTTLSeconds: 3600, + CacheTTLSeconds: 0, } } @@ -324,7 +328,7 @@ func testConfig() *Config { config.Domain = "drasl.example.com" noRateLimit := rateLimitConfig{Enable: false} config.RateLimit = noRateLimit - config.FallbackAPIServers = []FallbackAPIServer{} + config.FallbackAPIServers = []FallbackAPIServerConfig{} config.LogRequests = false return &config }