Skip to content

Commit

Permalink
Merge pull request #2042 from minrk/shadow-context
Browse files Browse the repository at this point in the history
preserve context reference when shadowing Socket classes
  • Loading branch information
minrk authored Oct 22, 2024
2 parents dcf8963 + d29dc4a commit 63209e4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def test_shadow(self):
assert s2._shadow_obj is s
assert s.underlying != p.underlying
assert s2.underlying == s.underlying
assert s2.context is s.context
s3 = zmq.Socket(s)
assert s3._shadow_obj is s
assert s3.underlying == s.underlying
Expand Down
6 changes: 6 additions & 0 deletions zmq/sugar/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
shadow: Socket | int = 0,
copy_threshold: int | None = None,
):
shadow_context: zmq.Context | None = None
if isinstance(ctx_or_socket, zmq.Socket):
# positional Socket(other_socket)
shadow = ctx_or_socket
Expand All @@ -145,6 +146,8 @@ def __init__(
# hold a reference to the shadow object
self._shadow_obj = shadow
if not isinstance(shadow, int):
if isinstance(shadow, zmq.Socket):
shadow_context = shadow.context
try:
shadow = cast(int, shadow.underlying)
except AttributeError:
Expand All @@ -159,6 +162,9 @@ def __init__(
shadow=shadow_address,
copy_threshold=copy_threshold,
)
if self._shadow_obj and shadow_context:
# keep self.context reference if shadowing a Socket object
self.context = shadow_context

try:
socket_type = cast(int, self.get(zmq.TYPE))
Expand Down

0 comments on commit 63209e4

Please sign in to comment.