Kaydet (Commit) c495c66e authored tarafından Nick Coghlan's avatar Nick Coghlan

Fix the speed regression in inspect.py by adding another cache to speed up…

Fix the speed regression in inspect.py by adding another cache to speed up getmodule(). Patch #1553314
üst c563a1c3
...@@ -403,6 +403,7 @@ def getabsfile(object, _filename=None): ...@@ -403,6 +403,7 @@ def getabsfile(object, _filename=None):
return os.path.normcase(os.path.abspath(_filename)) return os.path.normcase(os.path.abspath(_filename))
modulesbyfile = {} modulesbyfile = {}
_filesbymodname = {}
def getmodule(object, _filename=None): def getmodule(object, _filename=None):
"""Return the module an object was defined in, or None if not found.""" """Return the module an object was defined in, or None if not found."""
...@@ -410,19 +411,32 @@ def getmodule(object, _filename=None): ...@@ -410,19 +411,32 @@ def getmodule(object, _filename=None):
return object return object
if hasattr(object, '__module__'): if hasattr(object, '__module__'):
return sys.modules.get(object.__module__) return sys.modules.get(object.__module__)
# Try the filename to modulename cache
if _filename is not None and _filename in modulesbyfile:
return sys.modules.get(modulesbyfile[_filename])
# Try the cache again with the absolute file name
try: try:
file = getabsfile(object, _filename) file = getabsfile(object, _filename)
except TypeError: except TypeError:
return None return None
if file in modulesbyfile: if file in modulesbyfile:
return sys.modules.get(modulesbyfile[file]) return sys.modules.get(modulesbyfile[file])
for module in sys.modules.values(): # Update the filename to module name cache and check yet again
# Copy sys.modules in order to cope with changes while iterating
for modname, module in sys.modules.items():
if ismodule(module) and hasattr(module, '__file__'): if ismodule(module) and hasattr(module, '__file__'):
f = module.__file__
if f == _filesbymodname.get(modname, None):
# Have already mapped this module, so skip it
continue
_filesbymodname[modname] = f
f = getabsfile(module) f = getabsfile(module)
# Always map to the name the module knows itself by
modulesbyfile[f] = modulesbyfile[ modulesbyfile[f] = modulesbyfile[
os.path.realpath(f)] = module.__name__ os.path.realpath(f)] = module.__name__
if file in modulesbyfile: if file in modulesbyfile:
return sys.modules.get(modulesbyfile[file]) return sys.modules.get(modulesbyfile[file])
# Check the main module
main = sys.modules['__main__'] main = sys.modules['__main__']
if not hasattr(object, '__name__'): if not hasattr(object, '__name__'):
return None return None
...@@ -430,6 +444,7 @@ def getmodule(object, _filename=None): ...@@ -430,6 +444,7 @@ def getmodule(object, _filename=None):
mainobject = getattr(main, object.__name__) mainobject = getattr(main, object.__name__)
if mainobject is object: if mainobject is object:
return main return main
# Check builtins
builtin = sys.modules['__builtin__'] builtin = sys.modules['__builtin__']
if hasattr(builtin, object.__name__): if hasattr(builtin, object.__name__):
builtinobject = getattr(builtin, object.__name__) builtinobject = getattr(builtin, object.__name__)
...@@ -444,7 +459,7 @@ def findsource(object): ...@@ -444,7 +459,7 @@ def findsource(object):
in the file and the line number indexes a line in that list. An IOError in the file and the line number indexes a line in that list. An IOError
is raised if the source code cannot be retrieved.""" is raised if the source code cannot be retrieved."""
file = getsourcefile(object) or getfile(object) file = getsourcefile(object) or getfile(object)
module = getmodule(object) module = getmodule(object, file)
if module: if module:
lines = linecache.getlines(file, module.__dict__) lines = linecache.getlines(file, module.__dict__)
else: else:
......
...@@ -178,7 +178,18 @@ class TestRetrievingSourceCode(GetSourceBase): ...@@ -178,7 +178,18 @@ class TestRetrievingSourceCode(GetSourceBase):
self.assertEqual(inspect.getcomments(mod.StupidGit), '# line 20\n') self.assertEqual(inspect.getcomments(mod.StupidGit), '# line 20\n')
def test_getmodule(self): def test_getmodule(self):
# Check actual module
self.assertEqual(inspect.getmodule(mod), mod)
# Check class (uses __module__ attribute)
self.assertEqual(inspect.getmodule(mod.StupidGit), mod) self.assertEqual(inspect.getmodule(mod.StupidGit), mod)
# Check a method (no __module__ attribute, falls back to filename)
self.assertEqual(inspect.getmodule(mod.StupidGit.abuse), mod)
# Do it again (check the caching isn't broken)
self.assertEqual(inspect.getmodule(mod.StupidGit.abuse), mod)
# Check a builtin
self.assertEqual(inspect.getmodule(str), sys.modules["__builtin__"])
# Check filename override
self.assertEqual(inspect.getmodule(None, modfile), mod)
def test_getsource(self): def test_getsource(self):
self.assertSourceEqual(git.abuse, 29, 39) self.assertSourceEqual(git.abuse, 29, 39)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment