Source code for aiodocker.stream

from __future__ import annotations

import socket
import struct
import warnings
from collections.abc import Awaitable, Callable
from types import TracebackType
from typing import TYPE_CHECKING, NamedTuple

import aiohttp
import attrs
from yarl import URL

from ._flow_control_queue import FlowControlDataQueue
from .exceptions import DockerError


if TYPE_CHECKING:
    from .docker import Docker


[docs] class Message(NamedTuple): stream: int data: bytes
[docs] class Stream: _resp: aiohttp.ClientResponse | None def __init__( self, docker: "Docker", setup: Callable[[], Awaitable[tuple[URL, bytes | None, bool]]], timeout: aiohttp.ClientTimeout | None = None, ) -> None: self._setup = setup self.docker = docker self._resp = None self._closed = False self._timeout = timeout self._queue: FlowControlDataQueue[Message] | None = None async def _init(self) -> None: if self._resp is not None: return url, body, tty = await self._setup() # inherit and update the parent client's timeout timeout = self.docker._timeout if self._timeout is not None: timeout = attrs.evolve( timeout, connect=self._timeout.connect, sock_connect=self._timeout.sock_connect, ) # sock_read and total timeout doesn't make sense for streaming timeout = attrs.evolve(timeout, sock_read=None, total=None) self._resp = resp = await self.docker._do_query( url, method="POST", data=body, params=None, headers={"Connection": "Upgrade", "Upgrade": "tcp"}, timeout=timeout, chunked=None, read_until_eof=False, versioned_api=True, ) # read body if present, it can contain an information # about disconnection assert self._resp is not None body = await self._resp.read() conn = resp.connection if conn is None: msg = ( "Cannot upgrade connection to vendored tcp protocol, " "the docker server has closed underlying socket." ) msg += f" Status code: {resp.status}." msg += f" Headers: {resp.headers}." if body: if len(body) > 100: msg = msg + f" First 100 bytes of body: [{body[100]!r}]..." else: msg = msg + f" Body: [{body!r}]" raise DockerError(500, msg) protocol = conn.protocol assert protocol is not None assert protocol.transport is not None sock = protocol.transport.get_extra_info("socket") if sock is not None: # set TCP keepalive for vendored socket # the socket can be closed in the case of error sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) queue: FlowControlDataQueue[Message] = FlowControlDataQueue( protocol, limit=2**16 ) protocol.set_parser(_ExecParser(queue, tty=tty), queue) protocol.force_close() self._queue = queue
[docs] async def read_out(self) -> Message | None: """Read from stdout or stderr.""" await self._init() try: assert self._queue is not None return await self._queue.read() except aiohttp.EofStream: return None
[docs] async def write_in(self, data: bytes) -> None: """Write into stdin.""" if self._closed: raise RuntimeError("Cannot write to closed transport") await self._init() assert self._resp is not None assert self._resp.connection is not None transport = self._resp.connection.transport assert transport is not None transport.write(data) protocol = self._resp.connection.protocol assert protocol is not None if protocol.transport is not None: await protocol._drain_helper()
[docs] async def close(self) -> None: if self._resp is None: return if self._closed: return self._closed = True assert self._resp.connection is not None transport = self._resp.connection.transport if transport and transport.can_write_eof(): transport.write_eof() self._resp.close()
async def __aenter__(self) -> Stream: await self._init() return self async def __aexit__( self, exc_typ: type[BaseException], exc_val: BaseException, exc_tb: TracebackType, ) -> bool | None: await self.close() return None def __del__(self, _warnings=warnings) -> None: if self._resp is None: return if not self._closed: warnings.warn("Unclosed ExecStream", ResourceWarning)
class _ExecParser: def __init__(self, queue, tty=False) -> None: self.queue = queue self.tty = tty self.header_fmt = struct.Struct(">BxxxL") self._buf = bytearray() def set_exception(self, exc: BaseException) -> None: self.queue.set_exception(exc) def feed_eof(self) -> None: self.queue.feed_eof() def feed_data(self, data: bytes) -> tuple[bool, bytes]: if self.tty: msg = Message(1, data) # stdout self.queue.feed_data(msg, len(data)) else: self._buf.extend(data) while self._buf: # Parse the header if len(self._buf) < self.header_fmt.size: return False, b"" fileno, msglen = self.header_fmt.unpack( self._buf[: self.header_fmt.size] ) msg_and_header = self.header_fmt.size + msglen if len(self._buf) < msg_and_header: return False, b"" msg = Message( fileno, bytes(self._buf[self.header_fmt.size : msg_and_header]) ) self.queue.feed_data(msg, msglen) del self._buf[:msg_and_header] return False, b""