Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ruby] Re-implemented "Ignore "Throwaway" AST Structures (#4982)" #4985

Merged
merged 2 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,18 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
/** Attempts to extract a type from the base of a member call.
*/
protected def typeFromCallTarget(baseNode: RubyExpression): Option[String] = {
scope.lookupVariable(baseNode.text) match {
// fixme: This should be under type recovery logic
case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => Option(decl.typeFullName)
case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => Option(decl.typeFullName)
case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => decl.dynamicTypeHintFullName.headOption
case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty =>
decl.dynamicTypeHintFullName.headOption
baseNode match {
case literal: LiteralExpr => Option(literal.typeFullName)
case _ =>
astForExpression(baseNode).nodes
.flatMap(_.properties.get(PropertyNames.TYPE_FULL_NAME).map(_.toString))
.filterNot(_ == XDefines.Any)
.headOption
scope.lookupVariable(baseNode.text) match {
// fixme: This should be under type recovery logic
case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => Option(decl.typeFullName)
case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => Option(decl.typeFullName)
case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => decl.dynamicTypeHintFullName.headOption
case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty =>
decl.dynamicTypeHintFullName.headOption
case _ => None
}
}
}

Expand Down Expand Up @@ -296,7 +296,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val createAssignmentToTmp = !baseAstCache.contains(target)
val tmpName = baseAstCache
.updateWith(target) {
case Some(tmpName) => Option(tmpName)
case Some(tmpName) =>
// TODO: Type ref nodes are automatically committed on creation, so if we have found a suitable cached AST,
// we want to clean this creation up.
Option(tmpName)
case None =>
val tmpName = this.tmpGen.fresh
val tmpGenLocal = NewLocal().name(tmpName).code(tmpName).typeFullName(Defines.Any)
Expand Down Expand Up @@ -872,16 +875,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}

private def astForMemberCallWithoutBlock(node: SimpleCall, memberAccess: MemberAccess): Ast = {
val receiverAst = astForFieldAccess(memberAccess)
val methodName = memberAccess.memberName
// TODO: Type recovery should potentially resolve this
val methodFullName = typeFromCallTarget(memberAccess.target)
.map(x => s"$x.$methodName")
.getOrElse(XDefines.DynamicCallUnknownFullName)
val argumentAsts = node.arguments.map(astForMethodCallArgument)
val call =
callNode(node, code(node), methodName, XDefines.DynamicCallUnknownFullName, DispatchTypes.DYNAMIC_DISPATCH)
.possibleTypes(IndexedSeq(methodFullName))
val receiverAst = astForFieldAccess(memberAccess)
val methodName = memberAccess.memberName
val methodFullName = XDefines.DynamicCallUnknownFullName
val argumentAsts = node.arguments.map(astForMethodCallArgument)
val call = callNode(node, code(node), methodName, methodFullName, DispatchTypes.DYNAMIC_DISPATCH)

callAst(call, argumentAsts, Some(receiverAst))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ import scala.collection.mutable

trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

/** As expressions may be discarded, we cannot store closure ASTs in the diffgraph at the point of creation. So we
* assume every reference to this map means that the closure AST was successfully propagated.
*/
protected val closureToRefs = mutable.Map.empty[RubyExpression, Seq[NewNode]]

/** Creates method declaration related structures.
* @param node
* the node to create the AST structure from.
Expand Down Expand Up @@ -194,8 +199,11 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
case _ => false
})

val methodRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewTypeRef => x }
val typeRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewTypeRef => x }

val astChildren = mutable.Buffer.empty[NewNode]
val refEdges = mutable.Buffer.empty[(NewNode, NewNode)]
val captureEdges = mutable.Buffer.empty[(NewNode, NewNode)]
capturedLocalNodes
.collect {
case local: NewLocal =>
Expand All @@ -216,14 +224,17 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
)

// Create new local node for lambda, with corresponding REF edges to identifiers and closure binding
capturedBlockAst.root.foreach(rootBlock => diffGraph.addEdge(rootBlock, capturingLocal, EdgeTypes.AST))
capturedIdentifiers.filter(_.name == name).foreach(i => diffGraph.addEdge(i, capturingLocal, EdgeTypes.REF))
diffGraph.addEdge(closureBinding, capturedLocal, EdgeTypes.REF)
val _refEdges =
capturedIdentifiers.filter(_.name == name).map(i => i -> capturingLocal) :+ (closureBinding, capturedLocal)

methodRefOption.foreach(methodRef => diffGraph.addEdge(methodRef, closureBinding, EdgeTypes.CAPTURE))
astChildren.addOne(capturingLocal)
refEdges.addAll(_refEdges.toList)
captureEdges.addAll(typeRefOption.map(typeRef => typeRef -> closureBinding).toList)
}

capturedBlockAst
val astWithAstChildren = astChildren.foldLeft(capturedBlockAst) { case (ast, child) => ast.withChild(Ast(child)) }
val astWithRefEdges = refEdges.foldLeft(astWithAstChildren) { case (ast, (src, dst)) => ast.withRefEdge(src, dst) }
captureEdges.foldLeft(astWithRefEdges) { case (ast, (src, dst)) => ast.withCaptureEdge(src, dst) }
}

/** Creates the bindings between the method and its types. This is useful for resolving function pointers and imports.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,20 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t

protected def astForDoBlock(block: Block & RubyExpression): Seq[Ast] = {
// Create closure structures: [MethodDecl, TypeRef, MethodRef]
val methodName = nextClosureName()
if (closureToRefs.contains(block)) {
closureToRefs(block).map(x => Ast(x.copy))
} else {
val methodName = nextClosureName()

val methodAstsWithRefs = block.body match {
case x: Block =>
astForMethodDeclaration(x.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true)
case _ =>
astForMethodDeclaration(block.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true)
val methodRefAsts = block.body match {
case x: Block =>
astForMethodDeclaration(x.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true)
case _ =>
astForMethodDeclaration(block.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true)
}
closureToRefs.put(block, methodRefAsts.flatMap(_.root))
methodRefAsts
}

methodAstsWithRefs
}

protected def astForReturnExpression(node: ReturnExpression): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ class DoBlockTests extends RubyCode2CpgFixture {
|""".stripMargin)

inside(cpg.local.l) {
case jfsOutsideLocal :: hashInsideLocal :: jfsCapturedLocal :: tmp0 :: tmp1 :: Nil =>
case jfsOutsideLocal :: hashInsideLocal :: tmp0 :: jfsCapturedLocal :: tmp1 :: Nil =>
jfsOutsideLocal.closureBindingId shouldBe None
hashInsideLocal.closureBindingId shouldBe None
jfsCapturedLocal.closureBindingId shouldBe Some("Test0.rb:<main>.get_pto_schedule.jfs")
Expand All @@ -412,7 +412,7 @@ class DoBlockTests extends RubyCode2CpgFixture {
}

inside(cpg.method.isLambda.local.l) {
case hashLocal :: jfsLocal :: _ :: Nil =>
case hashLocal :: _ :: jfsLocal :: Nil =>
hashLocal.closureBindingId shouldBe None
jfsLocal.closureBindingId shouldBe Some("Test0.rb:<main>.get_pto_schedule.jfs")
case xs => fail(s"Expected 3 locals in lambda, got ${xs.code.mkString(",")}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -984,4 +984,21 @@ class MethodTests extends RubyCode2CpgFixture {
}
}
}

"lambdas as arguments to a long chained call" should {
val cpg = code("""
|def foo(xs, total_ys, hex_values)
| xs.map.with_index { |f, i| [f / total_ys, hex_values[i]] } # 1
| .sort_by { |r| -r[0] } # 2
| .reject { |r| r[1].size == 8 && r[1].end_with?('00') } # 3
| .map { |r| Foo::Bar::Baz.new(*r[1][0..5].scan(/../).map { |c| c.to_i(16) }) } # 4 & 5
| .slice(0, quantity)
| end
|""".stripMargin)

"not write lambda nodes that are already assigned to some temp variable" in {
cpg.typeRef.typeFullName(".*Proc").size shouldBe 5
cpg.typeRef.whereNot(_.astParent).size shouldBe 0
}
}
}
27 changes: 23 additions & 4 deletions joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ object Ast {
ast.bindsEdges.foreach { edge =>
diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS)
}

ast.captureEdges.foreach { edge =>
diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CAPTURE)
}
}

def neighbourValidation(src: NewNode, dst: NewNode, edge: String)(implicit
Expand Down Expand Up @@ -92,7 +96,8 @@ case class Ast(
refEdges: collection.Seq[AstEdge] = Vector.empty,
bindsEdges: collection.Seq[AstEdge] = Vector.empty,
receiverEdges: collection.Seq[AstEdge] = Vector.empty,
argEdges: collection.Seq[AstEdge] = Vector.empty
argEdges: collection.Seq[AstEdge] = Vector.empty,
captureEdges: collection.Seq[AstEdge] = Vector.empty
)(implicit withSchemaValidation: ValidationMode = ValidationMode.Disabled) {

def root: Option[NewNode] = nodes.headOption
Expand All @@ -114,7 +119,8 @@ case class Ast(
argEdges = argEdges ++ other.argEdges,
receiverEdges = receiverEdges ++ other.receiverEdges,
refEdges = refEdges ++ other.refEdges,
bindsEdges = bindsEdges ++ other.bindsEdges
bindsEdges = bindsEdges ++ other.bindsEdges,
captureEdges = captureEdges ++ other.captureEdges
)
}

Expand All @@ -126,7 +132,8 @@ case class Ast(
argEdges = argEdges ++ other.argEdges,
receiverEdges = receiverEdges ++ other.receiverEdges,
refEdges = refEdges ++ other.refEdges,
bindsEdges = bindsEdges ++ other.bindsEdges
bindsEdges = bindsEdges ++ other.bindsEdges,
captureEdges = captureEdges ++ other.captureEdges
)
}

Expand Down Expand Up @@ -217,6 +224,16 @@ case class Ast(
this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _)))
}

def withCaptureEdge(src: NewNode, dst: NewNode): Ast = {
Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE)
this.copy(captureEdges = captureEdges ++ List(AstEdge(src, dst)))
}

def withCaptureEdges(src: NewNode, dsts: Seq[NewNode]): Ast = {
dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE))
this.copy(captureEdges = captureEdges ++ dsts.map(AstEdge(src, _)))
}

/** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and `argumentIndex`
* fields of the new root node are set to `order`. If `replacementNode` is set, then this replaces `node` in the new
* copy.
Expand Down Expand Up @@ -250,14 +267,16 @@ case class Ast(
val newRefEdges = refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))
val newBindsEdges = bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))
val newReceiverEdges = receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))
val newCaptureEdges = captureEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))

Ast(newNode)
.copy(
argEdges = newArgEdges,
conditionEdges = newConditionEdges,
refEdges = newRefEdges,
bindsEdges = newBindsEdges,
receiverEdges = newReceiverEdges
receiverEdges = newReceiverEdges,
captureEdges = newCaptureEdges
)
.withChildren(newChildren)
}
Expand Down
Loading