Skip to content

Commit

Permalink
WIP: Accept generic ExceptionGroups for raises
Browse files Browse the repository at this point in the history
  • Loading branch information
tapetersen committed Jan 14, 2025
1 parent f017df4 commit ae40048
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ repos:
- attrs>=19.2.0
- packaging
- tomli
- types-pkg_resources
- types-setuptools
# for mypy running on python>=3.11 since exceptiongroup is only a dependency
# on <3.11
- exceptiongroup>=1.0.0rc8
Expand Down
48 changes: 44 additions & 4 deletions src/_pytest/python_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import pprint
import sys
from collections.abc import Collection
from collections.abc import Sized
from decimal import Decimal
Expand All @@ -10,6 +11,8 @@
from typing import cast
from typing import ContextManager
from typing import final
from typing import get_args
from typing import get_origin
from typing import List
from typing import Mapping
from typing import Optional
Expand All @@ -26,6 +29,9 @@
from _pytest.compat import STRING_TYPES
from _pytest.outcomes import fail

if sys.version_info[:2] < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup

if TYPE_CHECKING:
from numpy import ndarray

Expand Down Expand Up @@ -772,6 +778,10 @@ def _as_numpy_array(obj: object) -> Optional["ndarray"]:

E = TypeVar("E", bound=BaseException)

# This will be `typing_GenericAlias` in the backport as opposed to
# `types.GenericAlias` for native ExceptionGroup. The cast is to not confuse mypy
_generic_exc_group_type = cast(type, type(ExceptionGroup[Exception]))


@overload
def raises(
Expand Down Expand Up @@ -941,15 +951,45 @@ def raises( # noqa: F811
f"Raising exceptions is already understood as failing the test, so you don't need "
f"any special code to say 'this should never raise an exception'."
)
if isinstance(expected_exception, type):
if isinstance(expected_exception, (type, type(ExceptionGroup[Exception]))):
expected_exceptions: Tuple[Type[E], ...] = (expected_exception,)
else:
expected_exceptions = expected_exception

_expected_exceptions: List[Type[E]] = []
for exc in expected_exceptions:
if not isinstance(exc, type) or not issubclass(exc, BaseException):
if isinstance(exc, type(ExceptionGroup[Exception])):
if issubclass(
(origin := cast(Type[E], get_origin(exc))), BaseExceptionGroup
):
exc_type = get_args(exc)[0]
if issubclass(origin, ExceptionGroup) and exc_type is Exception:
_expected_exceptions.append(cast(Type[E], origin))
continue
elif (
issubclass(origin, BaseExceptionGroup) and exc_type is BaseException
):
_expected_exceptions.append(cast(Type[E], origin))
continue

raise ValueError(
f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` "
f"are accepted as generic types but got `{exc}`. "
f"`raises` will catch all instances of the base-type regardless so the "
f"returned type will be wider regardless and has to be checked with `ExceptionInfo.group_contains()`"
)

elif not isinstance(exc, type) or not issubclass(exc, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
raise TypeError(msg.format(not_a))
_expected_exceptions.append(exc)

_expected_exception: Union[Type[E], Tuple[Type[E], ...]] = (
_expected_exceptions[0]
if len(_expected_exceptions) == 1
else tuple(_expected_exceptions)
)

message = f"DID NOT RAISE {expected_exception}"

Expand All @@ -960,14 +1000,14 @@ def raises( # noqa: F811
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
raise TypeError(msg)
return RaisesContext(expected_exception, message, match)
return RaisesContext(_expected_exception, message, match)
else:
func = args[0]
if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
try:
func(*args[1:], **kwargs)
except expected_exception as e:
except _expected_exception as e:
return _pytest._code.ExceptionInfo.from_exception(e)
fail(message)

Expand Down
28 changes: 27 additions & 1 deletion testing/code/test_excinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from _pytest._code.code import _TracebackStyle

if sys.version_info[:2] < (3, 11):
from exceptiongroup import ExceptionGroup
from exceptiongroup import ExceptionGroup, BaseExceptionGroup


@pytest.fixture
Expand Down Expand Up @@ -447,6 +447,32 @@ def test_division_zero():
result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match])


def test_raises_accepts_generic_group() -> None:
exc_group = ExceptionGroup("", [RuntimeError()])
with pytest.raises(ExceptionGroup[Exception]) as exc_info:
raise exc_group
assert exc_info.group_contains(RuntimeError)


def test_raises_accepts_generic_base_group() -> None:
exc_group = ExceptionGroup("", [RuntimeError()])
with pytest.raises(BaseExceptionGroup[BaseException]) as exc_info:
raise exc_group
assert exc_info.group_contains(RuntimeError)


def test_raises_rejects_specific_generic() -> None:
with pytest.raises(ValueError):
pytest.raises(ExceptionGroup[RuntimeError])


def test_raises_accepts_generic_in_tuple() -> None:
exc_group = ExceptionGroup("", [RuntimeError()])
with pytest.raises((ValueError, ExceptionGroup[Exception])) as exc_info:
raise exc_group
assert exc_info.group_contains(RuntimeError)


class TestGroupContains:
def test_contains_exception_type(self) -> None:
exc_group = ExceptionGroup("", [RuntimeError()])
Expand Down

0 comments on commit ae40048

Please sign in to comment.