Skip to content

Commit

Permalink
Allow for multiple copies of the same function in different modules.
Browse files Browse the repository at this point in the history
Previous iterations of the cache assumed a one-one-one function name to
module mapping, however it's possible to end up with many modules which
contain the same function (e.g due to a race condition when using multiple
processes, but other scenarios could exist in the future). This commit
separates the func_name (the id for the function) with the link_name
(the unique id for a given function in a given module). This distinction
is not exposed outside the cache - when asked for a target the cache
chooses which version to return (currently just the first one it sees).
  • Loading branch information
William Grant committed Mar 2, 2023
1 parent 12e35a8 commit d77f679
Show file tree
Hide file tree
Showing 3 changed files with 372 additions and 90 deletions.
178 changes: 104 additions & 74 deletions typed_python/compiler/compiler_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,77 +49,88 @@ class CompilerCache:
when we first boot up, which could be slow. We could improve this substantially
by making it possible to determine if a given function is in the cache by organizing
the manifests by, say, function name.
Due to the potential for race conditions, we must distinguish between the following:
func_name - The identifier for the function, based on its identity hash.
link_name - The identifier for the specific realization of that function, which lives in a specific
cache module.
"""
def __init__(self, cacheDir):
self.cacheDir = cacheDir

ensureDirExists(cacheDir)

self.loadedBinarySharedObjects = Dict(str, LoadedBinarySharedObject)()
self.nameToModuleHash = Dict(str, str)()

self.link_name_to_module_hash = Dict(str, str)()
self.moduleManifestsLoaded = set()

# link_names with an associated module in loadedBinarySharedObjects
self.targetsLoaded: Dict[str, TypedCallTarget] = {}
# the set of link_names for functions with linked and validated globals (i.e. ready to be run).
self.targetsValidated = set()
# link_name -> link_name
self.function_dependency_graph = DirectedGraph()
# dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions)
self.global_dependencies = Dict(str, ListOf(str))()
self.func_name_to_link_names = Dict(str, ListOf(str))()
for moduleHash in os.listdir(self.cacheDir):
if len(moduleHash) == 40:
self.loadNameManifestFromStoredModuleByHash(moduleHash)

# the set of functions with an associated module in loadedBinarySharedObjects
self.targetsLoaded: Dict[str, TypedCallTarget] = {}
def hasSymbol(self, func_name: str) -> bool:
"""Returns true if there are any versions of `func_name` in the cache.
# the set of functions with linked and validated globals (i.e. ready to be run).
self.targetsValidated = set()
There may be multiple copies in different modules with different link_names.
"""
return any(link_name in self.link_name_to_module_hash for link_name in self.func_name_to_link_names.get(func_name, []))

self.function_dependency_graph = DirectedGraph()
# dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions)
self.global_dependencies = Dict(str, ListOf(str))()
def getTarget(self, func_name: str) -> TypedCallTarget:
if not self.hasSymbol(func_name):
raise ValueError(f'symbol not found for func_name {func_name}')
link_name = self._select_link_name(func_name)
self.loadForSymbol(link_name)
return self.targetsLoaded[link_name]

def hasSymbol(self, linkName: str) -> bool:
"""NB this will return True even if the linkName is ultimately unretrievable."""
return linkName in self.nameToModuleHash
def _generate_link_name(self, func_name: str, module_hash: str) -> str:
return func_name + "." + module_hash

def getTarget(self, linkName: str) -> TypedCallTarget:
if not self.hasSymbol(linkName):
raise ValueError(f'symbol not found for linkName {linkName}')
self.loadForSymbol(linkName)
return self.targetsLoaded[linkName]
def _select_link_name(self, func_name) -> str:
"""choose a link name for a given func name.
def dependencies(self, linkName: str) -> Optional[List[str]]:
"""Returns all the function names that `linkName` depends on"""
return list(self.function_dependency_graph.outgoing(linkName))
Currently we just choose the first available option.
Throws a KeyError if func_name isn't in the cache.
"""
link_name_candidates = self.func_name_to_link_names[func_name]
return link_name_candidates[0]

def dependencies(self, link_name: str) -> Optional[List[str]]:
"""Returns all the function names that `link_name` depends on"""
return list(self.function_dependency_graph.outgoing(link_name))

def loadForSymbol(self, linkName: str) -> None:
"""Loads the whole module, and any submodules, into LoadedBinarySharedObjects"""
moduleHash = self.nameToModuleHash[linkName]
"""Loads the whole module, and any dependant modules, into LoadedBinarySharedObjects"""
moduleHash = self.link_name_to_module_hash[linkName]

self.loadModuleByHash(moduleHash)

if linkName not in self.targetsValidated:
dependantFuncs = self.dependencies(linkName) + [linkName]
globalsToLink = {} # dict from modulehash to list of globals.
for funcName in dependantFuncs:
if funcName not in self.targetsValidated:
funcModuleHash = self.nameToModuleHash[funcName]
# append to the list of globals to link for a given module. TODO: optimise this, don't double-link.
globalsToLink[funcModuleHash] = globalsToLink.get(funcModuleHash, []) + self.global_dependencies.get(funcName, [])

for moduleHash, globs in globalsToLink.items(): # this works because loadModuleByHash loads submodules too.
if globs:
definitionsToLink = {x: self.loadedBinarySharedObjects[moduleHash].serializedGlobalVariableDefinitions[x]
for x in globs
}
self.loadedBinarySharedObjects[moduleHash].linkGlobalVariables(definitionsToLink)
if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink):
raise RuntimeError('failed to validate globals when loading:', linkName)

self.targetsValidated.update(dependantFuncs)
self.targetsValidated.add(linkName)
for dependant_func in self.dependencies(linkName):
self.loadForSymbol(dependant_func)

globalsToLink = self.global_dependencies.get(linkName, [])
if globalsToLink:
definitionsToLink = {x: self.loadedBinarySharedObjects[moduleHash].serializedGlobalVariableDefinitions[x]
for x in globalsToLink
}
self.loadedBinarySharedObjects[moduleHash].linkGlobalVariables(definitionsToLink)
if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink):
raise RuntimeError('failed to validate globals when loading:', linkName)

def loadModuleByHash(self, moduleHash: str) -> None:
"""Load a module by name.
As we load, place all the newly imported typed call targets into
'nameToTypedCallTarget' so that the rest of the system knows what functions
have been uncovered.
Add the module contents to targetsLoaded, generate a LoadedBinarySharedObject,
and update the function and global dependency graphs.
"""
if moduleHash in self.loadedBinarySharedObjects:
return
Expand All @@ -128,6 +139,7 @@ def loadModuleByHash(self, moduleHash: str) -> None:

# TODO (Will) - store these names as module consts, use one .dat only
with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f:
# func_name -> typedcalltarget
callTargets = SerializationContext().deserialize(f.read())

with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f:
Expand Down Expand Up @@ -156,45 +168,68 @@ def loadModuleByHash(self, moduleHash: str) -> None:
serializedGlobalVarDefs,
functionNameToNativeType,
globalDependencies

).loadFromPath(modulePath)

self.loadedBinarySharedObjects[moduleHash] = loaded

self.targetsLoaded.update(callTargets)
for func_name, callTarget in callTargets.items():
link_name = self._generate_link_name(func_name, moduleHash)
assert link_name not in self.targetsLoaded
self.targetsLoaded[link_name] = callTarget

assert not any(key in self.global_dependencies for key in globalDependencies) # should only happen if there's a hash collision.
self.global_dependencies.update(globalDependencies)
link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()}

assert not any(key in self.global_dependencies for key in link_name_global_dependencies)

self.global_dependencies.update(link_name_global_dependencies)
# update the cache's dependency graph with our new edges.
for function_name, dependant_function_name in dependency_edgelist:
self.function_dependency_graph.addEdge(source=function_name, dest=dependant_function_name)

def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies, dependencyEdgelist):
"""Add new code to the compiler cache.
Args:
binarySharedObject: a BinarySharedObject containing the actual assembler
we've compiled.
nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us
nameToTypedCallTarget: a dict from func_name to TypedCallTarget telling us
the formal python types for all the objects.
linkDependencies: a set of linknames we depend on directly.
linkDependencies: a set of func_names we depend on directly. (this becomes submodules)
dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the
module.
TODO (Will): the notion of submodules/linkDependencies can be refactored out.
"""
dependentHashes = set()

hashToUse = SerializationContext().sha_hash(str(uuid.uuid4())).hexdigest

# the linkDependencies and dependencyEdgelist are in terms of func_name.
dependentHashes = set()
for name in linkDependencies:
dependentHashes.add(self.nameToModuleHash[name])
link_name = self._select_link_name(name)
dependentHashes.add(self.link_name_to_module_hash[link_name])

link_name_dependency_edgelist = []
for source, dest in dependencyEdgelist:
assert source in binarySharedObject.definedSymbols
source_link_name = self._generate_link_name(source, hashToUse)
if dest in binarySharedObject.definedSymbols:
dest_link_name = self._generate_link_name(dest, hashToUse)
else:
dest_link_name = self._select_link_name(dest)
link_name_dependency_edgelist.append([source_link_name, dest_link_name])

path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes, dependencyEdgelist)
path = self.writeModuleToDisk(binarySharedObject, hashToUse, nameToTypedCallTarget, dependentHashes, link_name_dependency_edgelist)

self.loadedBinarySharedObjects[hashToUse] = (
binarySharedObject.loadFromPath(os.path.join(path, "module.so"))
)

for n in binarySharedObject.definedSymbols:
self.nameToModuleHash[n] = hashToUse
for func_name in binarySharedObject.definedSymbols:
link_name = self._generate_link_name(func_name, hashToUse)
self.link_name_to_module_hash[link_name] = hashToUse
self.func_name_to_link_names.setdefault(func_name, []).append(link_name)

# link & validate all globals for the new module
self.loadedBinarySharedObjects[hashToUse].linkGlobalVariables()
Expand All @@ -208,20 +243,18 @@ def loadNameManifestFromStoredModuleByHash(self, moduleHash) -> None:

targetDir = os.path.join(self.cacheDir, moduleHash)

with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
submodules = SerializationContext().deserialize(f.read(), ListOf(str))

for subHash in submodules:
self.loadNameManifestFromStoredModuleByHash(subHash)

# TODO (Will) the name_manifest module_hash is the same throughout so this doesn't need to be a dict.
with open(os.path.join(targetDir, "name_manifest.dat"), "rb") as f:
self.nameToModuleHash.update(
SerializationContext().deserialize(f.read(), Dict(str, str))
)
func_name_to_module_hash = SerializationContext().deserialize(f.read(), Dict(str, str))

for func_name, module_hash in func_name_to_module_hash.items():
link_name = self._generate_link_name(func_name, module_hash)
self.func_name_to_link_names.setdefault(func_name, []).append(link_name)
self.link_name_to_module_hash[link_name] = module_hash

self.moduleManifestsLoaded.add(moduleHash)

def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules, dependencyEdgelist):
def writeModuleToDisk(self, binarySharedObject, hashToUse, nameToTypedCallTarget, submodules, dependencyEdgelist):
"""Write out a disk representation of this module.
This includes writing both the shared object, a manifest of the function names
Expand All @@ -235,7 +268,6 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
to interact with the compiler cache simultaneously without relying on
individual file-level locking.
"""
hashToUse = SerializationContext().sha_hash(str(uuid.uuid4())).hexdigest

targetDir = os.path.join(
self.cacheDir,
Expand Down Expand Up @@ -264,23 +296,20 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
for sourceName in manifest:
f.write(sourceName + "\n")

# write the type manifest
with open(os.path.join(tempTargetDir, "type_manifest.dat"), "wb") as f:
f.write(SerializationContext().serialize(nameToTypedCallTarget))

# write the nativetype manifest
with open(os.path.join(tempTargetDir, "native_type_manifest.dat"), "wb") as f:
f.write(SerializationContext().serialize(binarySharedObject.functionTypes))

# write the type manifest
with open(os.path.join(tempTargetDir, "globals_manifest.dat"), "wb") as f:
f.write(SerializationContext().serialize(binarySharedObject.serializedGlobalVariableDefinitions))

with open(os.path.join(tempTargetDir, "submodules.dat"), "wb") as f:
f.write(SerializationContext().serialize(ListOf(str)(submodules), ListOf(str)))

with open(os.path.join(tempTargetDir, "function_dependencies.dat"), "wb") as f:
f.write(SerializationContext().serialize(dependencyEdgelist)) # might need a listof
f.write(SerializationContext().serialize(dependencyEdgelist))

with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f:
f.write(SerializationContext().serialize(binarySharedObject.globalDependencies))
Expand All @@ -293,14 +322,15 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
else:
shutil.rmtree(tempTargetDir)

return targetDir, hashToUse
return targetDir

def function_pointer_by_name(self, linkName):
moduleHash = self.nameToModuleHash.get(linkName)
def function_pointer_by_name(self, func_name):
linkName = self._select_link_name(func_name)
moduleHash = self.link_name_to_module_hash.get(linkName)
if moduleHash is None:
raise Exception("Can't find a module for " + linkName)

if moduleHash not in self.loadedBinarySharedObjects:
self.loadForSymbol(linkName)

return self.loadedBinarySharedObjects[moduleHash].functionPointers[linkName]
return self.loadedBinarySharedObjects[moduleHash].functionPointers[func_name]
Loading

0 comments on commit d77f679

Please sign in to comment.