diff --git a/gui/src/app/SamplerOutputView/SamplerOutputView.tsx b/gui/src/app/SamplerOutputView/SamplerOutputView.tsx index 2bc6defe..dd46d7ad 100644 --- a/gui/src/app/SamplerOutputView/SamplerOutputView.tsx +++ b/gui/src/app/SamplerOutputView/SamplerOutputView.tsx @@ -2,7 +2,7 @@ import { SmallIconButton } from "@fi-sci/misc" import { Download } from "@mui/icons-material" import JSZip from 'jszip' import { FunctionComponent, useCallback, useMemo, useState } from "react" -import StanSampler from "../StanSampler/StanSampler" +import StanSampler, { SamplingOpts } from "../StanSampler/StanSampler" import { useSamplerOutput } from "../StanSampler/useStanSampler" import TabWidget from "../TabWidget/TabWidget" import { triggerDownload } from "../util/triggerDownload" @@ -30,6 +30,7 @@ const SamplerOutputView: FunctionComponent = ({width, he paramNames={paramNames} numChains={numChains} computeTimeSec={computeTimeSec} + samplingOpts={sampler.samplingOpts} /> ) } @@ -41,6 +42,7 @@ type DrawsDisplayProps = { numChains: number, paramNames: string[] computeTimeSec: number | undefined + samplingOpts: SamplingOpts // for including in exported zip } const tabs = [ @@ -70,7 +72,7 @@ const tabs = [ } ] -const DrawsDisplay: FunctionComponent = ({ width, height, draws, paramNames, numChains, computeTimeSec }) => { +const DrawsDisplay: FunctionComponent = ({ width, height, draws, paramNames, numChains, computeTimeSec, samplingOpts }) => { const [currentTabId, setCurrentTabId] = useState('summary'); @@ -106,6 +108,7 @@ const DrawsDisplay: FunctionComponent = ({ width, height, dra paramNames={paramNames} drawChainIds={drawChainIds} drawNumbers={drawNumbers} + samplingOpts={samplingOpts} /> = ({ width, height, draws, paramNames, drawChainIds, drawNumbers }) => { +const DrawsView: FunctionComponent = ({ width, height, draws, paramNames, drawChainIds, drawNumbers, samplingOpts }) => { const [abbreviatedToNumRows, setAbbreviatedToNumRows] = useState(300); const draws2 = useMemo(() => { if (abbreviatedToNumRows === undefined) return draws; @@ -147,7 +151,7 @@ const DrawsView: FunctionComponent = ({ width, height, draws, pa const handleExportToMultipleCsvs = useCallback(async () => { const uniqueChainIds = Array.from(new Set(drawChainIds)); const csvTexts = prepareMultipleCsvsText(draws, paramNames, drawChainIds, uniqueChainIds); - const blob = await createZipBlobForMultipleCsvs(csvTexts, uniqueChainIds); + const blob = await createZipBlobForMultipleCsvs(csvTexts, uniqueChainIds, samplingOpts); const fileName = 'SP-draws.zip'; const url = URL.createObjectURL(blob); const a = document.createElement('a'); @@ -155,7 +159,7 @@ const DrawsView: FunctionComponent = ({ width, height, draws, pa a.download = fileName; a.click(); URL.revokeObjectURL(url); - }, [draws, paramNames, drawChainIds]); + }, [draws, paramNames, drawChainIds, samplingOpts]); return (
@@ -237,7 +241,7 @@ const prepareMultipleCsvsText = (draws: number[][], paramNames: string[], drawCh }) } -const createZipBlobForMultipleCsvs = async (csvTexts: string[], uniqueChainIds: number[]) => { +const createZipBlobForMultipleCsvs = async (csvTexts: string[], uniqueChainIds: number[], samplingOpts: SamplingOpts) => { const zip = new JSZip(); // put them all in a folder called 'draws' const folder = zip.folder('draws'); @@ -245,6 +249,8 @@ const createZipBlobForMultipleCsvs = async (csvTexts: string[], uniqueChainIds: csvTexts.forEach((text, i) => { folder.file(`chain_${uniqueChainIds[i]}.csv`, text); }); + const samplingOptsText = JSON.stringify(samplingOpts, null, 2); + folder.file('sampling_opts.json', samplingOptsText); const blob = await zip.generateAsync({type: 'blob'}); return blob; } diff --git a/gui/src/app/StanSampler/StanSampler.ts b/gui/src/app/StanSampler/StanSampler.ts index 36bd6e9c..ebb16816 100644 --- a/gui/src/app/StanSampler/StanSampler.ts +++ b/gui/src/app/StanSampler/StanSampler.ts @@ -29,8 +29,8 @@ class StanSampler { #draws: number[][] = []; #computeTimeSec: number | undefined = undefined; #paramNames: string[] = []; - #numChains: number = 0; #samplingStartTimeSec: number = 0; + #samplingOpts: SamplingOpts = defaultSamplingOpts; // the sampling options used in the last sample call private constructor(private compiledUrl: string) { this._initialize() @@ -98,7 +98,6 @@ class StanSampler { console.warn('Number of chains not specified') return } - this.#numChains = sampleConfig.num_chains; if (this.#status === 'sampling') { console.warn('Already sampling') return @@ -107,6 +106,7 @@ class StanSampler { console.warn('Model not loaded yet') return } + this.#samplingOpts = samplingOpts; this.#draws = []; this.#paramNames = []; this.#worker @@ -137,9 +137,6 @@ class StanSampler { get paramNames() { return this.#paramNames; } - get numChains() { - return this.#numChains; - } get status() { return this.#status; } @@ -149,6 +146,9 @@ class StanSampler { get computeTimeSec() { return this.#computeTimeSec; } + get samplingOpts() { + return this.#samplingOpts; + } } const calculateReasonableRefreshRate = (samplingOpts: SamplingOpts) => { diff --git a/gui/src/app/StanSampler/useStanSampler.ts b/gui/src/app/StanSampler/useStanSampler.ts index 8cc0bb81..d163fde2 100644 --- a/gui/src/app/StanSampler/useStanSampler.ts +++ b/gui/src/app/StanSampler/useStanSampler.ts @@ -78,7 +78,7 @@ export const useSamplerOutput = (sampler: StanSampler | undefined) => { if (sampler.status === 'completed') { setDraws(sampler.draws); setParamNames(sampler.paramNames); - setNumChains(sampler.numChains); + setNumChains(sampler.samplingOpts.num_chains); setComputeTimeSec(sampler.computeTimeSec); } else {