Skip to content

Commit

Permalink
[php] Type Recovery on Method Returns
Browse files Browse the repository at this point in the history
Added a test that tests type flow to a method return which is inherited by another class and should then be propagated to the assigned identifier.

So far this test only fixes the bug where calls in the method return were discovered and resolved but not persisted.

Next steps: Investigate call resolution for superclass methods
  • Loading branch information
DavidBakerEffendi committed May 17, 2024
1 parent 7e347a9 commit e2a69fc
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import io.shiftleft.semanticcpg.language.operatorextension.OpNodes
import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.{Assignment, FieldAccess}
import overflowdb.BatchedUpdate.DiffGraphBuilder

import scala.annotation.tailrec
import scala.collection.mutable

class PhpTypeRecoveryPassGenerator(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig(iterations = 3))
Expand Down Expand Up @@ -41,8 +40,8 @@ private class RecoverForPhpFile(cpg: Cpg, cu: NamespaceBlock, builder: DiffGraph
override protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match {
case x: Call =>
x.methodFullName match {
case Operators.alloc =>
case _ => symbolTable.append(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet)
case s"<operator>.$_" =>
case _ => symbolTable.append(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet)
}
case _ => super.prepopulateSymbolTableEntry(x)
}
Expand Down Expand Up @@ -117,11 +116,10 @@ private class RecoverForPhpFile(cpg: Cpg, cu: NamespaceBlock, builder: DiffGraph
)
existingTypes.addAll(methodTypesTable.getOrElse(m, mutable.HashSet()))

@tailrec
def extractTypes(xs: List[CfgNode]): Set[String] = xs match {
case ::(head: Literal, Nil) if head.typeFullName != "ANY" =>
Set(head.typeFullName)
case ::(head: Call, Nil) if head.name == Operators.fieldAccess =>
case (head: Call) :: _ if head.name == Operators.fieldAccess =>
val fieldAccess = head.asInstanceOf[FieldAccess]
val (sym, ts) = getSymbolFromCall(fieldAccess)
val cpgTypes = cpg.typeDecl
Expand All @@ -133,21 +131,23 @@ private class RecoverForPhpFile(cpg: Cpg, cu: NamespaceBlock, builder: DiffGraph
.toSet
if (cpgTypes.nonEmpty) cpgTypes
else symbolTable.get(sym)
case ::(head: Call, Nil) if symbolTable.contains(head) =>
case (head: Call) :: _ if symbolTable.contains(head) =>
val callPaths = symbolTable.get(head)
val returnValues = methodReturnValues(callPaths.toSeq)
if (returnValues.isEmpty)
callPaths.map(c => s"$c$pathSep${XTypeRecovery.DummyReturnType}")
else
returnValues
case ::(head: Call, Nil) if head.argumentOut.headOption.exists(symbolTable.contains) =>
case (head: Call) :: _ if head.receiver.headOption.exists(symbolTable.contains) =>
symbolTable
.get(head.argumentOut.head)
.get(head.receiver.head)
.map(t => Seq(t, head.name, XTypeRecovery.DummyReturnType).mkString(pathSep))
case ::(identifier: Identifier, Nil) if symbolTable.contains(identifier) =>
symbolTable.get(identifier)
case ::(head: Call, Nil) =>
extractTypes(head.argument.l)
case (head: Call) :: _ =>
val callees =
extractTypes(head.argument.l).map(t => Seq(t, head.name, XTypeRecovery.DummyReturnType).mkString(pathSep))
symbolTable.append(head, callees)
case _ => Set.empty
}
val returnTypes = extractTypes(ret.argumentOut.l)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package io.joern.php2cpg.passes

import io.joern.php2cpg.testfixtures.PhpCode2CpgFixture
import io.shiftleft.semanticcpg.language._
import io.shiftleft.codepropertygraph.generated.nodes.Identifier
import io.shiftleft.semanticcpg.language.*

class PhpTypeRecoveryPassTests extends PhpCode2CpgFixture() {

Expand Down Expand Up @@ -445,4 +446,99 @@ class PhpTypeRecoveryPassTests extends PhpCode2CpgFixture() {
}
}

"a reference to a field of some external type (propagated via inherited calls)" should {
val cpg = code(
"""
|<?php
|
|declare(strict_types=1);
|
|namespace Jobcloud\Marketplace\Core\Repository;
|
|use Doctrine\ORM\EntityManagerInterface;
|use Doctrine\ORM\EntityRepository;
|use Doctrine\ORM\QueryBuilder;
|use Some\Repository\EntityRepositoryInterface;
|
|abstract class AbstractEntityRepository implements EntityRepositoryInterface
|{
| protected AliasHelperInterface $aliasHelper;
| private EntityManagerInterface $entityManager;
|
| private string $entityClassName;
|
| public function __construct(
| EntityManagerInterface $entityManager,
| ) {
| $this->entityManager = $entityManager;
| }
|
| protected function createQueryBuilder(string $alias, ?string $indexBy = null): QueryBuilder
| {
| return $this->entityManager->createQueryBuilder()->select($alias)
| ->from("ABC", $alias, $indexBy);
| }
|}
|""".stripMargin,
"AbstractEntity.php"
).moreCode(
"""
|<?php
|
|declare(strict_types=1);
|
|namespace Some\Repository;
|
|use Some\Entity\User;
|use Some\Entity\EntityInterface;
|
|class SomeRepository extends AbstractEntityRepository
|{
| public function findSomething(
| \DateTimeImmutable $date,
| string $accountId,
| ): ?EntityInterface {
| $rootAlias = $this->getRootAlias();
| $userAlias = $this->aliasHelper->getAliasForClass(User::class);
|
| $queryBuilder = $this->createQueryBuilder($rootAlias);
|
| $queryBuilder
| ->leftJoin(sprintf('%s.foo', $rootAlias), $userAlias)
| ->setParameter('userName', $userName)
|
| return $queryBuilder->getQuery()->execute()[0] ?? null;
| }
|}
|""".stripMargin,
"User.php"
)

"resolve the correct full name for the wrapped QueryBuilder call off the field" in {
inside(cpg.method.nameExact("createQueryBuilder").call.name(".*createQueryBuilder").l) {
case queryBuilderCall :: Nil =>
queryBuilderCall.methodFullName shouldBe "Doctrine\\ORM\\EntityManagerInterface->createQueryBuilder-><returnValue>"
case xs => fail(s"Expected one call, instead got [$xs]")
}
}

"propagate this QueryBuilder type to the identifier assigned to the inherited call for the wrapped `createQueryBuilder`" ignore {
cpg.method
.nameExact("findSomething")
._containsOut
.collectAll[Identifier]
.nameExact("queryBuilder")
.typeFullName
.head shouldBe "Doctrine\\ORM\\EntityManagerInterface->createQueryBuilder-><returnValue>->select-><returnValue>->from-><returnValue>"
}

"resolve the correct full name a call based on the QueryBuilder return value" ignore {
inside(cpg.call.nameExact("setParameter").l) {
case setParamCall :: Nil =>
setParamCall.methodFullName shouldBe "Doctrine\\ORM\\EntityManagerInterface->createQueryBuilder-><returnValue>->leftJoin-><returnValue>->setParameter-><returnValue>"
case xs => fail(s"Expected one call, instead got [$xs]")
}
}
}

}

0 comments on commit e2a69fc

Please sign in to comment.