Skip to content

Commit

Permalink
Revert "[ruby] Ignore "Throwaway" AST Structures (#4982)" (#4983)
Browse files Browse the repository at this point in the history
This reverts commit 3c27bf6.
  • Loading branch information
DavidBakerEffendi authored Oct 1, 2024
1 parent 3c27bf6 commit 464480d
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,18 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
/** Attempts to extract a type from the base of a member call.
*/
protected def typeFromCallTarget(baseNode: RubyExpression): Option[String] = {
baseNode match {
case literal: LiteralExpr => Option(literal.typeFullName)
scope.lookupVariable(baseNode.text) match {
// fixme: This should be under type recovery logic
case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => Option(decl.typeFullName)
case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => Option(decl.typeFullName)
case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => decl.dynamicTypeHintFullName.headOption
case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty =>
decl.dynamicTypeHintFullName.headOption
case _ =>
scope.lookupVariable(baseNode.text) match {
// fixme: This should be under type recovery logic
case Some(decl: NewLocal) if decl.typeFullName != Defines.Any => Option(decl.typeFullName)
case Some(decl: NewMethodParameterIn) if decl.typeFullName != Defines.Any => Option(decl.typeFullName)
case Some(decl: NewLocal) if decl.dynamicTypeHintFullName.nonEmpty => decl.dynamicTypeHintFullName.headOption
case Some(decl: NewMethodParameterIn) if decl.dynamicTypeHintFullName.nonEmpty =>
decl.dynamicTypeHintFullName.headOption
case _ => None
}
astForExpression(baseNode).nodes
.flatMap(_.properties.get(PropertyNames.TYPE_FULL_NAME).map(_.toString))
.filterNot(_ == XDefines.Any)
.headOption
}
}

Expand Down Expand Up @@ -296,10 +296,7 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val createAssignmentToTmp = !baseAstCache.contains(target)
val tmpName = baseAstCache
.updateWith(target) {
case Some(tmpName) =>
// TODO: Type ref nodes are automatically committed on creation, so if we have found a suitable cached AST,
// we want to clean this creation up.
Option(tmpName)
case Some(tmpName) => Option(tmpName)
case None =>
val tmpName = this.tmpGen.fresh
val tmpGenLocal = NewLocal().name(tmpName).code(tmpName).typeFullName(Defines.Any)
Expand Down Expand Up @@ -875,11 +872,16 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
}

private def astForMemberCallWithoutBlock(node: SimpleCall, memberAccess: MemberAccess): Ast = {
val receiverAst = astForFieldAccess(memberAccess)
val methodName = memberAccess.memberName
val methodFullName = XDefines.DynamicCallUnknownFullName
val argumentAsts = node.arguments.map(astForMethodCallArgument)
val call = callNode(node, code(node), methodName, methodFullName, DispatchTypes.DYNAMIC_DISPATCH)
val receiverAst = astForFieldAccess(memberAccess)
val methodName = memberAccess.memberName
// TODO: Type recovery should potentially resolve this
val methodFullName = typeFromCallTarget(memberAccess.target)
.map(x => s"$x.$methodName")
.getOrElse(XDefines.DynamicCallUnknownFullName)
val argumentAsts = node.arguments.map(astForMethodCallArgument)
val call =
callNode(node, code(node), methodName, XDefines.DynamicCallUnknownFullName, DispatchTypes.DYNAMIC_DISPATCH)
.possibleTypes(IndexedSeq(methodFullName))

callAst(call, argumentAsts, Some(receiverAst))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ import scala.collection.mutable

trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

/** As expressions may be discarded, we cannot store closure ASTs in the diffgraph at the point of creation. We need
* to only write these at the end.
*/
protected val closureToRefs = mutable.Map.empty[RubyExpression, Seq[Ast]]

/** Creates method declaration related structures.
* @param node
* the node to create the AST structure from.
Expand Down Expand Up @@ -199,11 +194,8 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
case _ => false
})

val typeRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewTypeRef => x }
val methodRefOption = refs.flatMap(_.nodes).collectFirst { case x: NewTypeRef => x }

val astChildren = mutable.Buffer.empty[NewNode]
val refEdges = mutable.Buffer.empty[(NewNode, NewNode)]
val captureEdges = mutable.Buffer.empty[(NewNode, NewNode)]
capturedLocalNodes
.collect {
case local: NewLocal =>
Expand All @@ -224,17 +216,14 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
)

// Create new local node for lambda, with corresponding REF edges to identifiers and closure binding
val _refEdges =
capturedIdentifiers.filter(_.name == name).map(i => i -> capturingLocal) :+ (closureBinding, capturedLocal)
capturedBlockAst.root.foreach(rootBlock => diffGraph.addEdge(rootBlock, capturingLocal, EdgeTypes.AST))
capturedIdentifiers.filter(_.name == name).foreach(i => diffGraph.addEdge(i, capturingLocal, EdgeTypes.REF))
diffGraph.addEdge(closureBinding, capturedLocal, EdgeTypes.REF)

astChildren.addOne(capturingLocal)
refEdges.addAll(_refEdges.toList)
captureEdges.addAll(typeRefOption.map(typeRef => typeRef -> closureBinding).toList)
methodRefOption.foreach(methodRef => diffGraph.addEdge(methodRef, closureBinding, EdgeTypes.CAPTURE))
}

val astWithAstChildren = astChildren.foldLeft(capturedBlockAst) { case (ast, child) => ast.withChild(Ast(child)) }
val astWithRefEdges = refEdges.foldLeft(astWithAstChildren) { case (ast, (src, dst)) => ast.withRefEdge(src, dst) }
captureEdges.foldLeft(astWithRefEdges) { case (ast, (src, dst)) => ast.withCaptureEdge(src, dst) }
capturedBlockAst
}

/** Creates the bindings between the method and its types. This is useful for resolving function pointers and imports.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,16 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t

protected def astForDoBlock(block: Block & RubyExpression): Seq[Ast] = {
// Create closure structures: [MethodDecl, TypeRef, MethodRef]
if (closureToRefs.contains(block)) {
closureToRefs(block)
} else {
val methodName = nextClosureName()
val methodName = nextClosureName()

val methodAstsWithRefs = block.body match {
case x: Block =>
astForMethodDeclaration(x.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true)
case _ =>
astForMethodDeclaration(block.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true)
}
closureToRefs.put(block, methodAstsWithRefs)
methodAstsWithRefs
val methodAstsWithRefs = block.body match {
case x: Block =>
astForMethodDeclaration(x.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true)
case _ =>
astForMethodDeclaration(block.toMethodDeclaration(methodName, Option(block.parameters)), isClosure = true)
}

methodAstsWithRefs
}

protected def astForReturnExpression(node: ReturnExpression): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ class DoBlockTests extends RubyCode2CpgFixture {
|""".stripMargin)

inside(cpg.local.l) {
case jfsOutsideLocal :: hashInsideLocal :: tmp0 :: jfsCapturedLocal :: tmp1 :: Nil =>
case jfsOutsideLocal :: hashInsideLocal :: jfsCapturedLocal :: tmp0 :: tmp1 :: Nil =>
jfsOutsideLocal.closureBindingId shouldBe None
hashInsideLocal.closureBindingId shouldBe None
jfsCapturedLocal.closureBindingId shouldBe Some("Test0.rb:<main>.get_pto_schedule.jfs")
Expand All @@ -412,7 +412,7 @@ class DoBlockTests extends RubyCode2CpgFixture {
}

inside(cpg.method.isLambda.local.l) {
case hashLocal :: _ :: jfsLocal :: Nil =>
case hashLocal :: jfsLocal :: _ :: Nil =>
hashLocal.closureBindingId shouldBe None
jfsLocal.closureBindingId shouldBe Some("Test0.rb:<main>.get_pto_schedule.jfs")
case xs => fail(s"Expected 3 locals in lambda, got ${xs.code.mkString(",")}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -984,21 +984,4 @@ class MethodTests extends RubyCode2CpgFixture {
}
}
}

"lambdas as arguments to a long chained call" should {
val cpg = code("""
|def foo(xs, total_ys, hex_values)
| xs.map.with_index { |f, i| [f / total_ys, hex_values[i]] } # 1
| .sort_by { |r| -r[0] } # 2
| .reject { |r| r[1].size == 8 && r[1].end_with?('00') } # 3
| .map { |r| Foo::Bar::Baz.new(*r[1][0..5].scan(/../).map { |c| c.to_i(16) }) } # 4 & 5
| .slice(0, quantity)
| end
|""".stripMargin)

"not write lambda nodes that are already assigned to some temp variable" in {
cpg.typeRef.typeFullName(".*Proc").size shouldBe 5
cpg.typeRef.whereNot(_.astParent).size shouldBe 0
}
}
}
27 changes: 4 additions & 23 deletions joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ object Ast {
ast.bindsEdges.foreach { edge =>
diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS)
}

ast.captureEdges.foreach { edge =>
diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CAPTURE)
}
}

def neighbourValidation(src: NewNode, dst: NewNode, edge: String)(implicit
Expand Down Expand Up @@ -96,8 +92,7 @@ case class Ast(
refEdges: collection.Seq[AstEdge] = Vector.empty,
bindsEdges: collection.Seq[AstEdge] = Vector.empty,
receiverEdges: collection.Seq[AstEdge] = Vector.empty,
argEdges: collection.Seq[AstEdge] = Vector.empty,
captureEdges: collection.Seq[AstEdge] = Vector.empty
argEdges: collection.Seq[AstEdge] = Vector.empty
)(implicit withSchemaValidation: ValidationMode = ValidationMode.Disabled) {

def root: Option[NewNode] = nodes.headOption
Expand All @@ -119,8 +114,7 @@ case class Ast(
argEdges = argEdges ++ other.argEdges,
receiverEdges = receiverEdges ++ other.receiverEdges,
refEdges = refEdges ++ other.refEdges,
bindsEdges = bindsEdges ++ other.bindsEdges,
captureEdges = captureEdges ++ other.captureEdges
bindsEdges = bindsEdges ++ other.bindsEdges
)
}

Expand All @@ -132,8 +126,7 @@ case class Ast(
argEdges = argEdges ++ other.argEdges,
receiverEdges = receiverEdges ++ other.receiverEdges,
refEdges = refEdges ++ other.refEdges,
bindsEdges = bindsEdges ++ other.bindsEdges,
captureEdges = captureEdges ++ other.captureEdges
bindsEdges = bindsEdges ++ other.bindsEdges
)
}

Expand Down Expand Up @@ -224,16 +217,6 @@ case class Ast(
this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _)))
}

def withCaptureEdge(src: NewNode, dst: NewNode): Ast = {
Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE)
this.copy(captureEdges = captureEdges ++ List(AstEdge(src, dst)))
}

def withCaptureEdges(src: NewNode, dsts: Seq[NewNode]): Ast = {
dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.CAPTURE))
this.copy(captureEdges = captureEdges ++ dsts.map(AstEdge(src, _)))
}

/** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and `argumentIndex`
* fields of the new root node are set to `order`. If `replacementNode` is set, then this replaces `node` in the new
* copy.
Expand Down Expand Up @@ -267,16 +250,14 @@ case class Ast(
val newRefEdges = refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))
val newBindsEdges = bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))
val newReceiverEdges = receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))
val newCaptureEdges = captureEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst)))

Ast(newNode)
.copy(
argEdges = newArgEdges,
conditionEdges = newConditionEdges,
refEdges = newRefEdges,
bindsEdges = newBindsEdges,
receiverEdges = newReceiverEdges,
captureEdges = newCaptureEdges
receiverEdges = newReceiverEdges
)
.withChildren(newChildren)
}
Expand Down

0 comments on commit 464480d

Please sign in to comment.