"""SSH connector for aiodocker."""
from __future__ import annotations
import asyncio
import logging
import os
import re
import tempfile
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
import aiohttp
from aiohttp.connector import Connection
from .exceptions import DockerError
try:
import asyncssh
except ImportError:
asyncssh = None # type: ignore
# Try to import SSH config parser (preferably paramiko like docker-py)
try:
from paramiko import SSHConfig
except ImportError:
SSHConfig = None # type: ignore
log = logging.getLogger(__name__)
# Constants
DEFAULT_SSH_PORT = 22
DANGEROUS_ENV_VARS = ["LD_LIBRARY_PATH", "SSL_CERT_FILE", "SSL_CERT_DIR", "PYTHONPATH"]
__all__ = ["SSHConnector"]
[docs]
class SSHConnector(aiohttp.UnixConnector):
"""SSH tunnel connector that forwards Docker socket connections over SSH."""
def __init__(
self,
ssh_url: str,
strict_host_keys: bool = True,
**kwargs: Any,
) -> None:
"""Initialize SSH connector.
Args:
ssh_url: SSH connection URL (ssh://[user@]host[:port]).
The username is optional and can be inferred from ~/.ssh/config.
strict_host_keys: Enforce strict host key verification (default: True)
**kwargs: Additional SSH connection options
Note:
This connector uses 'docker system dial-stdio' to connect to the remote
Docker daemon, which automatically discovers and uses the correct socket
path on the remote host (works with standard, rootless, and custom setups).
"""
if asyncssh is None:
raise DockerError(
500,
"asyncssh is required for SSH connections. "
"Install with: pip install aiodocker[ssh]",
)
# Validate and parse SSH URL
parsed = urlparse(ssh_url)
if parsed.scheme != "ssh":
raise DockerError(400, f"Invalid SSH URL scheme: {parsed.scheme}")
if not parsed.hostname:
raise DockerError(400, "SSH URL must include hostname")
self._ssh_host = parsed.hostname
self._ssh_port = parsed.port or DEFAULT_SSH_PORT
self._ssh_username = parsed.username
self._ssh_password = parsed.password
self._strict_host_keys = strict_host_keys
# Load SSH config and merge with provided options
ssh_config = self._load_ssh_config()
self._ssh_options = {**ssh_config, **kwargs}
# Validate and enforce host key verification
self._setup_host_key_verification()
# Warn about password in URL
if self._ssh_password:
log.warning(
"Password provided in SSH URL. Consider using SSH key authentication "
"for better security. Passwords may be exposed in logs or memory dumps."
)
# Connection state
self._ssh_conn: asyncssh.SSHClientConnection | None = None
self._ssh_context: Any | None = None
self._tunnel_lock = asyncio.Lock()
self._socket_server: asyncio.Server | None = None
self._relay_tasks: set[asyncio.Task[None]] = set()
# Create secure temporary directory (system chooses location and sets permissions)
self._temp_dir = tempfile.TemporaryDirectory()
self._local_socket_path = os.path.join(self._temp_dir.name, "docker.sock")
# Initialize as Unix connector with our local socket
super().__init__(path=self._local_socket_path)
def _load_ssh_config(self) -> dict[str, Any]:
"""Load SSH configuration from ~/.ssh/config like docker-py does."""
if SSHConfig is None:
log.debug("SSH config parsing not available (paramiko not installed)")
return {}
config_options = {}
ssh_config_path = Path.home() / ".ssh" / "config"
if ssh_config_path.exists():
try:
config = SSHConfig.from_path(ssh_config_path)
host_config = config.lookup(self._ssh_host)
# Map SSH config options to asyncssh parameters
# Only use config port if not specified in URL
if "port" in host_config and self._ssh_port == DEFAULT_SSH_PORT:
self._ssh_port = int(host_config["port"])
# Only use config user if not specified in URL
if "user" in host_config and not self._ssh_username:
self._ssh_username = host_config["user"]
# Map file paths directly
if "identityfile" in host_config:
config_options["client_keys"] = host_config["identityfile"]
if "userknownhostsfile" in host_config:
config_options["known_hosts"] = host_config["userknownhostsfile"]
log.debug("Loaded SSH config for %s", self._ssh_host)
except Exception:
log.exception("Failed to parse SSH config")
return config_options
def _setup_host_key_verification(self) -> None:
"""Setup host key verification following docker-py security principles."""
known_hosts = self._ssh_options.get("known_hosts")
# If no known_hosts specified in config, use default location
if known_hosts is None:
default_known_hosts = Path.home() / ".ssh" / "known_hosts"
if default_known_hosts.exists():
self._ssh_options["known_hosts"] = str(default_known_hosts)
known_hosts = str(default_known_hosts)
if known_hosts is None and self._strict_host_keys:
# Docker-py equivalent: enforce host key checking
raise DockerError(
400,
"Host key verification is required for security. "
"Either add the host to ~/.ssh/known_hosts or set strict_host_keys=False. "
"SECURITY WARNING: Disabling host key verification makes connections "
"vulnerable to man-in-the-middle attacks.",
)
elif known_hosts is None:
# Allow but warn (similar to docker-py's WarningPolicy)
log.warning(
"SECURITY WARNING: Host key verification disabled for %(ssh_host)s. "
"Connection is vulnerable to man-in-the-middle attacks. "
"Add host to ~/.ssh/known_hosts or run: ssh-keyscan -H %(ssh_host)s >> ~/.ssh/known_hosts",
{"ssh_host": self._ssh_host},
)
def _sanitize_error_message(self, error: Exception) -> str:
"""Sanitize error messages to prevent credential leakage."""
message = str(error)
# Remove password from error messages
if self._ssh_password:
message = message.replace(self._ssh_password, "***REDACTED***")
# Remove password from SSH URLs in error messages
message = re.sub(
r"ssh://([^:/@]+):([^@]+)@", r"ssh://\1:***REDACTED***@", message
)
return message
def _clean_environment(self) -> dict[str, str]:
"""Clean environment variables for security like docker-py does."""
env = os.environ.copy()
for var in DANGEROUS_ENV_VARS:
env.pop(var, None)
return env
async def _relay_data(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""Relay data between two streams."""
try:
while True:
data = await reader.read(8192)
if not data:
break
writer.write(data)
await writer.drain()
except (asyncio.CancelledError, ConnectionError, BrokenPipeError):
pass
finally:
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
async def _handle_docker_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""Handle a Docker API connection by executing dial-stdio and relaying data."""
process = None
try:
if self._ssh_conn is None:
raise DockerError(500, "SSH connection not established")
log.debug("Handling new Docker connection via dial-stdio")
# Execute docker system dial-stdio on remote host
# This automatically connects to the correct Docker socket
# Use encoding=None for binary mode (stdin/stdout handle bytes, not strings)
process = await self._ssh_conn.create_process(
"docker system dial-stdio", encoding=None
) # type: ignore
# Create relay tasks for bidirectional communication
send_task = asyncio.create_task(
self._relay_data(reader, process.stdin) # type: ignore
)
recv_task = asyncio.create_task(
self._relay_data(process.stdout, writer) # type: ignore
)
# Track tasks for cleanup
self._relay_tasks.add(send_task)
self._relay_tasks.add(recv_task)
# Wait for either direction to complete
done, pending = await asyncio.wait(
[send_task, recv_task], return_when=asyncio.FIRST_COMPLETED
)
# Cancel remaining task
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Remove from tracking
self._relay_tasks.discard(send_task)
self._relay_tasks.discard(recv_task)
except Exception as e:
sanitized_error = self._sanitize_error_message(e)
log.error("Error in dial-stdio relay: %s", sanitized_error)
finally:
# Clean up process
if process:
try:
process.terminate()
await process.wait()
except Exception:
pass
# Close writer
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
async def _ensure_ssh_tunnel(self) -> None:
"""Ensure SSH connection and local Unix socket server are established."""
# Use lock to prevent concurrent tunnel creation (docker-py principle)
async with self._tunnel_lock:
# Re-check condition after acquiring lock
if self._ssh_conn is None or self._ssh_conn.is_closed():
log.debug(
"Establishing SSH connection to %s@%s:%s",
self._ssh_username,
self._ssh_host,
self._ssh_port,
)
try:
# Clean environment like docker-py does
clean_env = self._clean_environment()
# Use asyncssh context manager properly
self._ssh_context = asyncssh.connect(
host=self._ssh_host,
port=self._ssh_port,
username=self._ssh_username,
password=self._ssh_password,
env=clean_env,
**self._ssh_options,
)
self._ssh_conn = await self._ssh_context.__aenter__()
# Create Unix socket server that handles connections via dial-stdio
self._socket_server = await asyncio.start_unix_server(
self._handle_docker_connection,
path=self._local_socket_path,
)
log.debug(
"SSH connection established, Unix socket server listening at %s",
self._local_socket_path,
)
# Clear password from memory after successful connection
if self._ssh_password:
self._ssh_password = None
except Exception as e:
sanitized_error = self._sanitize_error_message(e)
log.error("Failed to establish SSH connection: %s", sanitized_error)
# Clean up context if it was created
if self._ssh_context:
try:
await self._ssh_context.__aexit__(
type(e), e, e.__traceback__
)
except Exception:
pass
self._ssh_context = None
self._ssh_conn = None
# Wrap in DockerError if not already one
if isinstance(e, DockerError):
raise
raise DockerError(
900,
f"Cannot connect to Docker via SSH {self._ssh_username}@{self._ssh_host}:{self._ssh_port}: {sanitized_error}",
) from e
[docs]
async def connect(
self, req: aiohttp.ClientRequest, traces: Any, timeout: aiohttp.ClientTimeout
) -> Connection:
"""Connect through SSH tunnel."""
await self._ensure_ssh_tunnel()
return await super().connect(req, traces, timeout)
[docs]
async def close(self) -> None: # type: ignore[override]
"""Close SSH connection and clean up resources with proper error handling."""
await super().close()
# Close socket server
if self._socket_server:
try:
self._socket_server.close()
await self._socket_server.wait_closed()
except Exception as e:
log.warning("Error closing socket server: %s", type(e).__name__)
finally:
self._socket_server = None
# Cancel all relay tasks
for task in list(self._relay_tasks):
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self._relay_tasks.clear()
# Close SSH context manager properly
if self._ssh_context:
try:
await self._ssh_context.__aexit__(None, None, None)
except Exception as e:
sanitized_error = self._sanitize_error_message(e)
log.warning("Error closing SSH connection: %s", sanitized_error)
finally:
self._ssh_context = None
self._ssh_conn = None
# Clean up temporary directory (removes socket file automatically)
try:
self._temp_dir.cleanup()
except Exception as e:
# Don't log full path for security
temp_name = self._temp_dir.name[-8:] if self._temp_dir.name else "unknown"
log.warning(
"Failed to clean up temporary directory <temp-%s>: %s",
temp_name,
type(e).__name__,
)
# Clear any remaining sensitive data
self._ssh_password = None