Using Python's ExitStack to mimic move semantics and make your own exception-safe context managers

In Python, it is often necessary to clean something up when you're done with it. Closing a file is the most common example.

with open('file.txt', 'w') as file:
    print('Hello, World!', file=file)
    # The file closes here

You can nest context managers by either nesting them directly or by using the comma shorthand.

with open('file1.txt', 'w') as file1:
    with open('file2.txt', 'w') as file2:
        pass

# This does the same thing
with open('file1.txt', 'w') as file1, open('file2.txt', 'w') as file2:
    pass

# As does this
with (
    open('file1.txt', 'w') as file1,
    open('file2.txt', 'w') as file2
):
    pass

But what if you need to open an unknown number of context managers, such as a list of files? Enter ExitStack.

Basic ExitStack Use

from contextlib import ExitStack

files = ['file1.txt', 'file2.txt']

with ExitStack() as exit_stack:
    filehandles = [exit_stack.enter_context(open(file, 'w')) for file in files]
    for file in filehandles:
        print('Hello, World!', file=file)
    # All files are closed here

This is more than just a small convenience. You can do the same thing yourself, with a list of context managers, and invoking their __enter__ and __exit__ methods manually, but ExitStack offers some power that makes it more useful:

  • ExitStack will always unwind the stack of context managers on exception, passing in the exception objects as needed and keeping track of when an __exit__ method swallows an exception in the stack.
  • ExitStack can have arbitrary callback functions inserted for simpler use.
  • ExitStack has a pop_all() method, allowing it to be used for many more flexible purposes.

Move Semantics

The use of context managers achieves the same basic goal as the use of destructors in C++ and Rust. There is a problem, however, if you want to do something like open a database connection, do some checks on it, and close it only if an error is thrown.

from contextlib import closing
import sqlite3

def open_connection() -> sqlite3.Connection:
    '''Opens the database, and returns it if it's valid.
    '''
    with closing(sqlite3.connect('test.db')) as connection:
        with closing(connection.cursor()) as cursor:
            cursor.execute('PRAGMA application_id')
            if cursor.fetchone()[0] != 1337:
                raise RuntimeError('Invalid database')

        # ERROR: This will return a closed connection
        return connection

An easy way to solve this is use an exit stack, and leak the exit stack into one that is never closed on success.

from contextlib import ExitStack, closing
import sqlite3

def open_connection() -> sqlite3.Connection:
    '''Opens the database, and returns it if it's valid.
    '''
    with ExitStack() as exit_stack:
        connection = exit_stack.enter_context(closing(sqlite3.connect('test.db')))
        with closing(connection.cursor()) as cursor:
            cursor.execute('PRAGMA application_id')
            if cursor.fetchone()[0] != 1337:
                raise RuntimeError('Invalid database')

        exit_stack.pop_all()
        return connection
    assert False, 'unreachable'

The assert False is to silence the type checker, which gets upset due to ExitStack being allowed to swallow exceptions.

In this case, we use the exit stack to guard against exceptions, closing the database when one is thrown, and on the happy path, we pop into a new ExitStack that is never exited, and return the exception.

This is similar to C++ move semantics, because we can bypass normal lexical destruction on an opt-in basis.

In this case, an easier option would be to turn open_connection into a context manager itself and use it in that way, yielding connection out, but that's not always an option. Notably, when making your own class as a context manager.

Custom Context Managers

Making your own context manager isn't very complex, but making a context manager that wraps contained context managers can be devilishly complex. This simple example has some notable problems:

from collections.abc import Generator
from contextlib import contextmanager
from types import TracebackType

@contextmanager
def cm(label: str) -> Generator[str, None, None]:
    print(f'pre-{label}')
    try:
        yield label
    finally:
        print(f'post-{label}')

class MyContextManager:
    def __enter__(self):
        self.__alpha_cm = cm('alpha')
        self.__alpha = self.__alpha_cm.__enter__()
        self.__beta_cm = cm('beta')
        self.__beta = self.__beta_cm.__enter__()
        self.__gamma_cm = cm('gamma')
        self.__gamma = self.__gamma_cm.__enter__()
    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ):
        self.__alpha_cm.__exit__(exc_type, exc_value, traceback)
        self.__beta_cm.__exit__(exc_type, exc_value, traceback)
        self.__gamma_cm.__exit__(exc_type, exc_value, traceback)

On the surface, this looks fine, but there are a few major issues here:

  1. An exception thrown during __enter__ will prevent __exit__ from running, so any already entered context managers will not be cleaned up.
  2. An exception thrown mid-__exit__ will prevent following context managers from being cleaned up.
  3. Any of the context managers swallowing exceptions will not inherently allow this context manager to swallow exceptions.
  4. Context managers might be order-sensitive, so they really should be cleaned up in reverse order.

To clean this up and do it right, you have to make things incredibly ugly.

from collections.abc import Generator
from contextlib import contextmanager
from types import TracebackType

@contextmanager
def cm(label: str) -> Generator[str, None, None]:
    print(f'pre-{label}')
    try:
        yield label
    finally:
        print(f'post-{label}')

class MyContextManager:
    def __enter__(self):
        self.__alpha_cm = cm('alpha')
        self.__alpha = self.__alpha_cm.__enter__()
        try:
            self.__beta_cm = cm('beta')
            self.__beta = self.__beta_cm.__enter__()
            try:
                self.__gamma_cm = cm('gamma')
                self.__gamma = self.__gamma_cm.__enter__()
            except BaseException as e:
                self.__beta_cm.__exit__(type(e), e, e.__traceback__)
                raise
        except BaseException as e:
            self.__alpha_cm.__exit__(type(e), e, e.__traceback__)
            raise
    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> None:
        try:
            if self.__gamma_cm.__exit__(exc_type, exc_value, traceback):
                exc_type = None
                exc_value = None
                traceback = None
        except BaseException as e:
            exc_type = type(e)
            exc_value = e
            traceback = e.__traceback__
        finally:
            try:
                if self.__beta_cm.__exit__(exc_type, exc_value, traceback):
                    exc_type = None
                    exc_value = None
                    traceback = None
            except BaseException as e:
                exc_type = type(e)
                exc_value = e
                traceback = e.__traceback__
            finally:
                if self.__alpha_cm.__exit__(exc_type, exc_value, traceback):
                    exc_type = None
                    exc_value = None
                    traceback = None
    if exc_value:
        raise exc_value
    else:
        return True

This is hideous and unreadable. It's also not perfect. It deals awkwardly with exceptions in flight and raises exceptions manually instead of just returning None to indicate that exceptions in flight should continue. Fortunately, ExitStack can make this much, much nicer for us. This is one of the major boons of using pop_all() to mimic move semantics, because you can protect your __enter__() method from exceptions, and then pop the exit stack to clean up in__exit__().

from collections.abc import Generator
from contextlib import ExitStack, contextmanager
from types import TracebackType

@contextmanager
def cm(label: str) -> Generator[str, None, None]:
    print(f'pre-{label}')
    try:
        yield label
    finally:
        print(f'post-{label}')

class MyContextManager:
    def __enter__(self):
        with ExitStack() as exit_stack:
            self.__alpha = exit_stack.enter_context(cm('alpha'))
            self.__beta = exit_stack.enter_context(cm('beta'))
            self.__gamma = exit_stack.enter_context(cm('gamma'))
            self.__exit_stack = exit_stack.pop_all()

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> bool:
        return self.__exit_stack.__exit__(exc_type, exc_value, traceback)

This is made even more valuable when we start inheriting context managers. Inheriting a context manager gives some of the same frustrations as composing your class of context managers, but the ExitStack makes it much easier.

from collections.abc import Generator
from contextlib import ExitStack, contextmanager
from types import TracebackType

@contextmanager
def cm(label: str) -> Generator[str, None, None]:
    print(f'pre-{label}')
    try:
        yield label
    finally:
        print(f'post-{label}')

class Parent:
    def __enter__(self):
        with ExitStack() as exit_stack:
            self.__alpha = exit_stack.enter_context(cm('alpha'))
            self.__beta = exit_stack.enter_context(cm('beta'))
            self.__gamma = exit_stack.enter_context(cm('gamma'))
            self.__exit_stack = exit_stack.pop_all()

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> bool:
        return self.__exit_stack.__exit__(exc_type, exc_value, traceback)

class Child(Parent):
    def __enter__(self):
        with ExitStack() as exit_stack:
            self.__super = super().__enter__()
            exit_stack.push(super().__exit__)
            self.__delta = exit_stack.enter_context(cm('delta'))
            self.__epsilon = exit_stack.enter_context(cm('epsilon'))
            self.__zeta = exit_stack.enter_context(cm('zeta'))
            self.__exit_stack = exit_stack.pop_all()

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> bool:
        return self.__exit_stack.__exit__(exc_type, exc_value, traceback)

Afterward

Python's context managers can be both flexible and powerful with ExitStack. They aren't perfect still, because they aren't re-entrant, but they become much easier to manage.

Also note the AsyncExitStack, which can achieve the same thing in asynchronous code.