diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index 75675fe41dcc..112d856f8d2e 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -157,18 +157,18 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { /** Attempts to extract a type from the base of a member call. */ protected def typeFromCallTarget(baseNode: RubyExpression): Option[String] = { - baseNode match { - case literal: LiteralExpr => Option(literal.typeFullName) + scope.lookupVariable(baseNode.text) match { + // fixme: This should be under type recovery logic + case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) + case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) + case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => decl.dynamicTypeHintFullName.headOption + case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty => + decl.dynamicTypeHintFullName.headOption case _ => - scope.lookupVariable(baseNode.text) match { - // fixme: This should be under type recovery logic - case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) - case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => Option(decl.typeFullName) - case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => decl.dynamicTypeHintFullName.headOption - case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty => - decl.dynamicTypeHintFullName.headOption - case _ => None - } + astForExpression(baseNode).nodes + .flatMap(_.properties.get(PropertyNames.TYPE_FULL_NAME).map(_.toString)) + .filterNot(_ == XDefines.Any) + .headOption } } @@ -296,10 +296,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val createAssignmentToTmp = !baseAstCache.contains(target) val tmpName = baseAstCache .updateWith(target) { - case Some(tmpName) => - // TODO: Type ref nodes are automatically committed on creation, so if we have found a suitable cached AST, - // we want to clean this creation up. - Option(tmpName) + case Some(tmpName) => Option(tmpName) case None => val tmpName = this.tmpGen.fresh val tmpGenLocal = NewLocal().name(tmpName).code(tmpName).typeFullName(Defines.Any) @@ -875,11 +872,16 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { } private def astForMemberCallWithoutBlock(node: SimpleCall, memberAccess: MemberAccess): Ast = { - val receiverAst = astForFieldAccess(memberAccess) - val methodName = memberAccess.memberName - val methodFullName = XDefines.DynamicCallUnknownFullName - val argumentAsts = node.arguments.map(astForMethodCallArgument) - val call = callNode(node, code(node), methodName, methodFullName, DispatchTypes.DYNAMIC_DISPATCH) + val receiverAst = astForFieldAccess(memberAccess) + val methodName = memberAccess.memberName + // TODO: Type recovery should potentially resolve this + val methodFullName = typeFromCallTarget(memberAccess.target) + .map(x => s"$x.$methodName") + .getOrElse(XDefines.DynamicCallUnknownFullName) + val argumentAsts = node.arguments.map(astForMethodCallArgument) + val call = + callNode(node, code(node), methodName, XDefines.DynamicCallUnknownFullName, DispatchTypes.DYNAMIC_DISPATCH) + .possibleTypes(IndexedSeq(methodFullName)) callAst(call, argumentAsts, Some(receiverAst)) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala index 6383c90ed897..8da6426c5a0f 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala @@ -26,11 +26,6 @@ import scala.collection.mutable trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - /** As expressions may be discarded, we cannot store closure ASTs in the diffgraph at the point of creation. We need - * to only write these at the end. - */ - protected val closureToRefs = mutable.Map.empty[RubyExpression, Seq[Ast]] - /** Creates method declaration related structures. * @param node * the node to create the AST structure from. @@ -199,11 +194,8 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th case _ => false }) - val typeRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewTypeRef => x } + val methodRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewTypeRef => x } - val astChildren = mutable.Buffer.empty[NewNode] - val refEdges = mutable.Buffer.empty[(NewNode, NewNode)] - val captureEdges = mutable.Buffer.empty[(NewNode, NewNode)] capturedLocalNodes .collect { case local: NewLocal => @@ -224,17 +216,14 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th ) // Create new local node for lambda, with corresponding REF edges to identifiers and closure binding - val _refEdges = - capturedIdentifiers.filter(_.name == name).map(i => i -> capturingLocal) :+ (closureBinding, capturedLocal) + capturedBlockAst.root.foreach(rootBlock => diffGraph.addEdge(rootBlock, capturingLocal, EdgeTypes.AST)) + capturedIdentifiers.filter(_.name == name).foreach(i => diffGraph.addEdge(i, capturingLocal, EdgeTypes.REF)) + diffGraph.addEdge(closureBinding, capturedLocal, EdgeTypes.REF) - astChildren.addOne(capturingLocal) - refEdges.addAll(_refEdges.toList) - captureEdges.addAll(typeRefOption.map(typeRef => typeRef -> closureBinding).toList) + methodRefOption.foreach(methodRef => diffGraph.addEdge(methodRef, closureBinding, EdgeTypes.CAPTURE)) } - val astWithAstChildren = astChildren.foldLeft(capturedBlockAst) { case (ast, child) => ast.withChild(Ast(child)) } - val astWithRefEdges = refEdges.foldLeft(astWithAstChildren) { case (ast, (src, dst)) => ast.withRefEdge(src, dst) } - captureEdges.foldLeft(astWithRefEdges) { case (ast, (src, dst)) => ast.withCaptureEdge(src, dst) } + capturedBlockAst } /** Creates the bindings between the method and its types. This is useful for resolving function pointers and imports. diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index e08cfa04e06b..7c521927ceba 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -92,20 +92,16 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t protected def astForDoBlock(block: Block & RubyExpression): Seq[Ast] = { // Create closure structures: [MethodDecl, TypeRef, MethodRef] - if (closureToRefs.contains(block)) { - closureToRefs(block) - } else { - val methodName = nextClosureName() + val methodName = nextClosureName() - val methodAstsWithRefs = block.body match { - case x: Block => - astForMethodDeclaration(x.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) - case _ => - astForMethodDeclaration(block.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) - } - closureToRefs.put(block, methodAstsWithRefs) - methodAstsWithRefs + val methodAstsWithRefs = block.body match { + case x: Block => + astForMethodDeclaration(x.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) + case _ => + astForMethodDeclaration(block.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true) } + + methodAstsWithRefs } protected def astForReturnExpression(node: ReturnExpression): Ast = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala index 0d678849a43e..e566e250d53f 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala @@ -401,7 +401,7 @@ class DoBlockTests extends RubyCode2CpgFixture { |""".stripMargin) inside(cpg.local.l) { - case jfsOutsideLocal :: hashInsideLocal :: tmp0 :: jfsCapturedLocal :: tmp1 :: Nil => + case jfsOutsideLocal :: hashInsideLocal :: jfsCapturedLocal :: tmp0 :: tmp1 :: Nil => jfsOutsideLocal.closureBindingId shouldBe None hashInsideLocal.closureBindingId shouldBe None jfsCapturedLocal.closureBindingId shouldBe Some("Test0.rb:
.get_pto_schedule.jfs") @@ -412,7 +412,7 @@ class DoBlockTests extends RubyCode2CpgFixture { } inside(cpg.method.isLambda.local.l) { - case hashLocal :: _ :: jfsLocal :: Nil => + case hashLocal :: jfsLocal :: _ :: Nil => hashLocal.closureBindingId shouldBe None jfsLocal.closureBindingId shouldBe Some("Test0.rb:
.get_pto_schedule.jfs") case xs => fail(s"Expected 3 locals in lambda, got ${xs.code.mkString(",")}") diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala index e3c35ce9c92f..4586866b496c 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodTests.scala @@ -984,21 +984,4 @@ class MethodTests extends RubyCode2CpgFixture { } } } - - "lambdas as arguments to a long chained call" should { - val cpg = code(""" - |def foo(xs, total_ys, hex_values) - | xs.map.with_index { |f, i| [f / total_ys, hex_values[i]] } # 1 - | .sort_by { |r| -r[0] } # 2 - | .reject { |r| r[1].size == 8 && r[1].end_with?('00') } # 3 - | .map { |r| Foo::Bar::Baz.new(*r[1][0..5].scan(/../).map { |c| c.to_i(16) }) } # 4 & 5 - | .slice(0, quantity) - | end - |""".stripMargin) - - "not write lambda nodes that are already assigned to some temp variable" in { - cpg.typeRef.typeFullName(".*Proc").size shouldBe 5 - cpg.typeRef.whereNot(_.astParent).size shouldBe 0 - } - } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala index b9f8bed7e67c..0788a8734691 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala @@ -49,10 +49,6 @@ object Ast { ast.bindsEdges.foreach { edge => diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS) } - - ast.captureEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CAPTURE) - } } def neighbourValidation(src: NewNode, dst: NewNode, edge: String)(implicit @@ -96,8 +92,7 @@ case class Ast( refEdges: collection.Seq[AstEdge] = Vector.empty, bindsEdges: collection.Seq[AstEdge] = Vector.empty, receiverEdges: collection.Seq[AstEdge] = Vector.empty, - argEdges: collection.Seq[AstEdge] = Vector.empty, - captureEdges: collection.Seq[AstEdge] = Vector.empty + argEdges: collection.Seq[AstEdge] = Vector.empty )(implicit withSchemaValidation: ValidationMode = ValidationMode.Disabled) { def root: Option[NewNode] = nodes.headOption @@ -119,8 +114,7 @@ case class Ast( argEdges = argEdges ++ other.argEdges, receiverEdges = receiverEdges ++ other.receiverEdges, refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges, - captureEdges = captureEdges ++ other.captureEdges + bindsEdges = bindsEdges ++ other.bindsEdges ) } @@ -132,8 +126,7 @@ case class Ast( argEdges = argEdges ++ other.argEdges, receiverEdges = receiverEdges ++ other.receiverEdges, refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges, - captureEdges = captureEdges ++ other.captureEdges + bindsEdges = bindsEdges ++ other.bindsEdges ) } @@ -224,16 +217,6 @@ case class Ast( this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _))) } - def withCaptureEdge(src: NewNode, dst: NewNode): Ast = { - Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE) - this.copy(captureEdges = captureEdges ++ List(AstEdge(src, dst))) - } - - def withCaptureEdges(src: NewNode, dsts: Seq[NewNode]): Ast = { - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE)) - this.copy(captureEdges = captureEdges ++ dsts.map(AstEdge(src, _))) - } - /** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and `argumentIndex` * fields of the new root node are set to `order`. If `replacementNode` is set, then this replaces `node` in the new * copy. @@ -267,7 +250,6 @@ case class Ast( val newRefEdges = refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) val newBindsEdges = bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) val newReceiverEdges = receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - val newCaptureEdges = captureEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) Ast(newNode) .copy( @@ -275,8 +257,7 @@ case class Ast( conditionEdges = newConditionEdges, refEdges = newRefEdges, bindsEdges = newBindsEdges, - receiverEdges = newReceiverEdges, - captureEdges = newCaptureEdges + receiverEdges = newReceiverEdges ) .withChildren(newChildren) }