Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion tb_plugin/fe/src/api/generated/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,25 @@ export interface Performance {
*/
children?: Array<Performance>
}
/**
*
* @export
* @interface Runs
*/
export interface Runs {
/**
*
* @type {Array<string>}
* @memberof Runs
*/
runs: Array<string>
/**
*
* @type {boolean}
* @memberof Runs
*/
loading: boolean
}
/**
*
* @export
Expand Down Expand Up @@ -2162,7 +2181,7 @@ export const DefaultApiFp = function (configuration?: Configuration) {
*/
runsGet(
options?: any
): (fetch?: FetchAPI, basePath?: string) => Promise<Array<string>> {
): (fetch?: FetchAPI, basePath?: string) => Promise<Runs> {
const localVarFetchArgs = DefaultApiFetchParamCreator(
configuration
).runsGet(options)
Expand Down
16 changes: 13 additions & 3 deletions tb_plugin/fe/src/api/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ paths:
content:
'*/*':
schema:
type: array
items:
type: string
$ref: '#/components/schemas/Runs'
/views:
get:
parameters:
Expand Down Expand Up @@ -453,6 +451,18 @@ paths:
type: object
components:
schemas:
Runs:
type: object
required:
- runs
- loading
properties:
runs:
type: array
items:
type: string
loading:
type: boolean
Performance:
type: object
required:
Expand Down
19 changes: 18 additions & 1 deletion tb_plugin/fe/src/app.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
*--------------------------------------------------------------------------------------------*/

import Card from '@material-ui/core/Card'
import CardContent from '@material-ui/core/CardContent'
import CardHeader from '@material-ui/core/CardHeader'
import ClickAwayListener from '@material-ui/core/ClickAwayListener'
import CssBaseline from '@material-ui/core/CssBaseline'
import Divider from '@material-ui/core/Divider'
Expand All @@ -15,6 +18,7 @@ import Select, { SelectProps } from '@material-ui/core/Select'
import { makeStyles } from '@material-ui/core/styles'
import ChevronLeftIcon from '@material-ui/icons/ChevronLeft'
import ChevronRightIcon from '@material-ui/icons/ChevronRight'
import Typography from '@material-ui/core/Typography'
import 'antd/es/button/style/css'
import 'antd/es/list/style/css'
import 'antd/es/table/style/css'
Expand Down Expand Up @@ -130,6 +134,7 @@ export const App = () => {

const [run, setRun] = React.useState<string>('')
const [runs, setRuns] = React.useState<string[]>([])
const [runsLoading, setRunsLoading] = React.useState(true)

const [workers, setWorkers] = React.useState<string[]>([])
const [worker, setWorker] = React.useState<string>('')
Expand All @@ -152,7 +157,8 @@ export const App = () => {
while (true) {
try {
const runs = await api.defaultApi.runsGet()
setRuns(runs)
setRuns(runs.runs)
setRunsLoading(runs.loading)
} catch (e) {
console.info('Cannot fetch runs: ', e)
}
Expand Down Expand Up @@ -248,6 +254,17 @@ export const App = () => {
}

const renderContent = () => {
if (!runsLoading && runs.length == 0) {
return (
<Card variant="outlined">
<CardHeader title="No Runs Found"></CardHeader>
<CardContent>
<Typography>There are not any runs in the log folder.</Typography>
</CardContent>
</Card>
)
}

if (!loaded || !run || !worker || !view || !span) {
return <FullCircularProgress />
}
Expand Down
11 changes: 9 additions & 2 deletions tb_plugin/test/test_tensorboard_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,22 @@ def _test_tensorboard(self, host, port, expected_runs, path_prefix):
try:
response = urllib.request.urlopen(run_link)
data = response.read()
if data == expected_runs:
runs = None
if data:
data = json.loads(data)
runs = data.get("runs")
if runs:
runs = '[{}]'.format(", ".join(['"{}"'.format(i) for i in runs]))
runs = runs.encode('utf-8')
if runs == expected_runs:
break
if retry_times % 10 == 0:
print("receive mismatched data, retrying", data)
time.sleep(2)
retry_times -= 1
if retry_times<0:
self.fail("Load run timeout")
except Exception:
except Exception as e:
if retry_times > 0:
continue
else:
Expand Down
45 changes: 28 additions & 17 deletions tb_plugin/torch_tb_profiler/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def __init__(self, context):
mp.set_start_method(start_method, force=True)
self.logdir = io.abspath(context.logdir.rstrip('/'))

self._is_active = None
self._is_active_initialized_event = threading.Event()
self._load_lock = threading.Lock()
self._load_threads = []

self._runs = OrderedDict()
self._runs_lock = threading.Lock()
Expand All @@ -76,8 +76,7 @@ def clean():
def is_active(self):
"""Returns whether there is relevant data for the plugin to process.
"""
self._is_active_initialized_event.wait()
return self._is_active
return True

def get_plugin_apps(self):
return {
Expand All @@ -104,13 +103,21 @@ def get_plugin_apps(self):
}

def frontend_metadata(self):
return base_plugin.FrontendMetadata(es_module_path="/index.js")
return base_plugin.FrontendMetadata(es_module_path="/index.js", disable_reload=True)

@wrappers.Request.application
def runs_route(self, request):
with self._runs_lock:
names = list(self._runs.keys())
return self.respond_as_json(names)

with self._load_lock:
loading = bool(self._load_threads)

data = {
"runs": names,
"loading": loading
}
return self.respond_as_json(data)

@wrappers.Request.application
def views_route(self, request):
Expand All @@ -130,7 +137,6 @@ def workers_route(self, request):
self._validate(run=name, view=view)
run = self._get_run(name)
self._check_run(run, name)
workers = run.get_workers(view)
return self.respond_as_json(run.get_workers(view))

@wrappers.Request.application
Expand Down Expand Up @@ -305,19 +311,22 @@ def _monitor_runs(self):
logger.debug("Scan run dir")
run_dirs = self._get_run_dirs()

has_dir = False
# Assume no deletion on run directories, trigger async load if find a new run
for run_dir in run_dirs:
# Set _is_active quickly based on file pattern match, don't wait for data loading
if not self._is_active:
self._is_active = True
self._is_active_initialized_event.set()

has_dir = True
if run_dir not in touched:
touched.add(run_dir)
logger.info("Find run directory %s", run_dir)
# Use threading to avoid UI stall and reduce data parsing time
t = threading.Thread(target=self._load_run, args=(run_dir,))
t.start()
with self._load_lock:
self._load_threads.append(t)

if not has_dir:
# handle directory removed case.
self._runs.clear()
except Exception as ex:
logger.warning("Failed to scan runs. Exception=%s", ex, exc_info=True)

Expand All @@ -338,11 +347,6 @@ def _receive_runs(self):
if is_new:
self._runs = OrderedDict(sorted(self._runs.items()))

# Update is_active
if not self._is_active:
self._is_active = True
self._is_active_initialized_event.set()

def _get_run_dirs(self):
"""Scan logdir, find PyTorch Profiler run directories.
A directory is considered to be a run if it contains 1 or more *.pt.trace.json[.gz].
Expand Down Expand Up @@ -371,6 +375,13 @@ def _load_run(self, run_dir):
except Exception as ex:
logger.warning("Failed to load run %s. Exception=%s", ex, name, exc_info=True)

t = threading.current_thread()
with self._load_lock:
try:
self._load_threads.remove(t)
except ValueError:
logger.warning("could not find the thread {}".format(run_dir))

def _get_run(self, name) -> Run:
with self._runs_lock:
return self._runs.get(name, None)
Expand Down
2 changes: 1 addition & 1 deletion tb_plugin/torch_tb_profiler/static/index.html

Large diffs are not rendered by default.