Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff: add and fix some SIM rules #10926

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions dojo/api_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,7 @@ def to_representation(self, value):
if not isinstance(value, RequestResponseDict):
if not isinstance(value, list):
# this will trigger when a queryset is found...
if self.order_by:
burps = value.all().order_by(*self.order_by)
else:
burps = value.all()
burps = value.all().order_by(*self.order_by) if self.order_by else value.all()
value = [
{
"request": burp.get_request(),
Expand Down Expand Up @@ -552,10 +549,7 @@ def update(self, instance, validated_data):
return instance

def create(self, validated_data):
if "password" in validated_data:
password = validated_data.pop("password")
else:
password = None
password = validated_data.pop("password", None)

new_configuration_permissions = None
if (
Expand All @@ -581,10 +575,7 @@ def create(self, validated_data):
return user

def validate(self, data):
if self.instance is not None:
instance_is_superuser = self.instance.is_superuser
else:
instance_is_superuser = False
instance_is_superuser = self.instance.is_superuser if self.instance is not None else False
data_is_superuser = data.get("is_superuser", False)
if not self.context["request"].user.is_superuser and (
instance_is_superuser or data_is_superuser
Expand Down Expand Up @@ -1217,7 +1208,7 @@ class Meta:

def validate(self, data):

if not self.context["request"].method == "PATCH":
if self.context["request"].method != "PATCH":
if "product" not in data:
msg = "Product is required"
raise serializers.ValidationError(msg)
Expand Down Expand Up @@ -2248,7 +2239,7 @@ def setup_common_context(self, data: dict) -> dict:
"""
context = dict(data)
# update some vars
context["scan"] = data.pop("file", None)
context["scan"] = data.pop("file")

if context.get("auto_create_context"):
environment = Development_Environment.objects.get_or_create(name=data.get("environment", "Development"))[0]
Expand Down Expand Up @@ -2293,7 +2284,7 @@ def setup_common_context(self, data: dict) -> dict:

# engagement end date was not being used at all and so target_end would also turn into None
# in this case, do not want to change target_end unless engagement_end exists
eng_end_date = context.get("engagement_end_date", None)
eng_end_date = context.get("engagement_end_date")
if eng_end_date:
context["target_end"] = context.get("engagement_end_date")

Expand Down
49 changes: 10 additions & 39 deletions dojo/api_v2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,9 +1481,7 @@ def metadata(self, request, pk=None):
return self._get_metadata(request, finding)
if request.method == "POST":
return self._add_metadata(request, finding)
if request.method == "PUT":
return self._edit_metadata(request, finding)
if request.method == "PATCH":
if request.method in ["PUT", "PATCH"]:
return self._edit_metadata(request, finding)
if request.method == "DELETE":
return self._remove_metadata(request, finding)
Expand Down Expand Up @@ -2892,24 +2890,15 @@ def report_generate(request, obj, options):
if eng.name:
engagement_name = eng.name
engagement_target_start = eng.target_start
if eng.target_end:
engagement_target_end = eng.target_end
else:
engagement_target_end = "ongoing"
engagement_target_end = eng.target_end or "ongoing"
if eng.test_set.all():
for t in eng.test_set.all():
test_type_name = t.test_type.name
if t.environment:
test_environment_name = t.environment.name
test_target_start = t.target_start
if t.target_end:
test_target_end = t.target_end
else:
test_target_end = "ongoing"
if eng.test_strategy:
test_strategy_ref = eng.test_strategy
else:
test_strategy_ref = ""
test_target_end = t.target_end or "ongoing"
test_strategy_ref = eng.test_strategy or ""
total_findings = len(findings.qs.all())

elif type(obj).__name__ == "Product":
Expand All @@ -2919,59 +2908,41 @@ def report_generate(request, obj, options):
if eng.name:
engagement_name = eng.name
engagement_target_start = eng.target_start
if eng.target_end:
engagement_target_end = eng.target_end
else:
engagement_target_end = "ongoing"
engagement_target_end = eng.target_end or "ongoing"

if eng.test_set.all():
for t in eng.test_set.all():
test_type_name = t.test_type.name
if t.environment:
test_environment_name = t.environment.name
if eng.test_strategy:
test_strategy_ref = eng.test_strategy
else:
test_strategy_ref = ""
test_strategy_ref = eng.test_strategy or ""
total_findings = len(findings.qs.all())

elif type(obj).__name__ == "Engagement":
eng = obj
if eng.name:
engagement_name = eng.name
engagement_target_start = eng.target_start
if eng.target_end:
engagement_target_end = eng.target_end
else:
engagement_target_end = "ongoing"
engagement_target_end = eng.target_end or "ongoing"

if eng.test_set.all():
for t in eng.test_set.all():
test_type_name = t.test_type.name
if t.environment:
test_environment_name = t.environment.name
if eng.test_strategy:
test_strategy_ref = eng.test_strategy
else:
test_strategy_ref = ""
test_strategy_ref = eng.test_strategy or ""
total_findings = len(findings.qs.all())

elif type(obj).__name__ == "Test":
t = obj
test_type_name = t.test_type.name
test_target_start = t.target_start
if t.target_end:
test_target_end = t.target_end
else:
test_target_end = "ongoing"
test_target_end = t.target_end or "ongoing"
total_findings = len(findings.qs.all())
if t.engagement.name:
engagement_name = t.engagement.name
engagement_target_start = t.engagement.target_start
if t.engagement.target_end:
engagement_target_end = t.engagement.target_end
else:
engagement_target_end = "ongoing"
engagement_target_end = t.engagement.target_end or "ongoing"
else:
pass # do nothing

Expand Down
31 changes: 9 additions & 22 deletions dojo/authorization/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def user_has_permission(user, obj, permission):
if user.is_superuser:
return True

if isinstance(obj, Product_Type) or isinstance(obj, Product):
if isinstance(obj, Product_Type | Product):
# Global roles are only relevant for product types, products and their
# dependent objects
if user_has_global_permission(user, permission):
Expand Down Expand Up @@ -97,13 +97,9 @@ def user_has_permission(user, obj, permission):
and permission in Permissions.get_test_permissions()
):
return user_has_permission(user, obj.engagement.product, permission)
if (
isinstance(obj, Finding) or isinstance(obj, Stub_Finding)
) and permission in Permissions.get_finding_permissions():
return user_has_permission(
user, obj.test.engagement.product, permission,
)
if (
if ((
isinstance(obj, Finding | Stub_Finding)
) and permission in Permissions.get_finding_permissions()) or (
isinstance(obj, Finding_Group)
and permission in Permissions.get_finding_group_permissions()
):
Expand All @@ -113,23 +109,17 @@ def user_has_permission(user, obj, permission):
if (
isinstance(obj, Endpoint)
and permission in Permissions.get_endpoint_permissions()
):
return user_has_permission(user, obj.product, permission)
if (
) or (
isinstance(obj, Languages)
and permission in Permissions.get_language_permissions()
):
return user_has_permission(user, obj.product, permission)
if (
) or ((
isinstance(obj, App_Analysis)
and permission in Permissions.get_technology_permissions()
):
return user_has_permission(user, obj.product, permission)
if (
) or (
isinstance(obj, Product_API_Scan_Configuration)
and permission
in Permissions.get_product_api_scan_configuration_permissions()
):
)):
return user_has_permission(user, obj.product, permission)
if (
isinstance(obj, Product_Type_Member)
Expand Down Expand Up @@ -351,10 +341,7 @@ def get_product_groups_dict(user):
.select_related("role")
.filter(group__users=user)
):
if pg_dict.get(product_group.product.id) is None:
pgu_list = []
else:
pgu_list = pg_dict[product_group.product.id]
pgu_list = [] if pg_dict.get(product_group.product.id) is None else pg_dict[product_group.product.id]
pgu_list.append(product_group)
pg_dict[product_group.product.id] = pgu_list
return pg_dict
Expand Down
5 changes: 2 additions & 3 deletions dojo/benchmark/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging

from crum import get_current_user
Expand Down Expand Up @@ -37,10 +38,8 @@ def add_benchmark(queryset, product):
benchmark_product.control = requirement
requirements.append(benchmark_product)

try:
with contextlib.suppress(Exception):
Benchmark_Product.objects.bulk_create(requirements)
except Exception:
pass


@user_is_authorized(Product, Permissions.Benchmark_Edit, "pid")
Expand Down
5 changes: 1 addition & 4 deletions dojo/cred/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ def get_authorized_cred_mappings(permission, queryset=None):
if user is None:
return Cred_Mapping.objects.none()

if queryset is None:
cred_mappings = Cred_Mapping.objects.all().order_by("id")
else:
cred_mappings = queryset
cred_mappings = Cred_Mapping.objects.all().order_by("id") if queryset is None else queryset

if user.is_superuser:
return cred_mappings
Expand Down
5 changes: 2 additions & 3 deletions dojo/cred/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging

from django.contrib import messages
Expand Down Expand Up @@ -585,10 +586,8 @@ def new_cred_finding(request, fid):
@user_is_authorized(Cred_User, Permissions.Credential_Delete, "ttid")
def delete_cred_controller(request, destination_url, id, ttid):
cred = None
try:
with contextlib.suppress(Exception):
cred = Cred_Mapping.objects.get(pk=ttid)
except:
pass
if request.method == "POST":
tform = CredMappingForm(request.POST, instance=cred)
message = ""
Expand Down
10 changes: 2 additions & 8 deletions dojo/endpoint/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ def get_authorized_endpoints(permission, queryset=None, user=None):
if user is None:
return Endpoint.objects.none()

if queryset is None:
endpoints = Endpoint.objects.all().order_by("id")
else:
endpoints = queryset
endpoints = Endpoint.objects.all().order_by("id") if queryset is None else queryset

if user.is_superuser:
return endpoints
Expand Down Expand Up @@ -66,10 +63,7 @@ def get_authorized_endpoint_status(permission, queryset=None, user=None):
if user is None:
return Endpoint_Status.objects.none()

if queryset is None:
endpoint_status = Endpoint_Status.objects.all().order_by("id")
else:
endpoint_status = queryset
endpoint_status = Endpoint_Status.objects.all().order_by("id") if queryset is None else queryset

if user.is_superuser:
return endpoint_status
Expand Down
41 changes: 11 additions & 30 deletions dojo/endpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,11 @@
def endpoint_filter(**kwargs):
qs = Endpoint.objects.all()

if kwargs.get("protocol"):
qs = qs.filter(protocol__iexact=kwargs["protocol"])
else:
qs = qs.filter(protocol__isnull=True)
qs = qs.filter(protocol__iexact=kwargs["protocol"]) if kwargs.get("protocol") else qs.filter(protocol__isnull=True)

if kwargs.get("userinfo"):
qs = qs.filter(userinfo__exact=kwargs["userinfo"])
else:
qs = qs.filter(userinfo__isnull=True)
qs = qs.filter(userinfo__exact=kwargs["userinfo"]) if kwargs.get("userinfo") else qs.filter(userinfo__isnull=True)

if kwargs.get("host"):
qs = qs.filter(host__iexact=kwargs["host"])
else:
qs = qs.filter(host__isnull=True)
qs = qs.filter(host__iexact=kwargs["host"]) if kwargs.get("host") else qs.filter(host__isnull=True)

if kwargs.get("port"):
if (kwargs.get("protocol")) and \
Expand All @@ -48,20 +39,11 @@ def endpoint_filter(**kwargs):
else:
qs = qs.filter(port__isnull=True)

if kwargs.get("path"):
qs = qs.filter(path__exact=kwargs["path"])
else:
qs = qs.filter(path__isnull=True)
qs = qs.filter(path__exact=kwargs["path"]) if kwargs.get("path") else qs.filter(path__isnull=True)

if kwargs.get("query"):
qs = qs.filter(query__exact=kwargs["query"])
else:
qs = qs.filter(query__isnull=True)
qs = qs.filter(query__exact=kwargs["query"]) if kwargs.get("query") else qs.filter(query__isnull=True)

if kwargs.get("fragment"):
qs = qs.filter(fragment__exact=kwargs["fragment"])
else:
qs = qs.filter(fragment__isnull=True)
qs = qs.filter(fragment__exact=kwargs["fragment"]) if kwargs.get("fragment") else qs.filter(fragment__isnull=True)

if kwargs.get("product"):
qs = qs.filter(product__exact=kwargs["product"])
Expand Down Expand Up @@ -267,12 +249,11 @@ def validate_endpoints_to_add(endpoints_to_add):
endpoints = endpoints_to_add.split()
for endpoint in endpoints:
try:
if "://" in endpoint: # is it full uri?
endpoint_ins = Endpoint.from_uri(endpoint) # from_uri validate URI format + split to components
else:
# from_uri parse any '//localhost', '//127.0.0.1:80', '//foo.bar/path' correctly
# format doesn't follow RFC 3986 but users use it
endpoint_ins = Endpoint.from_uri("//" + endpoint)
# is it full uri?
# 1. from_uri validate URI format + split to components
# 2. from_uri parse any '//localhost', '//127.0.0.1:80', '//foo.bar/path' correctly
# format doesn't follow RFC 3986 but users use it
endpoint_ins = Endpoint.from_uri(endpoint) if "://" in endpoint else Endpoint.from_uri("//" + endpoint)
endpoint_ins.clean()
endpoint_list.append([
endpoint_ins.protocol,
Expand Down
5 changes: 1 addition & 4 deletions dojo/endpoint/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ def process_endpoints_view(request, host_view=False, vulnerable=False):

paged_endpoints = get_page_items(request, endpoints.qs, 25)

if vulnerable:
view_name = "Vulnerable"
else:
view_name = "All"
view_name = "Vulnerable" if vulnerable else "All"

if host_view:
view_name += " Hosts"
Expand Down
Loading