From 8df04c0869a7da40684a5c0190e51f4cdc19d500 Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Wed, 2 Oct 2024 12:11:42 +0200 Subject: [PATCH] Revert "Revert "[ruby] Ignore "Throwaway" AST Structures (#4982)" (#4983)" This reverts commit 464480d7e41ac3b92628c03d04b433a2ae0606bc. --- .../AstForExpressionsCreator.scala | 42 +++++++++---------- .../astcreation/AstForFunctionsCreator.scala | 23 +++++++--- .../astcreation/AstForStatementsCreator.scala | 20 +++++---- .../rubysrc2cpg/querying/DoBlockTests.scala | 4 +- .../rubysrc2cpg/querying/MethodTests.scala | 17 ++++++++ .../src/main/scala/io/joern/x2cpg/Ast.scala | 27 ++++++++++-- 6 files changed, 91 insertions(+), 42 deletions(-) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index 112d856f8d2e..75675fe41dcc 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -157,18 +157,18 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { /** Attempts to extract a type from the base of a member call. */ protected def typeFromCallTarget(baseNode: RubyExpression): Option[String] = { - scope.lookupVariable(baseNode.text) match { - // fixme: This should be under type recovery logic - case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) - case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) - case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => decl.dynamicTypeHintFullName.headOption - case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty => - decl.dynamicTypeHintFullName.headOption + baseNode match { + case literal: LiteralExpr => Option(literal.typeFullName) case _ => - astForExpression(baseNode).nodes - .flatMap(_.properties.get(PropertyNames.TYPE_FULL_NAME).map(_.toString)) - .filterNot(_ == XDefines.Any) - .headOption + scope.lookupVariable(baseNode.text) match { + // fixme: This should be under type recovery logic + case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) + case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) + case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => decl.dynamicTypeHintFullName.headOption + case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty => + decl.dynamicTypeHintFullName.headOption + case _ => None + } } } @@ -296,7 +296,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val createAssignmentToTmp = !baseAstCache.contains(target) val tmpName = baseAstCache .updateWith(target) { - case Some(tmpName) => Option(tmpName) + case Some(tmpName) => + // TODO: Type ref nodes are automatically committed on creation, so if we have found a suitable cached AST, + // we want to clean this creation up. + Option(tmpName) case None => val tmpName = this.tmpGen.fresh val tmpGenLocal = NewLocal().name(tmpName).code(tmpName).typeFullName(Defines.Any) @@ -872,16 +875,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } private def astForMemberCallWithoutBlock(node: SimpleCall, memberAccess: MemberAccess): Ast = { - val receiverAst = astForFieldAccess(memberAccess) - val methodName = memberAccess.memberName - // TODO: Type recovery should potentially resolve this - val methodFullName = typeFromCallTarget(memberAccess.target) - .map(x => s"$x.$methodName") - .getOrElse(XDefines.DynamicCallUnknownFullName) - val argumentAsts = node.arguments.map(astForMethodCallArgument) - val call = - callNode(node, code(node), methodName, XDefines.DynamicCallUnknownFullName, DispatchTypes.DYNAMIC_DISPATCH) - .possibleTypes(IndexedSeq(methodFullName)) + val receiverAst = astForFieldAccess(memberAccess) + val methodName = memberAccess.memberName + val methodFullName = XDefines.DynamicCallUnknownFullName + val argumentAsts = node.arguments.map(astForMethodCallArgument) + val call = callNode(node, code(node), methodName, methodFullName, DispatchTypes.DYNAMIC_DISPATCH) callAst(call, argumentAsts, Some(receiverAst)) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala index 8da6426c5a0f..6383c90ed897 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala @@ -26,6 +26,11 @@ import scala.collection.mutable trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => + /** As expressions may be discarded, we cannot store closure ASTs in the diffgraph at the point of creation. We need + * to only write these at the end. + */ + protected val closureToRefs = mutable.Map.empty[RubyExpression, Seq[Ast]] + /** Creates method declaration related structures. * @param node * the node to create the AST structure from. @@ -194,8 +199,11 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th case _ => false }) - val methodRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewTypeRef => x } + val typeRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewTypeRef => x } + val astChildren = mutable.Buffer.empty[NewNode] + val refEdges = mutable.Buffer.empty[(NewNode, NewNode)] + val captureEdges = mutable.Buffer.empty[(NewNode, NewNode)] capturedLocalNodes .collect { case local: NewLocal => @@ -216,14 +224,17 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th ) // Create new local node for lambda, with corresponding REF edges to identifiers and closure binding - capturedBlockAst.root.foreach(rootBlock => diffGraph.addEdge(rootBlock, capturingLocal, EdgeTypes.AST)) - capturedIdentifiers.filter(_.name == name).foreach(i => diffGraph.addEdge(i, capturingLocal, EdgeTypes.REF)) - diffGraph.addEdge(closureBinding, capturedLocal, EdgeTypes.REF) + val _refEdges = + capturedIdentifiers.filter(_.name == name).map(i => i -> capturingLocal) :+ (closureBinding, capturedLocal) - methodRefOption.foreach(methodRef => diffGraph.addEdge(methodRef, closureBinding, EdgeTypes.CAPTURE)) + astChildren.addOne(capturingLocal) + refEdges.addAll(_refEdges.toList) + captureEdges.addAll(typeRefOption.map(typeRef => typeRef -> closureBinding).toList) } - capturedBlockAst + val astWithAstChildren = astChildren.foldLeft(capturedBlockAst) { case (ast, child) => ast.withChild(Ast(child)) } + val astWithRefEdges = refEdges.foldLeft(astWithAstChildren) { case (ast, (src, dst)) => ast.withRefEdge(src, dst) } + captureEdges.foldLeft(astWithRefEdges) { case (ast, (src, dst)) => ast.withCaptureEdge(src, dst) } } /** Creates the bindings between the method and its types. This is useful for resolving function pointers and imports. diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index 7c521927ceba..e08cfa04e06b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -92,16 +92,20 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t protected def astForDoBlock(block: Block & RubyExpression): Seq[Ast] = { // Create closure structures: [MethodDecl, TypeRef, MethodRef] - val methodName = nextClosureName() + if (closureToRefs.contains(block)) { + closureToRefs(block) + } else { + val methodName = nextClosureName() - val methodAstsWithRefs = block.body match { - case x: Block => - astForMethodDeclaration(x.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) - case _ => - astForMethodDeclaration(block.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) + val methodAstsWithRefs = block.body match { + case x: Block => + astForMethodDeclaration(x.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) + case _ => + astForMethodDeclaration(block.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) + } + closureToRefs.put(block, methodAstsWithRefs) + methodAstsWithRefs } - - methodAstsWithRefs } protected def astForReturnExpression(node: ReturnExpression): Ast = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala index e566e250d53f..0d678849a43e 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala @@ -401,7 +401,7 @@ class DoBlockTests extends RubyCode2CpgFixture { |""".stripMargin) inside(cpg.local.l) { - case jfsOutsideLocal :: hashInsideLocal :: jfsCapturedLocal :: tmp0 :: tmp1 :: Nil => + case jfsOutsideLocal :: hashInsideLocal :: tmp0 :: jfsCapturedLocal :: tmp1 :: Nil => jfsOutsideLocal.closureBindingId shouldBe None hashInsideLocal.closureBindingId shouldBe None jfsCapturedLocal.closureBindingId shouldBe Some("Test0.rb:
.get_pto_schedule.jfs") @@ -412,7 +412,7 @@ class DoBlockTests extends RubyCode2CpgFixture { } inside(cpg.method.isLambda.local.l) { - case hashLocal :: jfsLocal :: _ :: Nil => + case hashLocal :: _ :: jfsLocal :: Nil => hashLocal.closureBindingId shouldBe None jfsLocal.closureBindingId shouldBe Some("Test0.rb:
.get_pto_schedule.jfs") case xs => fail(s"Expected 3 locals in lambda, got ${xs.code.mkString(",")}") diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala index 4586866b496c..e3c35ce9c92f 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala @@ -984,4 +984,21 @@ class MethodTests extends RubyCode2CpgFixture { } } } + + "lambdas as arguments to a long chained call" should { + val cpg = code(""" + |def foo(xs, total_ys, hex_values) + | xs.map.with_index { |f, i| [f / total_ys, hex_values[i]] } # 1 + | .sort_by { |r| -r[0] } # 2 + | .reject { |r| r[1].size == 8 && r[1].end_with?('00') } # 3 + | .map { |r| Foo::Bar::Baz.new(*r[1][0..5].scan(/../).map { |c| c.to_i(16) }) } # 4 & 5 + | .slice(0, quantity) + | end + |""".stripMargin) + + "not write lambda nodes that are already assigned to some temp variable" in { + cpg.typeRef.typeFullName(".*Proc").size shouldBe 5 + cpg.typeRef.whereNot(_.astParent).size shouldBe 0 + } + } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala index 0788a8734691..b9f8bed7e67c 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala @@ -49,6 +49,10 @@ object Ast { ast.bindsEdges.foreach { edge => diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS) } + + ast.captureEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CAPTURE) + } } def neighbourValidation(src: NewNode, dst: NewNode, edge: String)(implicit @@ -92,7 +96,8 @@ case class Ast( refEdges: collection.Seq[AstEdge] = Vector.empty, bindsEdges: collection.Seq[AstEdge] = Vector.empty, receiverEdges: collection.Seq[AstEdge] = Vector.empty, - argEdges: collection.Seq[AstEdge] = Vector.empty + argEdges: collection.Seq[AstEdge] = Vector.empty, + captureEdges: collection.Seq[AstEdge] = Vector.empty )(implicit withSchemaValidation: ValidationMode = ValidationMode.Disabled) { def root: Option[NewNode] = nodes.headOption @@ -114,7 +119,8 @@ case class Ast( argEdges = argEdges ++ other.argEdges, receiverEdges = receiverEdges ++ other.receiverEdges, refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges + bindsEdges = bindsEdges ++ other.bindsEdges, + captureEdges = captureEdges ++ other.captureEdges ) } @@ -126,7 +132,8 @@ case class Ast( argEdges = argEdges ++ other.argEdges, receiverEdges = receiverEdges ++ other.receiverEdges, refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges + bindsEdges = bindsEdges ++ other.bindsEdges, + captureEdges = captureEdges ++ other.captureEdges ) } @@ -217,6 +224,16 @@ case class Ast( this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _))) } + def withCaptureEdge(src: NewNode, dst: NewNode): Ast = { + Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE) + this.copy(captureEdges = captureEdges ++ List(AstEdge(src, dst))) + } + + def withCaptureEdges(src: NewNode, dsts: Seq[NewNode]): Ast = { + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE)) + this.copy(captureEdges = captureEdges ++ dsts.map(AstEdge(src, _))) + } + /** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and `argumentIndex` * fields of the new root node are set to `order`. If `replacementNode` is set, then this replaces `node` in the new * copy. @@ -250,6 +267,7 @@ case class Ast( val newRefEdges = refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) val newBindsEdges = bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) val newReceiverEdges = receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newCaptureEdges = captureEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) Ast(newNode) .copy( @@ -257,7 +275,8 @@ case class Ast( conditionEdges = newConditionEdges, refEdges = newRefEdges, bindsEdges = newBindsEdges, - receiverEdges = newReceiverEdges + receiverEdges = newReceiverEdges, + captureEdges = newCaptureEdges ) .withChildren(newChildren) }