#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Azure Security-Aware Discovery
-------------------------------
Discovers Azure resources, enriches them with configuration data
for security/compliance analysis, and saves a JSON snapshot.

Requires:
  - Azure CLI logged in (`az login`)
  - Python 3.9+
  - Reader-level access on the subscriptions

Usage:
  python azure_analysis.py
  python azure_analysis.py -v
  python azure_analysis.py -s <sub_id> --max-workers 8
"""

import json
import logging
import os
import re
import shlex
import subprocess
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass

import click

# ---------- Logging ----------

LOG = logging.getLogger("azure_discovery")

def setup_logging(verbosity: int) -> None:
    level = logging.WARNING
    if verbosity == 1:
        level = logging.INFO
    elif verbosity >= 2:
        level = logging.DEBUG
    logging.basicConfig(
        level=level,
        format="%(asctime)s %(levelname)s %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%SZ",
    )

# ---------- Helpers ----------

AZ_TIMEOUT_SECS = 60
RETRY_MAX = 3
RETRY_BASE_DELAY = 1.5

def utc_now_iso() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")

def run_az(args: List[str], timeout: int = AZ_TIMEOUT_SECS) -> Any:
    """
    Run an `az` CLI command and return parsed JSON output.
    Retries on transient or timeout failures, and gracefully skips unresponsive calls.
    """
    cmd = ["az"] + args + ["--output", "json"]
    for attempt in range(1, RETRY_MAX + 1):
        try:
            LOG.debug("Running: %s", " ".join(shlex.quote(p) for p in cmd))
            proc = subprocess.run(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                check=False,
                timeout=timeout,
                text=True,
            )
            # Non-zero return codes
            if proc.returncode != 0:
                stderr = proc.stderr.strip()
                LOG.warning("az returned %s: %s", proc.returncode, stderr[:200])
                # Retry transient or throttling errors
                if attempt < RETRY_MAX and any(
                    kw in stderr.lower()
                    for kw in ("throttle", "timeout", "server busy", "try again", "temporarily unavailable")
                ):
                    time.sleep(RETRY_BASE_DELAY * attempt)
                    continue
                return {}

            # Parse output
            raw = proc.stdout.strip()
            if not raw:
                return {}
            try:
                return json.loads(raw)
            except json.JSONDecodeError:
                return {"raw_output": raw}

        except subprocess.TimeoutExpired:
            LOG.warning("az command timed out after %ss (attempt %s/%s)", timeout, attempt, RETRY_MAX)
            if attempt < RETRY_MAX:
                time.sleep(RETRY_BASE_DELAY * attempt)
                continue
            # Final timeout — skip the command rather than hang
            return {"error": f"timeout after {timeout}s", "args": args}

    # Should never reach here
    return {}

ARM_RG_RE = re.compile(r"/resourceGroups/([^/]+)/", re.IGNORECASE)
ARM_PROVIDER_RE = re.compile(r"/providers/([^/]+)/([^/]+)", re.IGNORECASE)

def parse_arm_id(arm_id: str) -> Tuple[Optional[str], Optional[str]]:
    """Extract (resource_group, provider_type) from ARM ID."""
    rg = None
    m_rg = ARM_RG_RE.search(arm_id)
    if m_rg:
        rg = m_rg.group(1)
    provider_type = None
    m_pt = ARM_PROVIDER_RE.search(arm_id)
    if m_pt:
        provider_type = m_pt.group(1) + "/" + m_pt.group(2)
    return rg, provider_type

def normalize_category(provider_type: Optional[str]) -> Tuple[str, str]:
    if not provider_type:
        return ("other", "unknown")
    provider, rtype = provider_type.split("/", 1)
    if "Storage" in provider:
        return "storage", rtype
    if "Sql" in provider or "DB" in provider:
        return "database", rtype
    if "Network" in provider:
        return "network", rtype
    if "KeyVault" in provider:
        return "keyvault", rtype
    if "Web" in provider:
        return "web", rtype
    if "Cdn" in provider:
        return "cdn", rtype
    if "OperationalInsights" in provider:
        return "monitoring", rtype
    return provider.lower(), rtype

# ---------- Enrichment Registry ----------

@dataclass
class EnrichmentCommand:
    """
    Defines how to build CLI arguments for enrichment.
    Supports --ids, --name/--resource-group, and --account-name patterns.
    """
    base_args: List[str]
    use_ids: bool = False
    use_name_rg: bool = False
    use_account_name: bool = False  # 👈 NEW
    key: Optional[str] = None
    require_rg: bool = True

    def build(self, res: Dict[str, Any]) -> List[str]:
        args = list(self.base_args)
        rid = res.get("id", "")
        name = res.get("name", "")
        rg, _ = parse_arm_id(rid)

        if self.use_ids:
            args.extend(["--ids", rid])
        elif self.use_account_name:  # 👈 for storage accounts
            args.extend(["--account-name", name])
            if self.require_rg and rg:
                args.extend(["--resource-group", rg])
        elif self.use_name_rg:
            args.extend(["--name", name])
            if self.require_rg and rg:
                args.extend(["--resource-group", rg])

        return args

# Enrichment command registry with correct Azure CLI syntax
ENRICHMENT_REGISTRY: Dict[str, List[EnrichmentCommand]] = {
    "Microsoft.Storage/storageAccounts": [
        EnrichmentCommand(["storage", "account", "show"], use_ids=True, key="account"),
        # ✅ safer, directly queries networkRuleSet through az query
        EnrichmentCommand(["storage", "account", "show", "--query", "networkRuleSet"], use_ids=True, key="networkRules"),
    ],
    "Microsoft.Sql/servers": [
        EnrichmentCommand(["sql", "server", "show"], use_ids=True, key="server"),
        EnrichmentCommand(["sql", "server", "firewall-rule", "list"], use_ids=True, key="firewallRules"),
    ],
    "Microsoft.KeyVault/vaults": [
        EnrichmentCommand(["keyvault", "show"], use_name_rg=True, key="vault"),
        EnrichmentCommand(["keyvault", "network-rule", "list"], use_name_rg=True, key="networkRules"),
    ],
    "Microsoft.Web/sites": [
        EnrichmentCommand(["webapp", "show"], use_ids=True, key="site"),
        EnrichmentCommand(["webapp", "config", "show"], use_ids=True, key="config"),
    ],
    "Microsoft.Network/frontdoorWebApplicationFirewallPolicies": [
        EnrichmentCommand(["network", "front-door", "waf-policy", "show"], use_ids=True, key="wafPolicy"),
    ],
}

# ---------- Discovery Core ----------

def enrich_one_resource(res: Dict[str, Any]) -> Dict[str, Any]:
    rid = res.get("id", "")
    rg, provider_type = parse_arm_id(rid)
    name = res.get("name", "")
    category, rtype = normalize_category(provider_type)
    enrichment = {}
    meta = {
        "category": category,
        "provider_type": provider_type,
        "rtype": rtype,
        "resource_group": rg,
        "collected_at": utc_now_iso(),
    }

    cmds = ENRICHMENT_REGISTRY.get(provider_type or "", [])
    for cmd in cmds:
        args = cmd.build(res)
        # Handle commands like diagnostic-settings that require a resource ID
        if "--resource" in args:
            args.append(rid)
        result = run_az(args)
        key = cmd.key or "-".join(args[:3])
        enrichment[key] = result

    enriched = {
        "name": name,
        "id": rid,
        "type": provider_type,
        "location": res.get("location"),
        "tags": res.get("tags", {}) or {},
        "kind": res.get("kind"),
        "sku": res.get("sku"),
        "properties": res.get("properties", {}) or {},
        "_meta": meta,
        "enrichment": enrichment,
    }
    return enriched

def list_subscriptions() -> List[Dict[str, Any]]:
    return run_az(["account", "list", "--all"]) or []

def set_subscription(sub_id: str) -> None:
    _ = run_az(["account", "set", "--subscription", sub_id])

def current_account() -> Dict[str, Any]:
    return run_az(["account", "show"]) or {}

def list_resources() -> List[Dict[str, Any]]:
    return run_az(["resource", "list"]) or []

def discover_subscription(sub: Dict[str, Any], max_workers: int) -> Dict[str, Any]:
    sid = sub.get("id")
    sname = sub.get("name")
    click.echo(f"🔍 Discovering subscription: {sname} ({sid})")
    set_subscription(sid)

    resources = list_resources()
    click.echo(f"   Found {len(resources)} resources.")
    services = defaultdict(lambda: defaultdict(list))

    with ThreadPoolExecutor(max_workers=max_workers) as pool:
        futures = [pool.submit(enrich_one_resource, r) for r in resources]
        for fut in as_completed(futures):
            try:
                enriched = fut.result()
                cat = enriched["_meta"]["category"]
                rtype = enriched["_meta"]["rtype"]
                services[cat][rtype].append(enriched)
            except Exception as exc:
                LOG.warning(f"⚠️ Enrichment failed: {exc}")

    return {
        "name": sname,
        "subscriptionId": sid,
        "discovered_at": utc_now_iso(),
        "services": services,
    }

def build_discovery(max_workers: int, only_subs: Optional[List[str]] = None) -> Dict[str, Any]:
    acct = current_account()
    tenant_id = acct.get("tenantId", "unknown")
    subs = [s for s in list_subscriptions() if not only_subs or s.get("id") in only_subs]

    discovery = {
        "provider": "azure",
        "tenant": tenant_id,
        "collector_version": "1.3.1",
        "collected_at": utc_now_iso(),
        "subscriptions": {},
        "_meta": {"cli_info": run_az(["version"])},
    }

    for sub in subs:
        sub_blob = discover_subscription(sub, max_workers)
        discovery["subscriptions"][sub_blob["subscriptionId"]] = sub_blob

    return discovery

# ---------- CLI ----------

@click.command()
@click.option("--output", "-o", type=click.Path(), default=None, help="Output file path")
@click.option("--max-workers", default=6, show_default=True, help="Number of parallel workers")
@click.option("--subscription", "-s", multiple=True, help="Limit to specific subscription IDs")
@click.option("-v", "--verbose", count=True, help="Increase verbosity (-v, -vv)")
def cli(output, max_workers, subscription, verbose):
    """Discover Azure resources and enrich them for security/compliance."""
    setup_logging(verbose)

    timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H%M%S")
    if not output:
        output = os.path.join(os.getcwd(), f"azure_discovery_{timestamp}.json")

    click.echo(f"🚀 Starting Azure discovery ({max_workers} workers)...")
    data = build_discovery(max_workers=max_workers, only_subs=subscription)

    os.makedirs(os.path.dirname(output), exist_ok=True)
    with open(output, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2)

    click.echo(f"✅ Discovery snapshot saved to {output}")

if __name__ == "__main__":
    sys.exit(cli())

