diff --git a/CHANGELOG.md b/CHANGELOG.md index deff14b..19db6e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0. ### Fixed - Fix import error for python<3.8 ([#3](https://github.com/AustinT/mol_ga/pull/3)) ([@austint]) +- Fix unintended use of system random in sampling ([#4](https://github.com/AustinT/mol_ga/pull/4)) ([@austint]) ## [0.1.0] - 2023-09-05 diff --git a/mol_ga/general_ga.py b/mol_ga/general_ga.py index 55cb9d4..0b242ad 100644 --- a/mol_ga/general_ga.py +++ b/mol_ga/general_ga.py @@ -29,7 +29,7 @@ def run_ga_maximization( *, scoring_func: Union[Callable[[list[str]], list[float]], CachedBatchFunction], starting_population_smiles: set[str], - sampling_func: Callable[[list[tuple[float, str]], int], list[str]], + sampling_func: Callable[[list[tuple[float, str]], int, random.Random], list[str]], offspring_gen_func: Callable[[list[str], int, random.Random, Optional[joblib.Parallel]], set[str]], selection_func: Callable[[int, list[tuple[float, str]]], list[tuple[float, str]]], max_generations: int, @@ -108,7 +108,7 @@ def run_ga_maximization( _, population_smiles = tuple(zip(*population)) # type: ignore[assignment] # Sample SMILES from population to create offspring - samples_from_population = sampling_func(population, num_samples_per_generation) + samples_from_population = sampling_func(population, num_samples_per_generation, rng) # Create the offspring offspring = offspring_gen_func( diff --git a/mol_ga/sample_population.py b/mol_ga/sample_population.py index bfa9565..4d7fa30 100644 --- a/mol_ga/sample_population.py +++ b/mol_ga/sample_population.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -import random +from random import Random import numpy as np @@ -9,6 +9,7 @@ def uniform_qualitle_sampling( population: list[tuple[float, str]], n_sample: int, + rng: Random, shuffle: bool = True, ) -> list[str]: """Sample SMILES by sampling uniformly from logarithmically spaced top-N.""" @@ -19,10 +20,10 @@ def uniform_qualitle_sampling( for q in quantiles: score_threshold = np.quantile([s for s, _ in population], q) eligible_population = [smiles for score, smiles in population if score >= score_threshold] - samples.extend(random.choices(population=eligible_population, k=n_samples_per_quanitile)) + samples.extend(rng.choices(population=eligible_population, k=n_samples_per_quanitile)) # Shuffle samples to decrease correlations between adjacent samples if shuffle: - random.shuffle(samples) + rng.shuffle(samples) return samples[:n_sample] # in case there are slightly too many samples