Using Python's ExitStack to mimic move semantics and make your own exception-safe context managers
In Python, it is often necessary to clean something up when you're done with it. Closing a file is the most common example.
with open('file.txt', 'w') as file: print('Hello, World!', file=file) # The file closes here
You can nest context managers by either nesting them directly or by using the comma shorthand.
with open('file1.txt', 'w') as file1: with open('file2.txt', 'w') as file2: pass # This does the same thing with open('file1.txt', 'w') as file1, open('file2.txt', 'w') as file2: pass # As does this with ( open('file1.txt', 'w') as file1, open('file2.txt', 'w') as file2 ): pass
But what if you need to open an unknown number of context managers, such as a list of files? Enter ExitStack.
Basic ExitStack Use
from contextlib import ExitStack files = ['file1.txt', 'file2.txt'] with ExitStack() as exit_stack: filehandles = [exit_stack.enter_context(open(file, 'w')) for file in files] for file in filehandles: print('Hello, World!', file=file) # All files are closed here
This is more than just a small convenience. You can do the same thing yourself, with a list of context managers, and invoking their __enter__
and __exit__
methods manually, but ExitStack
offers some power that makes it more useful:
ExitStack
will always unwind the stack of context managers on exception, passing in the exception objects as needed and keeping track of when an__exit__
method swallows an exception in the stack.ExitStack
can have arbitrary callback functions inserted for simpler use.ExitStack
has apop_all()
method, allowing it to be used for many more flexible purposes.
Move Semantics
The use of context managers achieves the same basic goal as the use of destructors in C++ and Rust. There is a problem, however, if you want to do something like open a database connection, do some checks on it, and close it only if an error is thrown.
from contextlib import closing import sqlite3 def open_connection() -> sqlite3.Connection: '''Opens the database, and returns it if it's valid. ''' with closing(sqlite3.connect('test.db')) as connection: with closing(connection.cursor()) as cursor: cursor.execute('PRAGMA application_id') if cursor.fetchone()[0] != 1337: raise RuntimeError('Invalid database') # ERROR: This will return a closed connection return connection
An easy way to solve this is use an exit stack, and leak the exit stack into one that is never closed on success.
from contextlib import ExitStack, closing import sqlite3 def open_connection() -> sqlite3.Connection: '''Opens the database, and returns it if it's valid. ''' with ExitStack() as exit_stack: connection = exit_stack.enter_context(closing(sqlite3.connect('test.db'))) with closing(connection.cursor()) as cursor: cursor.execute('PRAGMA application_id') if cursor.fetchone()[0] != 1337: raise RuntimeError('Invalid database') exit_stack.pop_all() return connection assert False, 'unreachable'
The assert False
is to silence the type checker, which gets upset due to ExitStack
being allowed to swallow exceptions.
In this case, we use the exit stack to guard against exceptions, closing the database when one is thrown, and on the happy path, we pop into a new ExitStack
that is never exited, and return the exception.
This is similar to C++ move semantics, because we can bypass normal lexical destruction on an opt-in basis.
In this case, an easier option would be to turn open_connection
into a context manager itself and use it in that way, yielding connection out, but that's not always an option. Notably, when making your own class as a context manager.
Custom Context Managers
Making your own context manager isn't very complex, but making a context manager that wraps contained context managers can be devilishly complex. This simple example has some notable problems:
from collections.abc import Generator from contextlib import contextmanager from types import TracebackType @contextmanager def cm(label: str) -> Generator[str, None, None]: print(f'pre-{label}') try: yield label finally: print(f'post-{label}') class MyContextManager: def __enter__(self): self.__alpha_cm = cm('alpha') self.__alpha = self.__alpha_cm.__enter__() self.__beta_cm = cm('beta') self.__beta = self.__beta_cm.__enter__() self.__gamma_cm = cm('gamma') self.__gamma = self.__gamma_cm.__enter__() def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ): self.__alpha_cm.__exit__(exc_type, exc_value, traceback) self.__beta_cm.__exit__(exc_type, exc_value, traceback) self.__gamma_cm.__exit__(exc_type, exc_value, traceback)
On the surface, this looks fine, but there are a few major issues here:
- An exception thrown during
__enter__
will prevent__exit__
from running, so any already entered context managers will not be cleaned up. - An exception thrown mid-
__exit__
will prevent following context managers from being cleaned up. - Any of the context managers swallowing exceptions will not inherently allow this context manager to swallow exceptions.
- Context managers might be order-sensitive, so they really should be cleaned up in reverse order.
To clean this up and do it right, you have to make things incredibly ugly.
from collections.abc import Generator from contextlib import contextmanager from types import TracebackType @contextmanager def cm(label: str) -> Generator[str, None, None]: print(f'pre-{label}') try: yield label finally: print(f'post-{label}') class MyContextManager: def __enter__(self): self.__alpha_cm = cm('alpha') self.__alpha = self.__alpha_cm.__enter__() try: self.__beta_cm = cm('beta') self.__beta = self.__beta_cm.__enter__() try: self.__gamma_cm = cm('gamma') self.__gamma = self.__gamma_cm.__enter__() except BaseException as e: self.__beta_cm.__exit__(type(e), e, e.__traceback__) raise except BaseException as e: self.__alpha_cm.__exit__(type(e), e, e.__traceback__) raise def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: try: if self.__gamma_cm.__exit__(exc_type, exc_value, traceback): exc_type = None exc_value = None traceback = None except BaseException as e: exc_type = type(e) exc_value = e traceback = e.__traceback__ finally: try: if self.__beta_cm.__exit__(exc_type, exc_value, traceback): exc_type = None exc_value = None traceback = None except BaseException as e: exc_type = type(e) exc_value = e traceback = e.__traceback__ finally: if self.__alpha_cm.__exit__(exc_type, exc_value, traceback): exc_type = None exc_value = None traceback = None if exc_value: raise exc_value else: return True
This is hideous and unreadable. It's also not perfect. It deals awkwardly with exceptions in flight and raises exceptions manually instead of just returning None to indicate that exceptions in flight should continue. Fortunately, ExitStack
can make this much, much nicer for us. This is one of the major boons of using pop_all() to mimic move semantics, because you can protect your __enter__()
method from exceptions, and then pop the exit stack to clean up in__exit__()
.
from collections.abc import Generator from contextlib import ExitStack, contextmanager from types import TracebackType @contextmanager def cm(label: str) -> Generator[str, None, None]: print(f'pre-{label}') try: yield label finally: print(f'post-{label}') class MyContextManager: def __enter__(self): with ExitStack() as exit_stack: self.__alpha = exit_stack.enter_context(cm('alpha')) self.__beta = exit_stack.enter_context(cm('beta')) self.__gamma = exit_stack.enter_context(cm('gamma')) self.__exit_stack = exit_stack.pop_all() def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> bool: return self.__exit_stack.__exit__(exc_type, exc_value, traceback)
This is made even more valuable when we start inheriting context managers. Inheriting a context manager gives some of the same frustrations as composing your class of context managers, but the ExitStack
makes it much easier.
from collections.abc import Generator from contextlib import ExitStack, contextmanager from types import TracebackType @contextmanager def cm(label: str) -> Generator[str, None, None]: print(f'pre-{label}') try: yield label finally: print(f'post-{label}') class Parent: def __enter__(self): with ExitStack() as exit_stack: self.__alpha = exit_stack.enter_context(cm('alpha')) self.__beta = exit_stack.enter_context(cm('beta')) self.__gamma = exit_stack.enter_context(cm('gamma')) self.__exit_stack = exit_stack.pop_all() def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> bool: return self.__exit_stack.__exit__(exc_type, exc_value, traceback) class Child(Parent): def __enter__(self): with ExitStack() as exit_stack: self.__super = super().__enter__() exit_stack.push(super().__exit__) self.__delta = exit_stack.enter_context(cm('delta')) self.__epsilon = exit_stack.enter_context(cm('epsilon')) self.__zeta = exit_stack.enter_context(cm('zeta')) self.__exit_stack = exit_stack.pop_all() def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> bool: return self.__exit_stack.__exit__(exc_type, exc_value, traceback)
Afterward
Python's context managers can be both flexible and powerful with ExitStack. They aren't perfect still, because they aren't re-entrant, but they become much easier to manage.
Also note the AsyncExitStack
, which can achieve the same thing in asynchronous code.