refactor: improve config and ssh core and add iac integration tests

This commit is contained in:
Fredrick Amnehagen 2026-02-06 00:20:36 +01:00
parent 5d61ef8116
commit 2d9e1bc06e
3 changed files with 94 additions and 29 deletions

View file

@ -9,17 +9,23 @@ class Config:
self.data = self._load() self.data = self._load()
def _load(self): def _load(self):
if not os.path.exists(self.path): data = {}
# Fallback to local config if exists if os.path.exists(self.path):
if os.path.exists("config.yaml"): with open(self.path, 'r') as f:
self.path = os.path.abspath("config.yaml") data = yaml.safe_load(f) or {}
else: elif os.path.exists("config.yaml"):
raise FileNotFoundError(f"Config file not found at {self.path}. Please create it based on config.yaml.example") self.path = os.path.abspath("config.yaml")
with open(self.path, 'r') as f:
data = yaml.safe_load(f) or {}
with open(self.path, 'r') as f: return data
return yaml.safe_load(f)
def get(self, key, default=None): def get(self, key, default=None):
# Support ENV overrides: INFRA_PROXMOX_USER -> proxmox.user
env_key = "INFRA_" + key.upper().replace('.', '_')
if env_key in os.environ:
return os.environ[env_key]
parts = key.split('.') parts = key.split('.')
val = self.data val = self.data
for part in parts: for part in parts:
@ -30,18 +36,7 @@ class Config:
return val return val
def get_node(self, node_name): def get_node(self, node_name):
"""Helper to get proxmox node details by name or default to first if none provided"""
nodes = self.get('proxmox.nodes', {}) nodes = self.get('proxmox.nodes', {})
if not nodes: if not nodes:
# Fallback for old single-host config if present
host = self.get('proxmox.host')
if host:
return {"host": host, "pass": self.get('proxmox.password')}
return None return None
return nodes.get(node_name)
if not node_name:
# Default to first node found
return next(iter(nodes.values()))
return nodes.get(node_name)

View file

@ -1,15 +1,23 @@
import subprocess import subprocess
import os import os
import sys
class SSHClient: class SSHClient:
def __init__(self, host, user="root", key_path=None, password=None): def __init__(self, host, user="root", key_path=None, password=None, timeout=30):
self.host = host self.host = host
self.user = user self.user = user
self.key_path = os.path.expanduser(key_path) if key_path else None self.key_path = os.path.expanduser(key_path) if key_path else None
self.password = password self.password = password
self.timeout = timeout
def run(self, cmd, capture=True): def run(self, cmd, capture=True):
ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null"] ssh_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", f"ConnectTimeout={self.timeout}",
"-o", "BatchMode=yes" if not self.password else "BatchMode=no"
]
if self.key_path: if self.key_path:
ssh_cmd += ["-i", self.key_path] ssh_cmd += ["-i", self.key_path]
@ -17,16 +25,26 @@ class SSHClient:
target = f"{self.user}@{self.host}" target = f"{self.user}@{self.host}"
if self.password: if self.password:
# sshpass is required for password auth
full_cmd = ["sshpass", "-p", self.password] + ssh_cmd + [target, cmd] full_cmd = ["sshpass", "-p", self.password] + ssh_cmd + [target, cmd]
else: else:
full_cmd = ssh_cmd + [target, cmd] full_cmd = ssh_cmd + [target, cmd]
result = subprocess.run( try:
full_cmd, result = subprocess.run(
capture_output=capture, full_cmd,
text=True capture_output=capture,
) text=True,
return result timeout=self.timeout + 10
)
return result
except subprocess.TimeoutExpired:
print(f"Error: SSH command timed out after {self.timeout}s on {self.host}", file=sys.stderr)
# Create a mock result for timeout
return subprocess.CompletedProcess(full_cmd, 1, "", "Timeout expired")
except Exception as e:
print(f"Error: SSH execution failed on {self.host}: {e}", file=sys.stderr)
return subprocess.CompletedProcess(full_cmd, 1, "", str(e))
def scp_to(self, local_path, remote_path): def scp_to(self, local_path, remote_path):
scp_cmd = ["scp", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null"] scp_cmd = ["scp", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null"]
@ -40,4 +58,4 @@ class SSHClient:
else: else:
full_cmd = scp_cmd + [local_path, target] full_cmd = scp_cmd + [local_path, target]
return subprocess.run(full_cmd, capture_output=True) return subprocess.run(full_cmd, capture_output=True, timeout=self.timeout + 60)

View file

@ -0,0 +1,52 @@
import pytest
import os
import subprocess
import json
# Path discovery - We are in external/dynamic-infra-tooling/tests
# Correct project root is 3 levels up
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
ANSIBLE_BIN = "ansible-playbook"
PLAYBOOK_PATH = os.path.join(PROJECT_ROOT, "ansible", "playbooks", "test-infra-plugin.yml")
def test_ansible_lookup_plugin_integration():
"""Verifies the Ansible plugin can be loaded and executed by Ansible"""
if not os.path.exists(PLAYBOOK_PATH):
pytest.skip(f"Playbook not found at {PLAYBOOK_PATH}")
env = os.environ.copy()
env["LC_ALL"] = "en_US.UTF-8"
env["LANG"] = "en_US.UTF-8"
# Ensure Ansible finds the plugin
env["ANSIBLE_LOOKUP_PLUGINS"] = os.path.join(PROJECT_ROOT, "ansible", "plugins", "lookup")
# Run the test playbook
cmd = [ANSIBLE_BIN, PLAYBOOK_PATH]
result = subprocess.run(cmd, capture_output=True, text=True, env=env, cwd=os.path.join(PROJECT_ROOT, "ansible"))
assert result.returncode == 0
assert "The next available IP is" in result.stdout
assert "The certificate for loopaware.com is loopaware.com.pem" in result.stdout
def test_opentofu_external_wrapper():
"""Verifies the OpenTofu wrapper returns valid JSON for the external data source"""
wrapper_script = os.path.join(PROJECT_ROOT, "scripts", "tofu-infra-query.sh")
if not os.path.exists(wrapper_script):
pytest.skip(f"Wrapper script not found at {wrapper_script}")
# Simulate Tofu sending JSON query to stdin
query = json.dumps({"command": "ip next-free"})
# Use shell execution to ensure the wrapper can find its relative paths correctly if needed
result = subprocess.run(
[wrapper_script],
input=query,
capture_output=True,
text=True
)
assert result.returncode == 0
data = json.loads(result.stdout)
assert "result" in data
assert "10.32." in data["result"]