Source code for package_scan.core.threat_database

"""Threat database management for multi-ecosystem scanning"""

import csv
from collections import defaultdict
from pathlib import Path
from typing import Dict, Set, Optional, List

import click

from .threat_metadata import get_csv_reader_without_comments


[docs] class ThreatDatabase: """ Manages threat data from CSV files with multi-ecosystem support Supports: - Loading specific threats by name (e.g., 'sha1-Hulud', 'other-threat') - Loading all threats from threats/ directory - Loading custom CSV files CSV Format: ecosystem,name,version npm,left-pad,1.3.0 maven,org.apache.logging.log4j:log4j-core,2.14.1 pip,requests,2.25.1 gem,strong_migrations,0.7.9 """ def __init__(self, threats_dir: str = "threats"): self.threats_dir = Path(threats_dir) self.loaded_threats: List[str] = [] # Track loaded threat names # Structure: {ecosystem: {package_name: set(versions)}} self.threats: Dict[str, Dict[str, Set[str]]] = defaultdict(lambda: defaultdict(set)) self._is_loaded = False
[docs] def load_threats(self, threat_names: Optional[List[str]] = None, csv_file: Optional[str] = None) -> bool: """ Load threats by name or from custom CSV Args: threat_names: List of threat names to load (e.g., ['sha1-Hulud']). If None, loads all threats from threats/ directory. csv_file: Path to custom CSV file (overrides threat_names). Returns: True if at least one threat loaded successfully, False otherwise """ success = False if csv_file: # Load custom CSV file if self._load_csv(Path(csv_file), threat_name='custom'): success = True elif threat_names: # Load specific threats by name for threat_name in threat_names: csv_path = self.threats_dir / f"{threat_name}.csv" if self._load_csv(csv_path, threat_name=threat_name): success = True else: # Load all threats from directory if not self.threats_dir.exists(): click.echo(click.style( f"✗ Error: Threats directory not found: {self.threats_dir}", fg='red', bold=True), err=True) return False csv_files = sorted(self.threats_dir.glob("*.csv")) if not csv_files: click.echo(click.style( f"✗ Error: No threat CSV files found in {self.threats_dir}", fg='red', bold=True), err=True) return False for csv_path in csv_files: threat_name = csv_path.stem if self._load_csv(csv_path, threat_name=threat_name): success = True if success: self._is_loaded = True return success
def _load_csv(self, csv_path: Path, threat_name: str) -> bool: """ Load a single CSV file Args: csv_path: Path to CSV file threat_name: Name of the threat (for tracking) Returns: True if loaded successfully, False otherwise """ if not csv_path.exists(): click.echo(click.style(f"✗ Error: Threat CSV file not found: {csv_path}", fg='red', bold=True), err=True) return False try: # Use comment-filtered reader to skip # lines csv_content = get_csv_reader_without_comments(csv_path) reader = csv.DictReader(csv_content) headers = reader.fieldnames if not headers: click.echo(click.style(f"✗ Error: CSV file has no headers: {csv_path}", fg='red', bold=True), err=True) return False # Check for required headers if not ('ecosystem' in headers and 'name' in headers and 'version' in headers): click.echo(click.style( f"✗ Error: Invalid CSV format in {csv_path}. " f"Expected headers: 'ecosystem,name,version'. " f"Got: {','.join(headers)}", fg='red', bold=True), err=True) return False self._load_multi_ecosystem_format(reader) self.loaded_threats.append(threat_name) return True except UnicodeDecodeError as e: click.echo(click.style(f"✗ Error: CSV file encoding issue in {csv_path}: {e}", fg='red', bold=True), err=True) return False except Exception as e: click.echo(click.style(f"✗ Error loading {csv_path}: {e}", fg='red', bold=True), err=True) return False def _load_multi_ecosystem_format(self, reader): """Load CSV in multi-ecosystem format""" for row_num, row in enumerate(reader, start=2): # Start at 2 (header is line 1) try: ecosystem = row['ecosystem'].strip().lower() name = row['name'].strip() version = row['version'].strip() if not ecosystem or not name or not version: click.echo(click.style( f"⚠️ Warning: Skipping row {row_num} with empty fields: {row}", fg='yellow'), err=True) continue self.threats[ecosystem][name].add(version) except KeyError as e: click.echo(click.style( f"⚠️ Warning: Skipping row {row_num} with missing field {e}: {row}", fg='yellow'), err=True) continue
[docs] def get_compromised_versions(self, ecosystem: str, package_name: str) -> Set[str]: """ Get all compromised versions for a specific package in an ecosystem Args: ecosystem: Ecosystem name (npm, maven, pip, gem, etc.) package_name: Package identifier Returns: Set of compromised version strings """ if not self._is_loaded: return set() ecosystem = ecosystem.lower() return self.threats.get(ecosystem, {}).get(package_name, set())
[docs] def is_compromised(self, ecosystem: str, package_name: str, version: str) -> bool: """ Check if a specific package version is compromised Args: ecosystem: Ecosystem name package_name: Package identifier version: Version string Returns: True if compromised, False otherwise """ compromised_versions = self.get_compromised_versions(ecosystem, package_name) return version in compromised_versions
[docs] def get_all_packages(self, ecosystem: Optional[str] = None) -> Dict[str, Set[str]]: """ Get all compromised packages, optionally filtered by ecosystem Args: ecosystem: Optional ecosystem filter Returns: Dictionary mapping package names to sets of compromised versions """ if not self._is_loaded: return {} if ecosystem: ecosystem = ecosystem.lower() return dict(self.threats.get(ecosystem, {})) else: # Return all ecosystems merged (not recommended, use with caution) all_packages = defaultdict(set) for eco_threats in self.threats.values(): for pkg_name, versions in eco_threats.items(): all_packages[pkg_name].update(versions) return dict(all_packages)
[docs] def get_ecosystems(self) -> Set[str]: """ Get all ecosystems present in the threat database Returns: Set of ecosystem names """ if not self._is_loaded: return set() return set(self.threats.keys())
[docs] def get_loaded_threats(self) -> List[str]: """ Get list of loaded threat names Returns: List of threat names that were loaded """ return self.loaded_threats.copy()
[docs] def get_package_count(self, ecosystem: Optional[str] = None) -> int: """ Get count of unique packages in threat database Args: ecosystem: Optional ecosystem filter Returns: Number of unique packages """ if ecosystem: ecosystem = ecosystem.lower() return len(self.threats.get(ecosystem, {})) else: # Total across all ecosystems return sum(len(packages) for packages in self.threats.values())
[docs] def get_version_count(self, ecosystem: Optional[str] = None) -> int: """ Get count of compromised package versions Args: ecosystem: Optional ecosystem filter Returns: Total number of compromised versions """ if ecosystem: ecosystem = ecosystem.lower() eco_threats = self.threats.get(ecosystem, {}) return sum(len(versions) for versions in eco_threats.values()) else: # Total across all ecosystems total = 0 for eco_threats in self.threats.values(): total += sum(len(versions) for versions in eco_threats.values()) return total
[docs] def print_summary(self): """Print a summary of loaded threats""" if not self._is_loaded: click.echo(click.style("✗ Threat database not loaded", fg='red', bold=True)) return ecosystems = self.get_ecosystems() if not ecosystems: click.echo(click.style("⚠️ Warning: Threat database is empty", fg='yellow', bold=True)) return total_packages = self.get_package_count() total_versions = self.get_version_count() # Show loaded threats if self.loaded_threats: threat_list = ', '.join(self.loaded_threats) click.echo(click.style(f"✓ Loaded threats: {threat_list}", fg='green', bold=True)) click.echo(click.style(f"✓ Threat database: {total_packages} packages, {total_versions} versions", fg='green', bold=True)) if len(ecosystems) > 1: click.echo(click.style(f" Ecosystems: {', '.join(sorted(ecosystems))}", fg='cyan')) for ecosystem in sorted(ecosystems): pkg_count = self.get_package_count(ecosystem) ver_count = self.get_version_count(ecosystem) click.echo(click.style(f" • {ecosystem}: {pkg_count} packages, {ver_count} versions", fg='cyan', dim=True)) else: # Single ecosystem (likely legacy format) ecosystem = list(ecosystems)[0] click.echo(click.style(f" Ecosystem: {ecosystem}", fg='cyan'))