diff --git a/infra_cli/config.py b/infra_cli/config.py index 878f218..8948b49 100644 --- a/infra_cli/config.py +++ b/infra_cli/config.py @@ -9,17 +9,23 @@ class Config: self.data = self._load() def _load(self): - if not os.path.exists(self.path): - # Fallback to local config if exists - if os.path.exists("config.yaml"): - self.path = os.path.abspath("config.yaml") - else: - raise FileNotFoundError(f"Config file not found at {self.path}. Please create it based on config.yaml.example") + data = {} + if os.path.exists(self.path): + with open(self.path, 'r') as f: + data = yaml.safe_load(f) or {} + elif os.path.exists("config.yaml"): + self.path = os.path.abspath("config.yaml") + with open(self.path, 'r') as f: + data = yaml.safe_load(f) or {} - with open(self.path, 'r') as f: - return yaml.safe_load(f) + return data def get(self, key, default=None): + # Support ENV overrides: INFRA_PROXMOX_USER -> proxmox.user + env_key = "INFRA_" + key.upper().replace('.', '_') + if env_key in os.environ: + return os.environ[env_key] + parts = key.split('.') val = self.data for part in parts: @@ -30,18 +36,7 @@ class Config: return val def get_node(self, node_name): - """Helper to get proxmox node details by name or default to first if none provided""" nodes = self.get('proxmox.nodes', {}) if not nodes: - # Fallback for old single-host config if present - host = self.get('proxmox.host') - if host: - return {"host": host, "pass": self.get('proxmox.password')} return None - - if not node_name: - # Default to first node found - return next(iter(nodes.values())) - - return nodes.get(node_name) - + return nodes.get(node_name) \ No newline at end of file diff --git a/infra_cli/ssh.py b/infra_cli/ssh.py index 1461aaa..735f1fe 100644 --- a/infra_cli/ssh.py +++ b/infra_cli/ssh.py @@ -1,15 +1,23 @@ import subprocess import os +import sys class SSHClient: - def __init__(self, host, user="root", key_path=None, password=None): + def __init__(self, host, user="root", key_path=None, password=None, timeout=30): self.host = host self.user = user self.key_path = os.path.expanduser(key_path) if key_path else None self.password = password + self.timeout = timeout def run(self, cmd, capture=True): - ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null"] + ssh_cmd = [ + "ssh", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", f"ConnectTimeout={self.timeout}", + "-o", "BatchMode=yes" if not self.password else "BatchMode=no" + ] if self.key_path: ssh_cmd += ["-i", self.key_path] @@ -17,16 +25,26 @@ class SSHClient: target = f"{self.user}@{self.host}" if self.password: + # sshpass is required for password auth full_cmd = ["sshpass", "-p", self.password] + ssh_cmd + [target, cmd] else: full_cmd = ssh_cmd + [target, cmd] - result = subprocess.run( - full_cmd, - capture_output=capture, - text=True - ) - return result + try: + result = subprocess.run( + full_cmd, + capture_output=capture, + text=True, + timeout=self.timeout + 10 + ) + return result + except subprocess.TimeoutExpired: + print(f"Error: SSH command timed out after {self.timeout}s on {self.host}", file=sys.stderr) + # Create a mock result for timeout + return subprocess.CompletedProcess(full_cmd, 1, "", "Timeout expired") + except Exception as e: + print(f"Error: SSH execution failed on {self.host}: {e}", file=sys.stderr) + return subprocess.CompletedProcess(full_cmd, 1, "", str(e)) def scp_to(self, local_path, remote_path): scp_cmd = ["scp", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null"] @@ -40,4 +58,4 @@ class SSHClient: else: full_cmd = scp_cmd + [local_path, target] - return subprocess.run(full_cmd, capture_output=True) + return subprocess.run(full_cmd, capture_output=True, timeout=self.timeout + 60) \ No newline at end of file diff --git a/tests/test_iac_integration.py b/tests/test_iac_integration.py new file mode 100644 index 0000000..db09513 --- /dev/null +++ b/tests/test_iac_integration.py @@ -0,0 +1,52 @@ +import pytest +import os +import subprocess +import json + +# Path discovery - We are in external/dynamic-infra-tooling/tests +# Correct project root is 3 levels up +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +ANSIBLE_BIN = "ansible-playbook" +PLAYBOOK_PATH = os.path.join(PROJECT_ROOT, "ansible", "playbooks", "test-infra-plugin.yml") + +def test_ansible_lookup_plugin_integration(): + """Verifies the Ansible plugin can be loaded and executed by Ansible""" + if not os.path.exists(PLAYBOOK_PATH): + pytest.skip(f"Playbook not found at {PLAYBOOK_PATH}") + + env = os.environ.copy() + env["LC_ALL"] = "en_US.UTF-8" + env["LANG"] = "en_US.UTF-8" + # Ensure Ansible finds the plugin + env["ANSIBLE_LOOKUP_PLUGINS"] = os.path.join(PROJECT_ROOT, "ansible", "plugins", "lookup") + + # Run the test playbook + cmd = [ANSIBLE_BIN, PLAYBOOK_PATH] + result = subprocess.run(cmd, capture_output=True, text=True, env=env, cwd=os.path.join(PROJECT_ROOT, "ansible")) + + assert result.returncode == 0 + assert "The next available IP is" in result.stdout + assert "The certificate for loopaware.com is loopaware.com.pem" in result.stdout + +def test_opentofu_external_wrapper(): + """Verifies the OpenTofu wrapper returns valid JSON for the external data source""" + wrapper_script = os.path.join(PROJECT_ROOT, "scripts", "tofu-infra-query.sh") + + if not os.path.exists(wrapper_script): + pytest.skip(f"Wrapper script not found at {wrapper_script}") + + # Simulate Tofu sending JSON query to stdin + query = json.dumps({"command": "ip next-free"}) + + # Use shell execution to ensure the wrapper can find its relative paths correctly if needed + result = subprocess.run( + [wrapper_script], + input=query, + capture_output=True, + text=True + ) + + assert result.returncode == 0 + data = json.loads(result.stdout) + assert "result" in data + assert "10.32." in data["result"] \ No newline at end of file