#!/usr/bin/python3

# SPDX-License-Identifier: MPL-2.0
# SPDX-FileCopyrightText: 2021-2025 Collabora Ltd.
# SPDX-FileCopyrightText: 2021-2025 Walter Lozano <walter.lozano@collabora.com>
# SPDX-FileCopyrightText: 2025 Dylan Aïssi <dylan.aissi@collabora.com>

import argparse
import json
import os
import sys

VERBOSE_IMAGE = 0
VERBOSE_PACKAGE = 1
VERBOSE_BINARY = 2
VERBOSE_SOURCE = 3

ERROR_NONE = 0
ERROR_ERROR = 1
ERROR_WARN = 2

PROBLEMATIC_LICENSES_ERRORS = []
PROBLEMATIC_LICENSES_WARNS = ["UNKNOWN", "", "NoInfoFound", "no-info-found"]


class BomChecker:
    def __init__(
        self, bom_file, verbose, error_level=ERROR_NONE, whitelisted_packages=""
    ):
        self.bom_file = bom_file
        self.verbose = verbose
        self.error_level = error_level
        self.error = False
        self.whitelisted_packages = whitelisted_packages

    def check_licenses(
        self, licenses, source_type="image", source_name="", whitelisted=False
    ):
        licenses = set(licenses)
        license_errors = licenses.intersection(PROBLEMATIC_LICENSES_ERRORS)
        license_warnings = licenses.intersection(PROBLEMATIC_LICENSES_WARNS)
        if license_errors:
            if self.error_level >= ERROR_ERROR and not whitelisted:
                self.error = True
            license_errors = " ".join(license_errors)
            print(
                f"ERROR on {source_type} {source_name} license {license_errors} found, whitelisted {whitelisted}"
            )
        if license_warnings:
            if self.error_level >= ERROR_WARN:
                self.error = True
            license_warnings = " ".join(license_warnings)
            print(
                f"WARNING on {source_type} {source_name} license {license_warnings} found, whitelisted {whitelisted}"
            )

    def get_whitelisted_packages(self):
        whitelisted_packages = []
        if not os.path.isfile(self.whitelisted_packages):
            return whitelisted_packages
        with open(self.whitelisted_packages) as wp:
            for line in wp.readlines():
                if line.startswith("#"):
                    continue
                whitelisted_packages.append(line.strip())

        return whitelisted_packages

    def check_bom(self):
        bom = json.load(self.bom_file)
        whitelisted_packages = self.get_whitelisted_packages()

        if (
            len(whitelisted_packages) == 0
            or "packages" not in bom
            or "packages" in bom
            and len(bom["packages"]) == 0
        ):
            self.check_licenses(bom["license"])

        if self.verbose >= VERBOSE_PACKAGE and "packages" in bom:
            for p in bom["packages"]:
                whitelisted = p["name"] in whitelisted_packages
                self.check_licenses(p["license"], "package", p["name"], whitelisted)

                if self.verbose >= VERBOSE_BINARY and "binaries" in p:
                    for b in p["binaries"]:
                        self.check_licenses(
                            b["license"], "binary", b["name"], whitelisted
                        )


def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("bom_file", type=open, help="BOM file to check")
    parser.add_argument(
        "-e",
        "--error-level",
        type=int,
        default=ERROR_NONE,
        help="type of error that triggers a return code unsuccessful 0: none , 1: error, 2: warning",
    )
    parser.add_argument(
        "-w",
        "--whitelisted-packages",
        default="",
        help="file containing a list of whitelisted packages",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        type=int,
        default=VERBOSE_IMAGE,
        help="verbose use in output 0: image, 1: package, 2: binary, 3: source",
    )

    args = parser.parse_args()

    bom_checker = BomChecker(
        args.bom_file, args.verbose, args.error_level, args.whitelisted_packages
    )
    bom_checker.check_bom()

    if bom_checker.error:
        print("BOM check has failed", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    main(sys.argv[1:])
