Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rubysrc2cpg: Prints for parser issues causing all context cases to mismatch #3003

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class AstCreator(
val node = createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
case _ =>
logger.error("astForSingleLeftHandSideContext() All contexts mismatched.")
logger.error(s"astForSingleLeftHandSideContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())

}
Expand Down Expand Up @@ -340,7 +340,7 @@ class AstCreator(
case ctx: ChainedInvocationWithoutArgumentsPrimaryContext =>
astForChainedInvocationWithoutArgumentsPrimaryContext(ctx)
case _ =>
logger.error("astForPrimaryContext() All contexts mismatched.")
logger.error(s"astForPrimaryContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand All @@ -364,7 +364,7 @@ class AstCreator(
case ctx: MultipleAssignmentExpressionContext => astForMultipleAssignmentExpressionContext(ctx)
case ctx: IsDefinedExpressionContext => Seq(astForIsDefinedExpression(ctx))
case _ =>
logger.error("astForExpressionContext() All contexts mismatched.")
logger.error(s"astForExpressionContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -415,7 +415,7 @@ class AstCreator(
case ctx: RubyParser.SplattingOnlyIndexingArgumentsContext =>
astForSplattingArgumentContext(ctx.splattingArgument())
case _ =>
logger.error("astForIndexingArgumentsContext() All contexts mismatched.")
logger.error(s"astForIndexingArgumentsContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -627,7 +627,7 @@ class AstCreator(
case ctx: GroupedLeftHandSideOnlyMultipleLeftHandSideContext =>
astForGroupedLeftHandSideContext(ctx.groupedLeftHandSide())
case _ =>
logger.error("astForMultipleLeftHandSideContext() All contexts mismatched.")
logger.error(s"astForMultipleLeftHandSideContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -735,7 +735,7 @@ class AstCreator(
.withChildren(astForArguments(ctx.arguments()))
)
case _ =>
logger.error("astForInvocationWithoutParenthesesContext() All contexts mismatched.")
logger.error(s"astForInvocationWithoutParenthesesContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -973,7 +973,7 @@ class AstCreator(
case ctx: SimpleMethodNamePartContext => astForSimpleMethodNamePartContext(ctx)
case ctx: SingletonMethodNamePartContext => astForSingletonMethodNamePartContext(ctx)
case _ =>
logger.error("astForMethodNamePartContext() All contexts mismatched.")
logger.error(s"astForMethodNamePartContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -1050,7 +1050,7 @@ class AstCreator(
}

def astForBodyStatementContext(ctx: BodyStatementContext, addReturnNode: Boolean = false): Seq[Ast] = {
val compoundStatementAsts = astForCompoundStatement(ctx.compoundStatement())
val compoundStatementAsts = astForCompoundStatement(ctx.compoundStatement(), !addReturnNode)

val compoundStatementAstsWithReturn =
if (addReturnNode && compoundStatementAsts.size > 0) {
Expand Down Expand Up @@ -1320,7 +1320,7 @@ class AstCreator(
val primaryAsts = astForPrimaryContext(ctx.primary())
primaryAsts ++ methodNameAsts ++ argsAsts ++ doBlockAsts
case _ =>
logger.error("astForCommandWithDoBlockContext() All contexts mismatched.")
logger.error(s"astForCommandWithDoBlockContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -1354,7 +1354,7 @@ class AstCreator(
case ctx: ChainedCommandWithDoBlockOnlyArgumentsWithParenthesesContext =>
astForChainedCommandWithDoBlockContext(ctx.chainedCommandWithDoBlock())
case _ =>
logger.error("astForArgumentsWithParenthesesContext() All contexts mismatched.")
logger.error(s"astForArgumentsWithParenthesesContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import io.joern.x2cpg.Ast
import io.joern.x2cpg.Defines.DynamicCallUnknownFullName
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators}
import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewCall, NewControlStructure, NewImport, NewLiteral}
import org.slf4j.LoggerFactory
import org.antlr.v4.runtime.ParserRuleContext

import scala.jdk.CollectionConverters.CollectionHasAsScala

trait AstForStatementsCreator {
this: AstCreator =>

private val logger = LoggerFactory.getLogger(this.getClass)
protected def astForAliasStatement(ctx: AliasStatementContext): Ast = {
val aliasName = ctx.definedMethodNameOrSymbol(0).getText.substring(1)
val methodName = ctx.definedMethodNameOrSymbol(1).getText.substring(1)
Expand Down Expand Up @@ -80,9 +82,13 @@ trait AstForStatementsCreator {
controlStructureAst(throwNode, rhs.headOption, lhs)
}

protected def astForCompoundStatement(ctx: CompoundStatementContext): Seq[Ast] = {
protected def astForCompoundStatement(ctx: CompoundStatementContext, packInBlock: Boolean = true): Seq[Ast] = {
val stmtAsts = Option(ctx.statements()).map(astForStatements).getOrElse(Seq())
Seq(blockAst(blockNode(ctx), stmtAsts.toList))
if (packInBlock) {
Seq(blockAst(blockNode(ctx), stmtAsts.toList))
} else {
stmtAsts
}
}

protected def astForStatements(ctx: StatementsContext): Seq[Ast] = {
Expand Down Expand Up @@ -110,7 +116,9 @@ trait AstForStatementsCreator {
case ctx: NotExpressionOrCommandContext => Seq(astForNotKeywordExpressionOrCommand(ctx))
case ctx: OrAndExpressionOrCommandContext => Seq(astForOrAndExpressionOrCommand(ctx))
case ctx: ExpressionExpressionOrCommandContext => astForExpressionContext(ctx.expression())
case _ => Seq(Ast())
case _ =>
logger.error(s"astForExpressionOrCommand() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

protected def astForNotKeywordExpressionOrCommand(ctx: NotExpressionOrCommandContext): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,29 @@ class DataFlowTests extends DataFlowCodeToCpgSuite {
}
}

"Data flow for begin/rescue with sink in else" should {
val cpg = code("""
|x = 1
|begin
| puts "In begin"
|rescue SomeException
| puts "SomeException occurred"
|rescue => exceptionVar
| puts "Caught exception in variable #{exceptionVar}"
|rescue
| puts "Catch-all block"
|else
| puts x
|end
|""".stripMargin)

"find flows to the sink" in {
val source = cpg.identifier.name("x").l
val sink = cpg.call.name("puts").l
sink.reachableByFlows(source).size shouldBe 2
}
}

"Data flow for begin/rescue with sink in rescue" should {
val cpg = code("""
|x = 1
Expand Down Expand Up @@ -1010,6 +1033,70 @@ class DataFlowTests extends DataFlowCodeToCpgSuite {
}
}

"Data flow for begin/rescue with sink in ensure" should {
val cpg = code("""
|x = 1
|begin
| puts "in begin"
|rescue SomeException
| puts "SomeException occurred"
|rescue => exceptionVar
| puts "Caught exception in variable #{exceptionVar}"
|rescue
| puts "In rescue all"
|ensure
| puts x
|end
|
|""".stripMargin)

"find flows to the sink" in {
val source = cpg.identifier.name("x").l
val sink = cpg.call.name("puts").l
sink.reachableByFlows(source).size shouldBe 2
}
}

// parsing issue. comment out when fixed
"Data flow for begin/rescue with data flow through the exception" ignore {
val cpg = code("""
|x = "Exception message: "
|begin
|1/0
|rescue ZeroDivisionError => e
| y = x + e.message
| puts y
|end
|
|""".stripMargin)

"find flows to the sink" in {
val source = cpg.identifier.name("x").l
val sink = cpg.call.name("puts").l
sink.reachableByFlows(source).size shouldBe 2
}
}

"Data flow for begin/rescue with data flow through block with multiple exceptions being caught" should {
val cpg = code("""
|x = 1
|y = 10
|begin
|1/0
|rescue SystemCallError, ZeroDivisionError
| y = x + 100
|end
|
|puts y
|""".stripMargin)

"find flows to the sink" in {
val source = cpg.identifier.name("x").l
val sink = cpg.call.name("puts").l
sink.reachableByFlows(source).size shouldBe 2
}
}

"Data flow for begin/rescue with sink in function without begin" ignore {
val cpg = code("""
|def foo(arg)
Expand Down