Optimized getMovementToTilesAtPosition tilesToIgnore to a bitset instead of a hashset, saving 7.4% of next turn calculation time!

This commit is contained in:
yairm210 2025-06-22 20:09:54 +03:00
parent be8935439c
commit 435e5805f9
5 changed files with 87 additions and 11 deletions

View File

@ -232,12 +232,16 @@ object HexMath {
/** Get number of hexes from [origin] to [destination] _without respecting world-wrap_ */ /** Get number of hexes from [origin] to [destination] _without respecting world-wrap_ */
fun getDistance(origin: Vector2, destination: Vector2): Int { fun getDistance(origin: Vector2, destination: Vector2): Int {
val relativeX = origin.x - destination.x return getDistance(origin.x.toInt(), origin.y.toInt(), destination.x.toInt(), destination.y.toInt())
val relativeY = origin.y - destination.y }
fun getDistance(originX: Int, originY: Int, destinationX: Int, destinationY: Int): Int {
val relativeX = originX - destinationX
val relativeY = originY - destinationY
return if (relativeX * relativeY >= 0) return if (relativeX * relativeY >= 0)
max(abs(relativeX), abs(relativeY)).toInt() max(abs(relativeX), abs(relativeY))
else else
(abs(relativeX) + abs(relativeY)).toInt() (abs(relativeX) + abs(relativeY))
} }
private val clockPositionToHexVectorMap: Map<Int, Vector2> = mapOf( private val clockPositionToHexVectorMap: Map<Int, Vector2> = mapOf(
@ -296,4 +300,40 @@ object HexMath {
return min(getDistance(vector, Vector2(1f, radius.toFloat())), getDistance(vector, Vector2(-radius.toFloat(), -1f))) return min(getDistance(vector, Vector2(1f, radius.toFloat())), getDistance(vector, Vector2(-radius.toFloat(), -1f)))
} }
} }
/**
* The goal here is to map from hexagonal positions (centered on 0,0) to positive integers (starting from 0) so we can replace hashmap/hashset with arrays/bitsets
* Places 1-6 are ring 1, 7-18 are ring 2, etc.
*/
fun getZeroBasedIndex(x: Int, y: Int): Int {
if (x == 0 && y == 0) return 0
val ring = getDistance(0,0, x, y)
val ringStart = 1 + 6 * ring * (ring - 1) / 2 // 1 for the center tile, then 6 for each ring
// total number of elements in the ring is 6 * ring
// We divide the ring into its 6 edges, each of which can be determined by an equality comparison
// Each edge has a start index, a variable from 0 to the number of elements in that edge
val positionInRing = when (ring) {
y -> 0 /* start index*/ + x /*variable*/ // contains `ring+1` elements
x -> ring + 1 /* start index */ + y /*variable*/ // contains `ring` elements - 1 already taken by x=y=ring above
-x -> 2 * ring + 1 /* start index */ -y /*variable*/ // contains `ring+1` elements
-y -> 3 * ring + 2 /*start index*/ -x /*variable*/ // contains `ring` elements - 1 already taken by -x=-y=ring above
x-y -> 4 * ring + 2 /* start index */ +x-1 /*variable*/ // contains `ring-1` elements. -1 because x=0 is already taken by ring=-y above
y-x -> 5 * ring + 1 /* start index */ +y-1 /*variable*/ // contains `ring-1` elements. -1 because y=0 is already taken by ring=-x above
else -> throw Exception("How???")
}
return ringStart + positionInRing
}
// Much simpler to understand, passes same tests, but ~5x slower than the above
fun mapRelativePositionToPositiveIntRedblob(x: Int, y: Int): Int {
if (x == 0 && y == 0) return 0
val ring = getDistance(0,0, x, y)
val ringStart = 1 + 6 * ring * (ring - 1) / 2 // 1 for the center tile, then 6 for each ring
val vectorsInRing = getVectorsAtDistance(Vector2.Zero, ring, ring, false)
val positionInRing = vectorsInRing.indexOf(Vector2(x.toFloat(), y.toFloat()))
return ringStart + positionInRing
}
} }

View File

@ -516,6 +516,7 @@ class TileMap(initialCapacity: Int = 10) : IsPartOfGameInfoSerialization {
// looks at tileMatrix. Thus filling Tiles into tileMatrix and setting their // looks at tileMatrix. Thus filling Tiles into tileMatrix and setting their
// transients in the same loop will leave incomplete cached `neighbors`. // transients in the same loop will leave incomplete cached `neighbors`.
tileInfo.tileMap = this tileInfo.tileMap = this
tileInfo.zeroBasedIndex = HexMath.getZeroBasedIndex(tileInfo.position.x.toInt(), tileInfo.position.y.toInt())
tileInfo.ruleset = this.ruleset!! tileInfo.ruleset = this.ruleset!!
tileInfo.setTerrainTransients() tileInfo.setTerrainTransients()
tileInfo.setUnitTransients(setUnitCivTransients) tileInfo.setUnitTransients(setUnitCivTransients)

View File

@ -11,6 +11,7 @@ import com.unciv.logic.map.tile.Tile
import com.unciv.models.UnitActionType import com.unciv.models.UnitActionType
import com.unciv.models.ruleset.unique.UniqueType import com.unciv.models.ruleset.unique.UniqueType
import com.unciv.ui.components.UnitMovementMemoryType import com.unciv.ui.components.UnitMovementMemoryType
import java.util.BitSet
class UnitMovement(val unit: MapUnit) { class UnitMovement(val unit: MapUnit) {
@ -30,7 +31,7 @@ class UnitMovement(val unit: MapUnit) {
position: Vector2, position: Vector2,
unitMovement: Float, unitMovement: Float,
considerZoneOfControl: Boolean = true, considerZoneOfControl: Boolean = true,
tilesToIgnore: HashSet<Tile>? = null, tilesToIgnoreBitset: BitSet? = null,
passThroughCache: HashMap<Tile, Boolean> = HashMap(), passThroughCache: HashMap<Tile, Boolean> = HashMap(),
movementCostCache: HashMap<Pair<Tile, Tile>, Float> = HashMap(), movementCostCache: HashMap<Pair<Tile, Tile>, Float> = HashMap(),
includeOtherEscortUnit: Boolean = true includeOtherEscortUnit: Boolean = true
@ -49,12 +50,13 @@ class UnitMovement(val unit: MapUnit) {
&& unit.getOtherEscortUnit()?.currentMovement == 0f) return distanceToTiles && unit.getOtherEscortUnit()?.currentMovement == 0f) return distanceToTiles
var tilesToCheck = listOf(unitTile) var tilesToCheck = listOf(unitTile)
while (tilesToCheck.isNotEmpty()) { while (tilesToCheck.isNotEmpty()) {
val updatedTiles = ArrayList<Tile>() val updatedTiles = ArrayList<Tile>()
for (tileToCheck in tilesToCheck) for (tileToCheck in tilesToCheck)
for (neighbor in tileToCheck.neighbors) { for (neighbor in tileToCheck.neighbors) {
if (tilesToIgnore?.contains(neighbor) == true) continue // ignore this tile // ignore this tile
if (tilesToIgnoreBitset != null && tilesToIgnoreBitset.get(neighbor.zeroBasedIndex)) continue // ignore this tile
var totalDistanceToTile: Float = when { var totalDistanceToTile: Float = when {
!neighbor.isExplored(unit.civ) -> !neighbor.isExplored(unit.civ) ->
distanceToTiles[tileToCheck]!!.totalMovement + 1f // If we don't know then we just guess it to be 1. distanceToTiles[tileToCheck]!!.totalMovement + 1f // If we don't know then we just guess it to be 1.
@ -131,7 +133,7 @@ class UnitMovement(val unit: MapUnit) {
var distance = 1 var distance = 1
val unitMaxMovement = unit.getMaxMovement().toFloat() val unitMaxMovement = unit.getMaxMovement().toFloat()
val newTilesToCheck = ArrayList<Tile>() val newTilesToCheck = ArrayList<Tile>()
val visitedTiles: HashSet<Tile> = hashSetOf(currentTile) val visitedTilesBitset = BitSet().apply { set(currentTile.zeroBasedIndex) }
val civilization = unit.civ val civilization = unit.civ
val passThroughCache = HashMap<Tile, Boolean>() val passThroughCache = HashMap<Tile, Boolean>()
@ -160,7 +162,7 @@ class UnitMovement(val unit: MapUnit) {
tileToCheck.position, tileToCheck.position,
unitMaxMovement, unitMaxMovement,
false, false,
visitedTiles, visitedTilesBitset,
passThroughCache, passThroughCache,
movementCostCache movementCostCache
) )
@ -203,11 +205,11 @@ class UnitMovement(val unit: MapUnit) {
} }
// add newTilesToCheck to visitedTiles so we do not path over these tiles in a later iteration // add newTilesToCheck to visitedTiles so we do not path over these tiles in a later iteration
visitedTiles.addAll(newTilesToCheck) for (tile in newTilesToCheck) visitedTilesBitset.set(tile.zeroBasedIndex)
// no need to check tiles that are surrounded by reachable tiles, only need to check the edgemost tiles. // no need to check tiles that are surrounded by reachable tiles, only need to check the edgemost tiles.
// Because anything we can reach from intermediate tiles, can be more easily reached by the edgemost tiles, // Because anything we can reach from intermediate tiles, can be more easily reached by the edgemost tiles,
// since we'll have to pass through an edgemost tile in order to reach the destination anyway // since we'll have to pass through an edgemost tile in order to reach the destination anyway
tilesToCheck = newTilesToCheck.filterNot { tile -> tile.neighbors.all { it in visitedTiles } } tilesToCheck = newTilesToCheck.filterNot { tile -> tile.neighbors.all { visitedTilesBitset.get(it.zeroBasedIndex) } }
distance++ distance++
} }

View File

@ -98,6 +98,9 @@ class Tile : IsPartOfGameInfoSerialization, Json.Serializable {
//region Transient fields //region Transient fields
@Transient @Transient
lateinit var tileMap: TileMap lateinit var tileMap: TileMap
@Transient
var zeroBasedIndex: Int = 0
@Transient @Transient
lateinit var ruleset: Ruleset // a tile can be a tile with a ruleset, even without a map. lateinit var ruleset: Ruleset // a tile can be a tile with a ruleset, even without a map.

View File

@ -0,0 +1,30 @@
package com.unciv.logic.map
import com.badlogic.gdx.math.Vector2
import org.junit.Assert
import org.junit.Test
class HexmathTests {
// Looks like our current movement is actually unoptimized, since it fails this test :)
@Test
fun zeroIndexed(){
Assert.assertEquals(0, HexMath.getZeroBasedIndex(0,0))
}
@Test
fun testMappingIsOneToOne(){
val seenCoordsMapping = hashSetOf<Int>()
for (ring in 1..100) {
val coords = HexMath.getVectorsAtDistance(Vector2.Zero, ring, 100, false)
val ringStartCoordinate = 1 + 6 * ring * (ring - 1) / 2
for (coord in coords) {
val mapping = HexMath.getZeroBasedIndex(coord.x.toInt(), coord.y.toInt())
Assert.assertFalse("Duplicate coords found: $coord", seenCoordsMapping.contains(mapping))
Assert.assertTrue("Coords $coord should be in ring $ring, actual mapping $mapping", mapping in ringStartCoordinate .. (ringStartCoordinate + 6 * ring - 1))
seenCoordsMapping.add(mapping)
}
}
}
}