Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Nov 22, 2024
1 parent 5cf758c commit f66a530
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
1 change: 0 additions & 1 deletion netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ def branch_lengths(self, new_branch_lengths):
self._branch_lengths = new_branch_lengths
self.update_neutral_probs()


@abstractmethod
def update_neutral_probs(self):
pass
Expand Down
13 changes: 10 additions & 3 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,23 @@ class PlaceholderEncoder:
def parameters(self):
return {}


class BranchLengthDataset(Dataset):
def __len__(self):
return len(self.branch_lengths)

def export_branch_lengths(self, out_csv_path):
pd.DataFrame({"branch_length": self.branch_lengths}).to_csv(
out_csv_path, index=False
)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values

def __repr__(self):
return f"{self.__class__.__name__}(Size: {len(self)}) on {self.branch_lengths.device}"


class SHMoofDataset(BranchLengthDataset):
def __init__(self, dataframe, kmer_length, site_count):
super().__init__()
Expand All @@ -162,9 +172,6 @@ def __getitem__(self, idx):
self.branch_lengths[idx],
)

def __repr__(self):
return f"{self.__class__.__name__}(Size: {len(self)}) on {self.encoded_parents.device}"

def to(self, device):
self.encoded_parents = self.encoded_parents.to(device)
self.masks = self.masks.to(device)
Expand Down

0 comments on commit f66a530

Please sign in to comment.