diff --git a/vulnerabilities/pipelines/v2_improvers/compute_package_risk.py b/vulnerabilities/pipelines/v2_improvers/compute_package_risk.py index ac7caa49d..9caaaeb95 100644 --- a/vulnerabilities/pipelines/v2_improvers/compute_package_risk.py +++ b/vulnerabilities/pipelines/v2_improvers/compute_package_risk.py @@ -7,7 +7,10 @@ # See https://aboutcode.org for more information about nexB OSS projects. # from aboutcode.pipeline import LoopProgress +from django.db.models import Prefetch +from django.db.models import Q +from vulnerabilities.models import AdvisorySeverity from vulnerabilities.models import AdvisoryV2 from vulnerabilities.models import PackageV2 from vulnerabilities.pipelines import VulnerableCodePipeline @@ -35,7 +38,15 @@ def steps(cls): def compute_and_store_vulnerability_risk_score(self): affected_advisories = ( AdvisoryV2.objects.filter(impacted_packages__affecting_packages__isnull=False) - .prefetch_related("references", "severities", "exploits") + .prefetch_related( + "references", + "severities", + "exploits", + Prefetch( + "related_advisory_severities", + queryset=AdvisoryV2.objects.prefetch_related("severities"), + ), + ) .distinct() ) @@ -50,10 +61,13 @@ def compute_and_store_vulnerability_risk_score(self): batch_size = 5000 for advisory in progress.iter(affected_advisories.iterator(chunk_size=batch_size)): - severities = advisory.severities.all() references = advisory.references.all() exploits = advisory.exploits.all() + severities = AdvisorySeverity.objects.filter( + Q(advisories=advisory) | Q(advisories__related_to_advisory_severities=advisory) + ).distinct() + weighted_severity, exploitability = compute_vulnerability_risk_factors( references=references, severities=severities, diff --git a/vulnerabilities/tests/pipelines/v2_improvers/test_relate_severities.py b/vulnerabilities/tests/pipelines/v2_improvers/test_relate_severities.py index 0c4c3e901..2dadbc679 100644 --- a/vulnerabilities/tests/pipelines/v2_improvers/test_relate_severities.py +++ b/vulnerabilities/tests/pipelines/v2_improvers/test_relate_severities.py @@ -8,7 +8,9 @@ # import pytest +from django.db.models import Q +from vulnerabilities.models import AdvisorySeverity from vulnerabilities.models import AdvisoryV2 from vulnerabilities.pipelines.v2_improvers.relate_severities import RelateSeveritiesPipeline from vulnerabilities.severity_systems import EPSS @@ -42,6 +44,10 @@ def test_relate_severities_by_advisory_id(): pipeline.relate_severities() assert base.related_advisory_severities.filter(id=severity_advisory.id).exists() + severities = AdvisorySeverity.objects.filter( + Q(advisories=base) | Q(advisories__related_to_advisory_severities=base) + ).distinct() + assert severities.filter(id=severity_advisory.severities.first().id).exists() @pytest.mark.django_db