diff --git a/src/main/java/ldbc/finbench/datagen/generation/events/WithdrawEvent.java b/src/main/java/ldbc/finbench/datagen/generation/events/WithdrawEvent.java index 866cc167..73fb1235 100644 --- a/src/main/java/ldbc/finbench/datagen/generation/events/WithdrawEvent.java +++ b/src/main/java/ldbc/finbench/datagen/generation/events/WithdrawEvent.java @@ -1,7 +1,7 @@ package ldbc.finbench.datagen.generation.events; import java.io.Serializable; -import java.util.ArrayList; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Random; @@ -15,22 +15,17 @@ public class WithdrawEvent implements Serializable { private final RandomGeneratorFarm randomFarm; private final Random randIndex; - private final Random amountRandom; private final Map multiplicityMap; - private final double probWithdraw; public WithdrawEvent() { - this.probWithdraw = DatagenParams.accountWithdrawFraction; randomFarm = new RandomGeneratorFarm(); randIndex = new Random(DatagenParams.defaultSeed); - amountRandom = new Random(DatagenParams.defaultSeed); multiplicityMap = new ConcurrentHashMap<>(); } private void resetState(int seed) { randomFarm.resetRandomGenerators(seed); randIndex.setSeed(seed); - amountRandom.setSeed(seed); } public long getMultiplicityIdAndInc(Account from, Account to) { @@ -39,25 +34,24 @@ public long getMultiplicityIdAndInc(Account from, Account to) { return atomicInt.getAndIncrement(); } - public List withdraw(List sources, List cards, int blockId) { + public List withdraw(Account from, List cards, int blockId) { resetState(blockId); - List withdraws = new ArrayList<>(); - sources.forEach(from -> { - if (!from.getType().equals("debit card") - && randomFarm.get(RandomGeneratorFarm.Aspect.ACCOUNT_WHETHER_WITHDRAW).nextDouble() < probWithdraw) { - int count = 0; - while (count++ < DatagenParams.maxWithdrawals) { - Account to = cards.get(randIndex.nextInt(cards.size())); - if (cannotWithdraw(from, to)) { - continue; - } - withdraws.add( - Withdraw.createWithdraw(randomFarm.get(RandomGeneratorFarm.Aspect.WITHDRAW_DATE), from, to, - getMultiplicityIdAndInc(from, to), amountRandom.nextDouble() - * DatagenParams.withdrawMaxAmount)); + + Random amountRand = randomFarm.get(RandomGeneratorFarm.Aspect.WITHDRAW_AMOUNT); + Random dateRand = randomFarm.get(RandomGeneratorFarm.Aspect.WITHDRAW_DATE); + + List withdraws = new LinkedList<>(); + if (!from.getType().equals("debit card")) { + int count = 0; + while (count++ < DatagenParams.maxWithdrawals) { + Account to = cards.get(randIndex.nextInt(cards.size())); + if (cannotWithdraw(from, to)) { + continue; } + withdraws.add(Withdraw.createWithdraw(dateRand, from, to, getMultiplicityIdAndInc(from, to), + amountRand.nextDouble() * DatagenParams.withdrawMaxAmount)); } - }); + } return withdraws; } diff --git a/src/main/java/ldbc/finbench/datagen/util/RandomGeneratorFarm.java b/src/main/java/ldbc/finbench/datagen/util/RandomGeneratorFarm.java index 461a2302..1c78334a 100644 --- a/src/main/java/ldbc/finbench/datagen/util/RandomGeneratorFarm.java +++ b/src/main/java/ldbc/finbench/datagen/util/RandomGeneratorFarm.java @@ -67,6 +67,7 @@ public enum Aspect { // edge: withdraw ACCOUNT_WHETHER_WITHDRAW, WITHDRAW_DATE, + WITHDRAW_AMOUNT, // edge: signin SIGNIN_DATE, diff --git a/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala b/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala index 661e66a6..3aa3e986 100644 --- a/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala +++ b/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala @@ -158,7 +158,7 @@ class ActivityGenerator()(implicit spark: SparkSession) .asJava val signInEvent = new SignInEvent - val signRels = mediumRDD.mapPartitionsWithIndex((index, mediums) => { + mediumRDD.mapPartitionsWithIndex((index, mediums) => { mediums.flatMap(medium => { signInEvent .signIn( @@ -169,7 +169,6 @@ class ActivityGenerator()(implicit spark: SparkSession) .asScala }) }) - signRels } def personGuaranteeEvent(personRDD: RDD[Person]): RDD[Person] = { @@ -257,17 +256,26 @@ class ActivityGenerator()(implicit spark: SparkSession) // TODO: rewrite it with account centric def withdrawEvent(accountRDD: RDD[Account]): RDD[Withdraw] = { - val withdrawEvent = new WithdrawEvent val cards = accountRDD.filter(_.getType == "debit card").collect().toList.asJava - accountRDD.mapPartitions(sources => { - val withdrawList = withdrawEvent.withdraw( - sources.toList.asJava, - cards, - TaskContext.getPartitionId() + val withdrawEvent = new WithdrawEvent + accountRDD + .sample( + withReplacement = false, + DatagenParams.accountWithdrawFraction, + sampleRandom.nextLong() ) - for { withdraw <- withdrawList.iterator().asScala } yield withdraw - }) + .mapPartitionsWithIndex((index, sources) => { + sources.flatMap(source => { + withdrawEvent + .withdraw( + source, + cards, + index + ) + .asScala + }) + }) } // TODO: rewrite it with loan centric