Skip to content

Commit

Permalink
Feature distribution changes - migrated tests to Pytest, use of abstr…
Browse files Browse the repository at this point in the history
…act base classes (#277)

* added use of ABC, refactored tests

* wip

* fixed base method in Distrubution
  • Loading branch information
ronanstokes-db authored Jun 1, 2024
1 parent 82ce5ce commit 8136ccf
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 123 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ All notable changes to the Databricks Labs Data Generator will be documented in
### Changed
* Modified data generator to allow specification of constraints to the data generation process
* Updated documentation for generating text data.
* Modified data distribiutions to use abstract base classes
* migrated data distribution tests to use `pytest`

### Added
* Added classes for constraints on the data generation via new package `dbldatagen.constraints`
Expand Down
23 changes: 13 additions & 10 deletions dbldatagen/distributions/data_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
and no further scaling is needed.
"""
import copy
import pyspark.sql.functions as F
from abc import ABC, abstractmethod

import numpy as np
import pyspark.sql.functions as F


class DataDistribution(object):
class DataDistribution(ABC):
""" Base class for all distributions"""

def __init__(self):
self._rounding = False
self._randomSeed = None
Expand All @@ -37,8 +40,8 @@ def get_np_random_generator(random_seed):
:param random_seed: Numeric random seed to use. If < 0, then no random
:return:
"""
assert random_seed is None or type(random_seed) in [ np.int32, np.int64, int],\
f"`randomSeed` must be int or int-like not {type(random_seed)}"
assert random_seed is None or type(random_seed) in [np.int32, np.int64, int], \
f"`randomSeed` must be int or int-like not {type(random_seed)}"
from numpy.random import default_rng
if random_seed not in (-1, -1.0):
rng = default_rng(random_seed)
Expand All @@ -47,17 +50,17 @@ def get_np_random_generator(random_seed):

return rng

@abstractmethod
def generateNormalizedDistributionSample(self):
""" Generate sample of data for distribution
:return: random samples from distribution scaled to values between 0 and 1
Note implementors should provide implementation for this,
Return value is expected to be a Pyspark SQL column expression such as F.expr("rand()")
"""
if self.randomSeed == -1 or self.randomSeed is None:
newDef = F.expr("rand()")
else:
assert type(self.randomSeed) in [int, float], "random seed should be numeric"
newDef = F.expr(f"rand({self.randomSeed})")
return newDef
pass

def withRounding(self, rounding):
""" Create copy of object and set the rounding attribute
Expand Down
Loading

0 comments on commit 8136ccf

Please sign in to comment.