Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ruby] Add handling for multiple call args #4948

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,16 @@ command
# commandTernaryOperatorExpression
| primary NL? (AMPDOT | DOT | COLON2) methodName commandArgument
# memberAccessCommand
| methodIdentifier commandArgument
| methodIdentifier simpleCommandArgumentList
# simpleCommand
;

simpleCommandArgumentList
: associationList
| primaryValueList (COMMA NL* associationList)?
| argumentList
;

commandArgument
: commandArgumentList
# commandArgumentCommandArgumentList
Expand Down Expand Up @@ -251,7 +257,7 @@ splatArgList
commandArgumentList
: associationList
| primaryValueList (COMMA NL* associationList)?
;
;

primaryValueList
: primaryValue (COMMA NL* primaryValue)*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ object AntlrContextHelpers {
def parameters: List[ParserRuleContext] = Option(ctx.blockParameterList()).map(_.parameters).getOrElse(List())
}

sealed implicit class CommandArgumentContextHelper(ctx: CommandArgumentContext) {
sealed implicit class CommandArgumentContextelper(ctx: CommandArgumentContext) {
def arguments: List[ParserRuleContext] = ctx match {
case ctx: CommandCommandArgumentListContext => ctx.command() :: Nil
case ctx: CommandArgumentCommandArgumentListContext => ctx.commandArgumentList().elements
Expand All @@ -162,6 +162,15 @@ object AntlrContextHelpers {
}
}

sealed implicit class SimpleCommandArgumentListContextHelper(ctx: SimpleCommandArgumentListContext) {
def arguments: List[ParserRuleContext] = {
val primaryValues = Option(ctx.primaryValueList()).map(_.primaryValue().asScala.toList).getOrElse(List())
val associations = Option(ctx.associationList()).map(_.association().asScala.toList).getOrElse(List())
val argumentLists = Option(ctx.argumentList()).map(_.elements).getOrElse(List())
primaryValues ++ associations ++ argumentLists
}
}

sealed implicit class PrimaryValueListWithAssociationContextHelper(ctx: PrimaryValueListWithAssociationContext) {
def elements: List[ParserRuleContext] = {
ctx.children.asScala.collect {
Expand Down Expand Up @@ -332,6 +341,8 @@ object AntlrContextHelpers {
Option(ctx.blockArgument()).toList
case ctx: ArrayArgumentListContext =>
Option(ctx.indexingArgumentList()).toList
case ctx: SingleCommandArgumentListContext =>
Option(ctx.command()).toList
case ctx =>
logger.warn(s"ArgumentListContextHelper - Unsupported element type ${ctx.getClass.getSimpleName}")
List()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,13 +601,13 @@ class AstPrinter extends RubyParserBaseVisitor[String] {
}

override def visitSimpleCommand(ctx: RubyParser.SimpleCommandContext): String = {
if (Option(ctx.commandArgument()).map(_.getText).exists(_.startsWith("::"))) {
val memberName = ctx.commandArgument().getText.stripPrefix("::")
if (Option(ctx.simpleCommandArgumentList()).map(_.getText).exists(_.startsWith("::"))) {
val memberName = ctx.simpleCommandArgumentList().getText.stripPrefix("::")
val methodIdentifier = visit(ctx.methodIdentifier())
s"$methodIdentifier::$memberName"
} else if (!ctx.methodIdentifier().isAttrDeclaration) {
val identifierCtx = ctx.methodIdentifier()
val arguments = ctx.commandArgument().arguments.map(visit)
val arguments = ctx.simpleCommandArgumentList().arguments.map(visit)
(identifierCtx.getText, arguments) match {
case ("require", List(argument)) =>
s"require ${arguments.mkString(",")}"
Expand All @@ -627,7 +627,7 @@ class AstPrinter extends RubyParserBaseVisitor[String] {
s"${visit(identifierCtx)} ${arguments.mkString(",")}"
}
} else {
s"${ctx.commandArgument.arguments.map(visit).mkString(",")}"
s"${ctx.simpleCommandArgumentList.arguments.map(visit).mkString(",")}"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,16 +681,16 @@ class RubyNodeCreator(
}

override def visitSimpleCommand(ctx: RubyParser.SimpleCommandContext): RubyExpression = {
if (Option(ctx.commandArgument()).map(_.getText).exists(_.startsWith("::"))) {
val memberName = ctx.commandArgument().getText.stripPrefix("::")
if (Option(ctx.simpleCommandArgumentList()).map(_.getText).exists(_.startsWith("::"))) {
val memberName = ctx.simpleCommandArgumentList().getText.stripPrefix("::")
if (memberName.headOption.exists(_.isUpper)) { // Constant accesses are upper-case 1st letter
MemberAccess(visit(ctx.methodIdentifier()), "::", memberName)(ctx.toTextSpan)
} else {
MemberCall(visit(ctx.methodIdentifier()), "::", memberName, Nil)(ctx.toTextSpan)
}
} else if (!ctx.methodIdentifier().isAttrDeclaration) {
val identifierCtx = ctx.methodIdentifier()
val arguments = ctx.commandArgument().arguments.map(visit)
val arguments = ctx.simpleCommandArgumentList().arguments.map(visit)
(identifierCtx.getText, arguments) match {
case (requireLike, List(argument)) if ImportsPass.ImportCallNames.contains(requireLike) =>
val isRelative = requireLike == "require_relative" || requireLike == "require_all"
Expand All @@ -713,14 +713,14 @@ class RubyNodeCreator(
val lhsIdentifier = SimpleIdentifier(None)(identifierCtx.toTextSpan.spanStart(idAssign.stripSuffix("=")))
val argNode = arguments match {
case arg :: Nil => arg
case xs => ArrayLiteral(xs)(ctx.commandArgument().toTextSpan)
case xs => ArrayLiteral(xs)(ctx.simpleCommandArgumentList().toTextSpan)
}
SingleAssignment(lhsIdentifier, "=", argNode)(ctx.toTextSpan)
case _ =>
SimpleCall(visit(identifierCtx), arguments)(ctx.toTextSpan)
}
} else {
FieldsDeclaration(ctx.commandArgument().arguments.map(visit))(ctx.toTextSpan)
FieldsDeclaration(ctx.simpleCommandArgumentList().arguments.map(visit))(ctx.toTextSpan)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class AssignmentParserTests extends RubyParserFixture with Matchers {
test("*a, b, c = 1, 2, 3, 4")
test("a, b, c = 1, 2, *list")
test("a, b, c = 1, *list")
test("a = b, *c, d")
test("a = *c, b, d")
}

"Class Constant Assign" in {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.joern.rubysrc2cpg.querying

import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier, Literal}
import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Identifier, Literal}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators}
import io.shiftleft.semanticcpg.language.*
import io.joern.rubysrc2cpg.passes.Defines as RubyDefines
Expand Down Expand Up @@ -489,4 +489,20 @@ class SingleAssignmentTests extends RubyCode2CpgFixture {
case xs => fail(s"Expected one call for assignment, got ${xs.code.mkString(",")}")
}
}

"MethodInvocationWithoutParentheses multiple call args" in {
val cpg = code("""
|def gl_badge_tag(*args, &block)
| render :some_symbol, &block
|end
|""".stripMargin)

inside(cpg.call.name("render").argument.l) {
case _ :: (blockArg: Identifier) :: (symbolArg: Literal) :: Nil =>
blockArg.code shouldBe "block"
symbolArg.code shouldBe ":some_symbol"

case xs => fail(s"Expected two args, found [${xs.code.mkString(",")}]")
}
}
}
Loading