diff --git a/app.py b/app.py deleted file mode 100644 index 3a86d70..0000000 --- a/app.py +++ /dev/null @@ -1,72 +0,0 @@ -import dash_bootstrap_components as dbc -from dash import Dash, html, dcc -from collections import defaultdict - -from callbacks.update_res import * -from components.header import layout as header -from components.progress import layout as tomo_progress -from components.proteins import layout as protein_sts -from components.waitlist import layout as unlabelled_tomos -from components.annotators import layout as ranking -from components.composition import layout as composition -from components.popups import layout as popups - - - - -def create_app(): - external_stylesheets = [dbc.themes.BOOTSTRAP, - "assets/header-style.css", - "https://codepen.io/chriddyp/pen/bWLwgP.css", - "https://use.fontawesome.com/releases/v5.10.2/css/all.css"] - - app = Dash(__name__, external_stylesheets=external_stylesheets) - - browser_cache =html.Div( - id="no-display", - children=[ - dcc.Interval( - id='interval-component', - interval=20*1000, # clientside check in milliseconds, 10s - n_intervals=0 - ), - dcc.Store(id='tomogram-index', data=''), - dcc.Store(id='keybind-num', data=''), - dcc.Store(id='run-dt', data=defaultdict(list)) - ], - ) - - - app.layout = html.Div( - [ - header(), - popups(), - dbc.Container( - [ - dbc.Row( - [ - dbc.Col([tomo_progress(), - unlabelled_tomos() - ], - width=3), - dbc.Col(ranking(), width=3), - dbc.Col(composition(), width=3), - dbc.Col(protein_sts(), width=3), - ], - justify='center', - className="h-100", - ), - ], - fluid=True, - ), - html.Div(browser_cache) - ], - ) - return app - - - -if __name__ == "__main__": - dash_app = create_app() - dash_app.run_server(host="0.0.0.0", port=8000, debug=False) - diff --git a/callbacks/update_res.py b/callbacks/update_res.py deleted file mode 100644 index 6b08075..0000000 --- a/callbacks/update_res.py +++ /dev/null @@ -1,545 +0,0 @@ -import plotly.express as px -import dash_bootstrap_components as dbc -import json, time -import pandas as pd -from collections import defaultdict -from apscheduler.schedulers.background import BackgroundScheduler - -import time -from utils.copick_dataset import copick_dataset -from utils.figure_utils import ( - blank_fig, - draw_gallery -) -from utils.local_dataset import ( - local_dataset, - dirs, - dir2id, - COUNTER_FILE_PATH, - #CACHE_ROOT, -) -from dash import ( - html, - Input, - Output, - callback, - State, - ALL, - MATCH, - ctx, - dcc, - no_update -) -from dash.exceptions import PreventUpdate - - -from dash_iconify import DashIconify -import base64 - - - -def submission_list(i,j): - return dbc.ListGroupItem("{}: {}".format(i.split('.json')[0], j)) - - -import io -def parse_contents(contents, filename, date): - content_type, content_string = contents.split(',') - decoded = base64.b64decode(content_string) - fig = [] - try: - if 'csv' in filename: - # Assume that the user uploaded a CSV file - df = pd.read_csv( - io.StringIO(decoded.decode('utf-8'))) - df = df.sort_values(by=['Aggregate_Fbeta'], ascending=False) - df = df.reset_index(drop=True) - df['rank'] = df.index - #df = df[['File', 'Aggregate_Fbeta']] - dict_df = df.set_index('File')['Aggregate_Fbeta'].to_dict() - print(dict_df) - fig = px.scatter(df, x='rank', y='Aggregate_Fbeta', hover_name='File', title='Submitted model ranking') - elif 'xls' in filename: - # Assume that the user uploaded an excel file - df = pd.read_excel(io.BytesIO(decoded)) - df = df.sort_values(by=['Aggregate_Fbeta'], ascending=False) - df = df[['File', 'Aggregate_Fbeta']] - except Exception as e: - print(e) - return html.Div([ - 'There was an error processing this file.' - ]) - - return dbc.Card([ - dbc.CardHeader([DashIconify(icon="noto-v1:trophy", width=25, style={"margin": "5px"}), - 'Submitted model ranking' - ], - style={"font-weight": "bold"} - ), - dbc.CardBody(id='submission-rank', children=dcc.Graph(figure=fig), style={'overflowY': 'scroll'}) - ], - style={"height": '87vh'} - ) - - - -# 1st update of the internal states -local_dataset.refresh() - -#Scheduler -scheduler = BackgroundScheduler() # in-memory job stores -scheduler.add_job(func=local_dataset.refresh, trigger='interval', seconds=20) # interval should be larger than the time it takes to refresh, o.w. it will be report incomplete stats. -scheduler.start() - - -roundbutton = { - "border": 'transparent', - #"border-radius": "100%", - "padding": 0, - "backgroundColor": 'transparent', - "color": "black", - "textAlign": "center", - "display": "block", - "fontSize": 9, - "height": 9, - "width": 9, - "margin-left": 10, - "margin-top": 8, -} - - - -def candidate_list(i, j): - return dbc.ListGroupItem("{} (labeled by {} person)".format(dirs[i], j)) - -def ranking_list(i, j): - return dbc.ListGroupItem("{} {} tomograms".format(i, j)) - - - -############################################## Callbacks ############################################## -@callback( - Output("modal-help", "is_open"), - Input("button-help", "n_clicks"), - State("modal-help", "is_open"), - prevent_initial_call=True -) -def toggle_help_modal(n_clicks, is_open): - return not is_open - - -@callback( - Output("modal-results", "is_open"), - Input("button-results", "n_clicks"), - State("modal-results", "is_open"), - prevent_initial_call=True -) -def toggle_help_modal(n_clicks, is_open): - return not is_open - - -@callback(Output('output-data-upload', 'children'), - Input('upload-data', 'contents'), - State('upload-data', 'filename'), - State('upload-data', 'last_modified') - ) -def update_output(list_of_contents, list_of_names, list_of_dates): - if list_of_contents is not None: - children = [ - parse_contents(c, n, d) for c, n, d in - zip(list_of_contents, list_of_names, list_of_dates)] - return children - - -@callback( - Output("tomogram-index", "data"), - Input({"type": "tomogram-eval-bttn", "index": ALL}, "n_clicks"), - prevent_initial_call=True -) -def update_tomogram_index(n_clicks): - if any(n_clicks): - changed_id = [p['prop_id'] for p in ctx.triggered][0].split(".")[0] - if "index" in changed_id: - tomogram_index = json.loads(changed_id)["index"] - return tomogram_index - - -@callback( - Output("collapse1", "is_open"), - Output("collapse2", "is_open"), - Input("tabs", "active_tab"), - prevent_initial_call=True -) -def toggle_analysis_tabs(at): - if at == "tab-1": - return True, False - elif at == "tab-2": - return False, True - - -@callback( - Output("output-image-upload", "children", allow_duplicate=True), - Output("image-slider", "value"), - Output("particle-dropdown", "value"), - Output("modal-evaluation", "is_open"), - Output("tabs", "active_tab"), - Output("choose-results", "children"), - Input("tomogram-index", "data"), - prevent_initial_call=True -) -def reset_analysis_popup(tomogram_index): - msg = f"Choose results for {tomogram_index}" - if tomogram_index is not None: - return [], 0, None, True, "tab-1", msg - else: - return [], 0, None, False, "tab-1", msg - - -@callback( - Output("run-dt", "data"), - Input("tomogram-index", "data"), - prevent_initial_call=True -) -def load_tomogram_run(tomogram_index): - dt = defaultdict(list) - if tomogram_index is not None: - # takes 18s for VPN - t1 = time.time() - copick_dataset.load_curr_run(run_name=tomogram_index, sort_by_score=True) - # takes 0.2s - t2 = time.time() - print('find copick run in copick', t2-t1) - - - return dt - - -@callback( - Output("image-slider", "value", allow_duplicate=True), - Input("particle-dropdown", "value"), - Input("display-row", "value"), - Input("display-col", "value"), - prevent_initial_call=True -) -def reset_slider(value, nrow, ncol): - return 0 - - -@callback( - Output("output-image-upload", "children"), - Output("particle-dropdown", "options"), - Output("fig1", "figure"), - Output("image-slider", "max"), - Output("image-slider", "marks"), - Output("crop-label", "children"), - Output("image-slider", "value", allow_duplicate=True), - Output("keybind-num", "data"), - Input("tabs", "active_tab"), - Input("image-slider", "value"), - Input("crop-width", "value"), - Input("crop-avg", "value"), - Input("particle-dropdown", "value"), - Input("accept-bttn", "n_clicks"), - Input("reject-bttn", "n_clicks"), - Input("assign-bttn", "n_clicks"), - Input("username-analysis", "value"), - Input("keybind-event-listener", "event"), - Input("keybind-event-listener", "n_events"), - Input("display-row", "value"), - Input("display-col", "value"), - State("tomogram-index", "data"), - State("fig1", "figure"), - State("output-image-upload", "children"), - State("keybind-num", "data"), - State({'type': 'thumbnail-image', 'index': ALL}, 'n_clicks'), - State("assign-dropdown", "value"), - prevent_initial_call=True -) -def update_analysis( - at, - slider_value, - crop_width, - crop_avg, - particle, - accept_bttn, - reject_bttn, - assign_bttn, - copicklive_username, - keybind_event_listener, - n_events, - nrow, - ncol, - tomogram_index, - fig1, - fig2, - kbn, - thumbnail_image_select_value, - new_particle -): - pressed_key = None - if ctx.triggered_id == "keybind-event-listener": - #user is going to type in the class creation/edit modals and we don't want to trigger this callback using keys - pressed_key = ( - keybind_event_listener.get("key", None) if keybind_event_listener else None - ) - if not pressed_key: - raise PreventUpdate - else: - print(f'pressed_key {pressed_key}') - - slider_max = 0 - changed_id = [p['prop_id'] for p in ctx.triggered][0] - # takes 0.35s on mac3 - if tomogram_index: - #time.sleep(7) - #particle_dict = {k: k for k in sorted(set(copick_dataset.dt['pickable_object_name']))} - if at == "tab-1": - time.sleep(2) - particle_dict = {k: k for k in sorted(set(copick_dataset.dt['pickable_object_name']))} - df = pd.DataFrame.from_dict(copick_dataset.dt) - fig1 = px.scatter_3d(df, x='x', y='y', z='z', color='pickable_object_name', symbol='user_id', size='size', opacity=0.5) - return fig2, particle_dict, fig1, slider_max, {0: '0', slider_max: str(slider_max)}, no_update, no_update, no_update - elif at == "tab-2": - #new_particle = None - if pressed_key in [str(i+1) for i in range(len(local_dataset._im_dataset['name']))]: - new_particle = local_dataset._im_dataset['name'][int(pressed_key)-1] - elif pressed_key == 's': - new_particle = kbn - - copick_dataset.new_user_id(user_id=copicklive_username) - if ("display-row" in changed_id or\ - "display-col" in changed_id) or \ - particle in copick_dataset.points_per_obj: - if len(copick_dataset.points_per_obj[particle])%(nrow*ncol): - slider_max = len(copick_dataset.points_per_obj[particle])//(nrow*ncol) - else: - slider_max = len(copick_dataset.points_per_obj[particle])//(nrow*ncol) - 1 - - positions = [i for i in range(slider_value*nrow*ncol, min((slider_value+1)*nrow*ncol, len(copick_dataset.points_per_obj[particle])))] - # loading zarr takes 6-8s for VPN - particle_dict = {k: k for k in sorted(set(copick_dataset.dt['pickable_object_name']))} - dim_z, dim_y, dim_x = copick_dataset.tomogram.shape - msg = f"Image crop width (max {min(dim_x, dim_y)})" - if crop_width is not None: - half_width = crop_width//2 - if crop_avg is None: - crop_avg = 0 - fig2 = draw_gallery(run=tomogram_index, particle=particle, positions=positions, hw=half_width, avg=crop_avg, nrow=nrow, ncol=ncol) - - - selected = [i for i,v in enumerate(thumbnail_image_select_value) if v%2 == 1] - selected_point_ids = [positions[i] for i in selected] - if 'accept-bttn' in changed_id or pressed_key=='a': - copick_dataset.handle_accept_batch(selected_point_ids, particle) - elif 'reject-bttn' in changed_id or pressed_key=='d': - copick_dataset.handle_reject_batch(selected_point_ids, particle) - elif 'assign-bttn' in changed_id or pressed_key=='s': - copick_dataset.handle_assign_batch(selected_point_ids, particle, new_particle) - - # update figures - # if 'accept-bttn' in changed_id or \ - # 'reject-bttn' in changed_id or \ - # 'assign-bttn' in changed_id or \ - # pressed_key in ['a', 'd', 's']: - # slider_value += 1 - # positions = [i for i in range(slider_value*nrow*ncol, min((slider_value+1)*nrow*ncol, len(copick_dataset.points_per_obj[particle])))] - # fig2 = draw_gallery(run=tomogram_index, particle=particle, positions=positions, hw=half_width, avg=crop_avg, nrow=nrow, ncol=ncol) - - if 'assign-bttn' in changed_id or pressed_key == 's': - positions = [i for i in range(slider_value*nrow*ncol, min((slider_value+1)*nrow*ncol, len(copick_dataset.points_per_obj[particle])))] - fig2 = draw_gallery(run=tomogram_index, particle=particle, positions=positions, hw=half_width, avg=crop_avg, nrow=nrow, ncol=ncol) - - if pressed_key=='ArrowRight' and slider_value < slider_max: - slider_value += 1 - positions = [i for i in range(slider_value*nrow*ncol, min((slider_value+1)*nrow*ncol, len(copick_dataset.points_per_obj[particle])))] - fig2 = draw_gallery(run=tomogram_index, particle=particle, positions=positions, hw=half_width, avg=crop_avg, nrow=nrow, ncol=ncol) - elif pressed_key=='ArrowLeft' and slider_value: - slider_value -= 1 - positions = [i for i in range(slider_value*nrow*ncol, min((slider_value+1)*nrow*ncol, len(copick_dataset.points_per_obj[particle])))] - fig2 = draw_gallery(run=tomogram_index, particle=particle, positions=positions, hw=half_width, avg=crop_avg, nrow=nrow, ncol=ncol) - - return fig2, particle_dict, blank_fig(), slider_max, {0: '0', slider_max: str(slider_max)}, msg, slider_value, new_particle - else: - return fig2, dict(), blank_fig(), slider_max, {0: '0', slider_max: str(slider_max)}, no_update, no_update, no_update - - - - -@callback( - Output({'type': 'thumbnail-card', 'index': MATCH}, 'color'), - Input({'type': 'thumbnail-image', 'index': MATCH}, 'n_clicks'), - Input('select-all-bttn', 'n_clicks'), - Input('unselect-all-bttn', 'n_clicks'), - State("image-slider", "value"), - State("display-row", "value"), - State("display-col", "value"), - State("particle-dropdown", "value"), - State({'type': 'thumbnail-image', 'index': MATCH}, 'id'), -) -def select_thumbnail(value, - select_clicks, - unselect_clicks, - slider_value, - nrow, ncol, - particle, - comp_id): - ''' - This callback assigns a color to thumbnail cards in the following scenarios: - - An image has been selected, but no label has been assigned (blue) - - An image has been labeled (label color) - - An image has been unselected or unlabeled (no color) - Args: - value: Thumbnail card that triggered the callback (n_clicks) - Returns: - thumbnail_color: Color of thumbnail card - ''' - color = '' - colors = ['', 'success', 'danger', 'warning'] - positions = [i for i in range(slider_value*nrow*ncol, min((slider_value+1)*nrow*ncol, len(copick_dataset.points_per_obj[particle])))] - #print(f'positions {positions}') - selected = [copick_dataset.picked_points_mask[copick_dataset.points_per_obj[particle][i][0]] for i in positions] - #print(f'selected {selected}') - #print(f'comp_id {comp_id}') - color = colors[selected[int(comp_id['index'])]] - #print(f'color {color} value {value}') - if value is None or (ctx.triggered[0]['prop_id'] == 'unselect-all-bttn.n_clicks' and color==''): - return '' - if value % 2 == 1: - return 'primary' - else: - return color - - -@callback( - Output({'type': 'thumbnail-image', 'index': ALL}, 'n_clicks'), - - # Input({'type': 'label-button', 'index': ALL}, 'n_clicks_timestamp'), - # Input('un-label', 'n_clicks'), - Input('select-all-bttn', 'n_clicks'), - Input('unselect-all-bttn', 'n_clicks'), - - State({'type': 'thumbnail-image', 'index': ALL}, 'n_clicks'), - prevent_initial_call=True -) -def deselect(select_clicks, unselect_clicks, thumb_clicked): - ''' - This callback deselects a thumbnail card - Args: - label_button_trigger: Label button - unlabel_n_clicks: Un-label button - unlabel_all: Un-label all the images - thumb_clicked: Selected thumbnail card indice, e.g., [0,1,1,0,0,0] - Returns: - Modify the number of clicks for a specific thumbnail card - ''' - # if all(x is None for x in label_button_trigger) and unlabel_n_clicks is None and unlabel_all is None: - # return [no_update]*len(thumb_clicked) - if ctx.triggered[0]['prop_id'] == 'unselect-all-bttn.n_clicks': - print([0 for thumb in thumb_clicked]) - return [0 for thumb in thumb_clicked] - elif ctx.triggered[0]['prop_id'] == 'select-all-bttn.n_clicks': - return [1 for thumb in thumb_clicked] - - - - - -@callback( - Output("download-json", "data"), - Input("btn-download", "n_clicks"), - State("username", "value"), - prevent_initial_call=True, -) -def download_json(n_clicks, input_value): - input_value = '.'.join(input_value.split(' ')) - filename = 'copick_config_' + '_'.join(input_value.split('.')) + '.json' - local_dataset.config_file["user_id"] = input_value - return dict(content=json.dumps(local_dataset.config_file, indent=4), filename=filename) - - -@callback( - Output("download-txt", "data"), - Input("btn-download-txt", "n_clicks"), - #State("username", "value"), - prevent_initial_call=True, -) -def download_txt(n_clicks): - print(f'COUNTER_FILE_PATH 0 {COUNTER_FILE_PATH}') - if COUNTER_FILE_PATH: - with open(COUNTER_FILE_PATH) as f: - counter = json.load(f) - - if counter['repeat'] == 2: - counter['start'] += counter['tasks_per_person'] - counter['repeat'] = 0 - - counter['repeat'] += 1 - task_contents = '\n'.join(dirs[counter['start']:counter['start']+counter['tasks_per_person']]) - print(task_contents) - task_filename = 'task_recommendation.txt' - - with open(COUNTER_FILE_PATH, 'w') as f: - f.write(json.dumps(counter, indent=4)) - - return dict(content=task_contents, filename=task_filename) - - -@callback( - Output('proteins-histogram', 'figure'), - Output('waitlist', 'children'), - Output('rank', 'children'), - Output('total-labeled', 'children'), - Output('progress-bar', 'value'), - Output('progress-bar', 'label'), - Input('interval-component', 'n_intervals') -) -def update_results(n): - data = local_dataset.fig_data() - fig = px.bar(x=data['name'], - y=data['count'], - labels={'x': 'Objects', 'y':'Counts'}, - text_auto=True, - color = data['name'], - color_discrete_map = data['colors'], - ) - fig.update(layout_showlegend=False) - num_candidates = len(dirs) if len(dirs) < 100 else 100 - candidates = local_dataset.candidates(num_candidates, random_sampling=False) - num_per_person_ordered = local_dataset.num_per_person_ordered - label = f'Labeled {len(local_dataset.tomos_pickers)} out of 1000 tomograms' - bar_val = round(len(local_dataset.tomos_pickers)/1000*100, 1) - - return fig, \ - dbc.ListGroup([candidate_list(i, j) for i, j in candidates.items()], flush=True), \ - dbc.ListGroup([ranking_list(i, len(j)) for i, j in num_per_person_ordered.items()], numbered=True), \ - [label], \ - bar_val, \ - f'{bar_val}%' - - -@callback( - Output('composition', 'children'), - #Input('interval-component', 'n_intervals'), - Input('refresh-button', 'n_clicks') -) -def update_compositions(n): - progress_list = [] - composition_list = html.Div() - data = local_dataset.fig_data() - l = 1/len(data['colors'])*100 - obj_order = {name:i for i,name in enumerate(data['name'])} - tomograms = {k:v for k,v in sorted(local_dataset.tomograms.items(), key=lambda x: dir2id[x[0]])} - for tomogram,ps in tomograms.items(): - progress = [] - ps = [p for p in ps if p in obj_order] - ps = sorted(list(ps), key=lambda x: obj_order[x]) - for p in ps: - progress.append(dbc.Progress(value=l, color=data['colors'][p], bar=True)) - - bttn = html.Button(id={"type": "tomogram-eval-bttn", "index": tomogram}, className="fa fa-search", style=roundbutton) - progress_list.append(dbc.ListGroupItem(children=[dbc.Row([tomogram, bttn]), dbc.Progress(progress)], style={"border": 'transparent'})) - - composition_list = dbc.ListGroup(progress_list) - return composition_list - - diff --git a/components/popups.py b/components/popups.py deleted file mode 100644 index bee905e..0000000 --- a/components/popups.py +++ /dev/null @@ -1,219 +0,0 @@ -from dash import html, dcc -import dash_bootstrap_components as dbc -import plotly.graph_objects as go -from utils.local_dataset import local_dataset -from dash_extensions import EventListener - -def blank_fig(): - """ - Creates a blank figure with no axes, grid, or background. - """ - fig = go.Figure() - fig.update_layout(template=None) - fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False) - fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False) - - return fig - - -instructions = [dcc.Markdown(''' - Thanks for participating in CZII Pickathon! We highly encourage labeling all the 6 types of prteins in the tomogram. - ### Tools installation - #### ChimeraX - 1. Download and install [ChimeraX](https://www.cgl.ucsf.edu/chimerax/download.html) version 1.7.0+ - 2. Download the [copick plugin](https://cxtoolshed.rbvi.ucsf.edu/apps/chimeraxcopick) - 3. Open ChimeraX and install the copick wheel by typing in the command line: `toolshed install /path/to/your/ChimeraX_copick-0.1.3-py3-none-any.whl` - - #### Configuration files - Auto-generate the copick configuration file and a tomogram recomendation list for you (5 tomograms per file). - ''', - link_target='_blank' - ), - html.Div( - children=[ - dbc.Input(id='username', placeholder="Please input your name (e.g., john.doe)", type="text"), - dbc.Button("Download copick config file", id='btn-download', outline=True, color="primary", className="me-1"), - dbc.Button("Download recommendation file", id='btn-download-txt', outline=True, color="primary", className="me-1"), - dcc.Download(id="download-json"), - dcc.Download(id="download-txt"), - ], - className="d-grid gap-3 col-4 mx-auto", - ), - dcc.Markdown(''' - ### Particle picking - The default workflow for ChimeraX should be: - 1. Type the command `copick start /path/to/config`. It will take about 2-3 mins to load the entire dataset tree. - 2. Open a tomogram by navigating the tree and double-clicking. - 3. Select or double click a pre-picked list from the upper table (double click will load the list). - 4. Press the ▼ ▼ ▼ button to copy the contents to the "editable" lower table. - 5. Select the Copick tab at the top right corner and choose a tool in the `Place Particles` session. Start editing by right click. Your picking results will be automatically saved. - ''') - ] - - - -competition_results = [ - dcc.Upload( - id='upload-data', - children=html.Div([ - 'Drag and Drop or ', - html.A('Select Files') - ]), - style={ - 'width': '95%', - 'height': '60px', - 'lineHeight': '60px', - 'borderWidth': '1px', - 'borderStyle': 'dashed', - 'borderRadius': '5px', - 'textAlign': 'center', - 'margin': '10px' - }, - # Allow multiple files to be uploaded - multiple=True - ), - html.Div(id='output-data-upload'), -] - - - -tabs = html.Div( - [ - dbc.Tabs( - [ - dbc.Tab(label="Picked points visualization", tab_id="tab-1"), - dbc.Tab(label="2D Plane Inspection", tab_id="tab-2"), - #dbc.Tab(label="3D Volume Inspection", tab_id="tab-3"), - ], - id="tabs", - active_tab="tab-1", - ), - html.Div([ - dbc.Label("Choose results", id='choose-results', style={'margin-top': '35px', 'margin-left': '7px'}), - dcc.Dropdown(["Pickathon results"], 'Pickathon results', id='pick-dropdown', style={'width':'42%', 'justify-content': 'center', 'margin-bottom': '0px', 'margin-left': '4px'}), - dbc.Collapse(id="collapse1",is_open=False, children=dbc.Spinner(dcc.Graph(id='fig1', figure=blank_fig()), spinner_style={"width": "5rem", "height": "5rem"})), - dbc.Collapse(id="collapse2",is_open=False, children=dbc.Container([dbc.Row( - [ - dbc.Col([ - dbc.Label("Please input your name", style={'margin-top': '-20px'}), - dbc.Input(id='username-analysis', placeholder="e.g., john.doe", type="text", style={'width': '75%'}), - dbc.Label("Please select a particle type", className="mt-3"), - dcc.Dropdown(id='particle-dropdown', style={'width': '87%'}), - dbc.Label("Number of rows", className="mt-3"), - dcc.Input(id="display-row",type="number", placeholder="5", value =5, min=1, step=1), - dbc.Label("Number of columns", className="mt-3"), - dcc.Input(id="display-col",type="number", placeholder="4", value =4, min=1, step=1), - dbc.Label(id='crop-label', children="Image crop size (max 100)", className="mt-3"), - dcc.Input(id="crop-width",type="number", placeholder="30", value =60, min=1, step=1), - dbc.Label("Average ±N neigbor layers", className="mt-3"), - dcc.Input(id="crop-avg", type="number", placeholder="3", value =2, min=0, step=1), - dbc.Label("Page slider (press key < or >)", className="mt-3"), - html.Div(dcc.Slider( - id='image-slider', - min=0, - max=200, - value = 0, - step = 1, - updatemode='drag', - tooltip={"placement": "top", "always_visible": True}, - marks={0: '0', 199: '199'}, - ), style={'width':'72%', 'margin-top': '10px'}), - ], - width=3, - align="center" - ), - dbc.Col([ - html.Div(id='output-image-upload',children=[], style={"height":"70vh", 'overflowY': 'scroll'}), - # dcc.Graph(id='fig2', - # figure=blank_fig(), - # style={ - # "width": "100%", - # "height": "100%", - # }) - ], - width=5, - align="top", - ), - dbc.Col([ - dbc.Row([ - dbc.Col(dbc.Row(dbc.Button('Select All', id='select-all-bttn', style={'width': '50%'}, color='primary', className="me-1"), justify='end')), - dbc.Col(dbc.Row(dbc.Button('Unselect All', id='unselect-all-bttn', style={'width': '50%'}, color='primary', className="me-1"), justify='start')) - ], - justify='evenly', - style={'margin-bottom': '40px'} - ), - dbc.Row([ - dbc.Col(dbc.Row(dbc.Button('(D) Reject', id='reject-bttn', style={'width': '50%'}, color='danger', className="me-1"), justify='end')), - dbc.Col(dbc.Row(dbc.Button('(A) Accept', id='accept-bttn', style={'width': '50%'}, color='success', className="me-1"), justify='start')) - ], - justify='evenly', - style={'margin-bottom': '40px'} - ), - dbc.Row([ - dbc.Col(dbc.Row(dbc.Button('(S) Assign', id='assign-bttn', style={'width': '25%', 'margin-left':'90px'}, color='primary', className="me-1"), justify='start')), - #dbc.Col(dbc.Row(dbc.Button('Select All', id='select-all-bttn', style={'width': '50%'}, color='primary', className="me-1"), justify='start')), - #dbc.Col(dbc.Row(dbc.ListGroup([dbc.ListGroupItem(f'({str(i+1)}) {k}') for i,k in enumerate(local_dataset._im_dataset['name'])]), justify='start')) - ], - justify='evenly', - style={'margin-bottom': '5px'} - ), - dbc.Row([dbc.Col(dbc.Row(dcc.Dropdown(id='assign-dropdown', options={k:k for k in local_dataset._im_dataset['name']}, style={'width': '75%', 'margin-left':'-10px'}), justify='end')) - ], - justify='evenly') - ], - width=4, - align="right", - ), - ], - justify='center', - align="center", - className="h-100", - ), - ], - fluid=True, - ), - ), - #dbc.Collapse(id="collapse3",is_open=False, children=dcc.Graph(id='fig3', figure=blank_fig())), - EventListener( - events=[ - { - "event": "keydown", - "props": ["key", "ctrlKey", "ctrlKey"], - } - ], - id="keybind-event-listener", - ), - ]), - ] - ) - - - -def layout(): - return html.Div([ - dbc.Modal([ - dbc.ModalHeader(dbc.ModalTitle("Instructions")), - dbc.ModalBody(id='modal-body-help', children=instructions), - ], - id="modal-help", - is_open=False, - size="xl" - ), - dbc.Modal([ - dbc.ModalHeader(dbc.ModalTitle("Submission Results")), - dbc.ModalBody(id='modal-body-results', children=competition_results), - ], - id="modal-results", - is_open=False, - size="xl" - ), - dbc.Modal([ - #dbc.ModalHeader(dbc.ModalTitle("Tomogram Evaluation")), - dbc.ModalBody(id='modal-body-evaluation', children=tabs), - ], - id="modal-evaluation", - is_open=False, - centered=True, - size='xl' - ), - ]) diff --git a/copick_live/__init__.py b/copick_live/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/copick_live/app.py b/copick_live/app.py new file mode 100644 index 0000000..4fe2cba --- /dev/null +++ b/copick_live/app.py @@ -0,0 +1,73 @@ +import dash_bootstrap_components as dbc +from dash import Dash, html, dcc +from collections import defaultdict +import argparse + +from copick_live.callbacks.update_res import * +from copick_live.components.header import layout as header +from copick_live.components.progress import layout as tomo_progress +from copick_live.components.proteins import layout as protein_sts +from copick_live.components.waitlist import layout as unlabelled_tomos +from copick_live.components.annotators import layout as ranking +from copick_live.components.composition import layout as composition +from copick_live.components.popups import layout as popups + +from copick_live.config import get_config + +def create_app(config_path=None): + config = get_config(config_path) + external_stylesheets = [ + dbc.themes.BOOTSTRAP, + "assets/header-style.css", + "https://codepen.io/chriddyp/pen/bWLwgP.css", + "https://use.fontawesome.com/releases/v5.10.2/css/all.css", + ] + + initialize_app() + app = Dash(__name__, external_stylesheets=external_stylesheets) + + browser_cache = html.Div( + id="no-display", + children=[ + dcc.Interval( + id="interval-component", + interval=20 * 1000, # clientside check in milliseconds, 20s + n_intervals=0, + ), + dcc.Store(id="tomogram-index", data=""), + dcc.Store(id="keybind-num", data=""), + dcc.Store(id="run-dt", data=defaultdict(list)), + ], + ) + + app.layout = html.Div( + [ + header(), + popups(), + dbc.Container( + [ + dbc.Row( + [ + dbc.Col([tomo_progress(), unlabelled_tomos()], width=3), + dbc.Col(ranking(), width=3), + dbc.Col(composition(), width=3), + dbc.Col(protein_sts(), width=3), + ], + justify="center", + className="h-100", + ), + ], + fluid=True, + ), + html.Div(browser_cache), + ], + ) + return app + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Run the Dash application with a specific config path.') + parser.add_argument('--config-path', type=str, help='Path to the configuration file.', required=False) + args = parser.parse_args() + + dash_app = create_app(config_path=args.config_path) + dash_app.run_server(host="0.0.0.0", port=8000, debug=False) diff --git a/copick_live/app_experimental.py b/copick_live/app_experimental.py new file mode 100644 index 0000000..9855f65 --- /dev/null +++ b/copick_live/app_experimental.py @@ -0,0 +1,66 @@ +import dash_bootstrap_components as dbc +from dash import Dash, html, dcc +from collections import defaultdict +import argparse +from copick_live.components.header import layout as header +from copick_live.config import get_config +from copick_live.components.project_explorer import layout as project_explorer + +def create_app(config_path=None): + config = get_config(config_path) + external_stylesheets = [ + dbc.themes.BOOTSTRAP, + "assets/header-style.css", + "https://codepen.io/chriddyp/pen/bWLwgP.css", + "https://use.fontawesome.com/releases/v5.10.2/css/all.css", + ] + + app = Dash(__name__, external_stylesheets=external_stylesheets) + + browser_cache = html.Div( + id="no-display", + children=[ + dcc.Interval( + id="interval-component", + interval=20 * 1000, # clientside check in milliseconds, 20s + n_intervals=0, + ), + dcc.Store(id="tomogram-index", data=""), + dcc.Store(id="keybind-num", data=""), + dcc.Store(id="run-dt", data=defaultdict(list)), + ], + ) + + app.layout = html.Div( + [ + header(), + dbc.Container( + [ + dbc.Row( + [ + dbc.Col( + [ + html.H2("CoPick Project Explorer"), + project_explorer(), + ], + width=12, + ), + ], + className="mb-4", + ), + ], + fluid=True, + ), + html.Div(browser_cache), + ], + ) + + return app + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Run the experimental Dash application with a specific config path.') + parser.add_argument('--config-path', type=str, help='Path to the configuration file.', required=False) + args = parser.parse_args() + + dash_app = create_app(config_path=args.config_path) + dash_app.run_server(host="0.0.0.0", port=8000, debug=True) diff --git a/assets/czii.png b/copick_live/assets/czii.png similarity index 100% rename from assets/czii.png rename to copick_live/assets/czii.png diff --git a/assets/czii_logo.png b/copick_live/assets/czii_logo.png similarity index 100% rename from assets/czii_logo.png rename to copick_live/assets/czii_logo.png diff --git a/assets/header-style.css b/copick_live/assets/header-style.css similarity index 100% rename from assets/header-style.css rename to copick_live/assets/header-style.css diff --git a/copick_live/callbacks/__init__.py b/copick_live/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/copick_live/callbacks/update_res.py b/copick_live/callbacks/update_res.py new file mode 100644 index 0000000..bb9214b --- /dev/null +++ b/copick_live/callbacks/update_res.py @@ -0,0 +1,689 @@ +import plotly.express as px +import dash_bootstrap_components as dbc +import json, time +import pandas as pd +from collections import defaultdict +from apscheduler.schedulers.background import BackgroundScheduler + +import time +from copick_live.utils.copick_dataset import get_copick_dataset +from copick_live.utils.figure_utils import blank_fig, draw_gallery +from copick_live.utils.local_dataset import get_local_dataset +from copick_live.config import get_config +from dash import html, Input, Output, callback, State, ALL, MATCH, ctx, dcc, no_update +from dash.exceptions import PreventUpdate + +from dash_iconify import DashIconify +import base64 +import io + +def submission_list(i, j): + return dbc.ListGroupItem("{}: {}".format(i.split(".json")[0], j)) + +def parse_contents(contents, filename, date): + content_type, content_string = contents.split(",") + decoded = base64.b64decode(content_string) + fig = [] + try: + if "csv" in filename: + # Assume that the user uploaded a CSV file + df = pd.read_csv(io.StringIO(decoded.decode("utf-8"))) + df = df.sort_values(by=["Aggregate_Fbeta"], ascending=False) + df = df.reset_index(drop=True) + df["rank"] = df.index + dict_df = df.set_index("File")["Aggregate_Fbeta"].to_dict() + print(dict_df) + fig = px.scatter( + df, + x="rank", + y="Aggregate_Fbeta", + hover_name="File", + title="Submitted model ranking", + ) + elif "xls" in filename: + # Assume that the user uploaded an excel file + df = pd.read_excel(io.BytesIO(decoded)) + df = df.sort_values(by=["Aggregate_Fbeta"], ascending=False) + df = df[["File", "Aggregate_Fbeta"]] + except Exception as e: + print(e) + return html.Div(["There was an error processing this file."]) + + return dbc.Card( + [ + dbc.CardHeader( + [ + DashIconify( + icon="noto-v1:trophy", width=25, style={"margin": "5px"} + ), + "Submitted model ranking", + ], + style={"font-weight": "bold"}, + ), + dbc.CardBody( + id="submission-rank", + children=dcc.Graph(figure=fig), + style={"overflowY": "scroll"}, + ), + ], + style={"height": "87vh"}, + ) + +roundbutton = { + "border": "transparent", + "padding": 0, + "backgroundColor": "transparent", + "color": "black", + "textAlign": "center", + "display": "block", + "fontSize": 9, + "height": 9, + "width": 9, + "margin-left": 10, + "margin-top": 8, +} + +def candidate_list(i, j): + return dbc.ListGroupItem("{} (labeled by {} person)".format(i, j)) + +def ranking_list(i, j): + return dbc.ListGroupItem("{} {} tomograms".format(i, j)) + +############################################## Callbacks ############################################## +@callback( + Output("modal-help", "is_open"), + Input("button-help", "n_clicks"), + State("modal-help", "is_open"), + prevent_initial_call=True, +) +def toggle_help_modal(n_clicks, is_open): + return not is_open + + +@callback( + Output("modal-results", "is_open"), + Input("button-results", "n_clicks"), + State("modal-results", "is_open"), + prevent_initial_call=True, +) +def toggle_help_modal(n_clicks, is_open): + return not is_open + + +@callback( + Output("output-data-upload", "children"), + Input("upload-data", "contents"), + State("upload-data", "filename"), + State("upload-data", "last_modified"), +) +def update_output(list_of_contents, list_of_names, list_of_dates): + if list_of_contents is not None: + children = [ + parse_contents(c, n, d) + for c, n, d in zip(list_of_contents, list_of_names, list_of_dates) + ] + return children + + +@callback( + Output("tomogram-index", "data"), + Input({"type": "tomogram-eval-bttn", "index": ALL}, "n_clicks"), + prevent_initial_call=True, +) +def update_tomogram_index(n_clicks): + if any(n_clicks): + changed_id = [p["prop_id"] for p in ctx.triggered][0].split(".")[0] + if "index" in changed_id: + tomogram_index = json.loads(changed_id)["index"] + return tomogram_index + + +@callback( + Output("collapse1", "is_open"), + Output("collapse2", "is_open"), + Input("tabs", "active_tab"), + prevent_initial_call=True, +) +def toggle_analysis_tabs(at): + if at == "tab-1": + return True, False + elif at == "tab-2": + return False, True + + +@callback( + Output("output-image-upload", "children", allow_duplicate=True), + Output("image-slider", "value"), + Output("particle-dropdown", "value"), + Output("modal-evaluation", "is_open"), + Output("tabs", "active_tab"), + Output("choose-results", "children"), + Input("tomogram-index", "data"), + prevent_initial_call=True, +) +def reset_analysis_popup(tomogram_index): + msg = f"Choose results for {tomogram_index}" + if tomogram_index is not None: + return [], 0, None, True, "tab-1", msg + else: + return [], 0, None, False, "tab-1", msg + + +@callback( + Output("run-dt", "data"), Input("tomogram-index", "data"), prevent_initial_call=True +) +def load_tomogram_run(tomogram_index): + dt = defaultdict(list) + if tomogram_index is not None: + # takes 18s for VPN + t1 = time.time() + get_copick_dataset().load_curr_run(run_name=tomogram_index, sort_by_score=True) + # takes 0.2s + t2 = time.time() + print("find copick run in copick", t2 - t1) + + return dt + + +@callback( + Output("image-slider", "value", allow_duplicate=True), + Input("particle-dropdown", "value"), + Input("display-row", "value"), + Input("display-col", "value"), + prevent_initial_call=True, +) +def reset_slider(value, nrow, ncol): + return 0 + + +@callback( + Output("output-image-upload", "children"), + Output("particle-dropdown", "options"), + Output("fig1", "figure"), + Output("image-slider", "max"), + Output("image-slider", "marks"), + Output("crop-label", "children"), + Output("image-slider", "value", allow_duplicate=True), + Output("keybind-num", "data"), + Input("tabs", "active_tab"), + Input("image-slider", "value"), + Input("crop-width", "value"), + Input("crop-avg", "value"), + Input("particle-dropdown", "value"), + Input("accept-bttn", "n_clicks"), + Input("reject-bttn", "n_clicks"), + Input("assign-bttn", "n_clicks"), + Input("username-analysis", "value"), + Input("keybind-event-listener", "event"), + Input("keybind-event-listener", "n_events"), + Input("display-row", "value"), + Input("display-col", "value"), + State("tomogram-index", "data"), + State("fig1", "figure"), + State("output-image-upload", "children"), + State("keybind-num", "data"), + State({"type": "thumbnail-image", "index": ALL}, "n_clicks"), + State("assign-dropdown", "value"), + prevent_initial_call=True, +) +def update_analysis( + at, + slider_value, + crop_width, + crop_avg, + particle, + accept_bttn, + reject_bttn, + assign_bttn, + copicklive_username, + keybind_event_listener, + n_events, + nrow, + ncol, + tomogram_index, + fig1, + fig2, + kbn, + thumbnail_image_select_value, + new_particle, +): + pressed_key = None + if ctx.triggered_id == "keybind-event-listener": + # user is going to type in the class creation/edit modals and we don't want to trigger this callback using keys + pressed_key = ( + keybind_event_listener.get("key", None) if keybind_event_listener else None + ) + if not pressed_key: + raise PreventUpdate + else: + print(f"pressed_key {pressed_key}") + + slider_max = 0 + changed_id = [p["prop_id"] for p in ctx.triggered][0] + # takes 0.35s on mac3 + if tomogram_index: + # time.sleep(7) + # particle_dict = {k: k for k in sorted(set(get_copick_dataset().dt['pickable_object_name']))} + if at == "tab-1": + time.sleep(2) + particle_dict = { + k: k + for k in sorted(set(get_copick_dataset().dt["pickable_object_name"])) + } + df = pd.DataFrame.from_dict(get_copick_dataset().dt) + fig1 = px.scatter_3d( + df, + x="x", + y="y", + z="z", + color="pickable_object_name", + symbol="user_id", + size="size", + opacity=0.5, + ) + return ( + fig2, + particle_dict, + fig1, + slider_max, + {0: "0", slider_max: str(slider_max)}, + no_update, + no_update, + no_update, + ) + elif at == "tab-2": + # new_particle = None + if pressed_key in [ + str(i + 1) for i in range(len(get_local_dataset()._im_dataset["name"])) + ]: + new_particle = get_local_dataset()._im_dataset["name"][ + int(pressed_key) - 1 + ] + elif pressed_key == "s": + new_particle = kbn + + get_copick_dataset().new_user_id(user_id=copicklive_username) + if ( + "display-row" in changed_id or "display-col" in changed_id + ) or particle in get_copick_dataset().points_per_obj: + if len(get_copick_dataset().points_per_obj[particle]) % (nrow * ncol): + slider_max = len(get_copick_dataset().points_per_obj[particle]) // ( + nrow * ncol + ) + else: + slider_max = ( + len(get_copick_dataset().points_per_obj[particle]) + // (nrow * ncol) + - 1 + ) + + positions = [ + i + for i in range( + slider_value * nrow * ncol, + min( + (slider_value + 1) * nrow * ncol, + len(get_copick_dataset().points_per_obj[particle]), + ), + ) + ] + # loading zarr takes 6-8s for VPN + particle_dict = { + k: k + for k in sorted(set(get_copick_dataset().dt["pickable_object_name"])) + } + dim_z, dim_y, dim_x = get_copick_dataset().tomogram.shape + msg = f"Image crop width (max {min(dim_x, dim_y)})" + if crop_width is not None: + half_width = crop_width // 2 + if crop_avg is None: + crop_avg = 0 + copick_dataset = get_copick_dataset() # Get the dataset here + fig2 = draw_gallery( + copick_dataset, # Pass the dataset + run=tomogram_index, + particle=particle, + positions=positions, + hw=half_width, + avg=crop_avg, + nrow=nrow, + ncol=ncol, + ) + + selected = [ + i for i, v in enumerate(thumbnail_image_select_value) if v % 2 == 1 + ] + selected_point_ids = [positions[i] for i in selected] + if "accept-bttn" in changed_id or pressed_key == "a": + get_copick_dataset().handle_accept_batch(selected_point_ids, particle) + elif "reject-bttn" in changed_id or pressed_key == "d": + get_copick_dataset().handle_reject_batch(selected_point_ids, particle) + elif "assign-bttn" in changed_id or pressed_key == "s": + get_copick_dataset().handle_assign_batch( + selected_point_ids, particle, new_particle + ) + + # update figures + # if 'accept-bttn' in changed_id or \ + # 'reject-bttn' in changed_id or \ + # 'assign-bttn' in changed_id or \ + # pressed_key in ['a', 'd', 's']: + # slider_value += 1 + # positions = [i for i in range(slider_value*nrow*ncol, min((slider_value+1)*nrow*ncol, len(get_copick_dataset().points_per_obj[particle])))] + # fig2 = draw_gallery(run=tomogram_index, particle=particle, positions=positions, hw=half_width, avg=crop_avg, nrow=nrow, ncol=ncol) + + if "assign-bttn" in changed_id or pressed_key == "s": + positions = [ + i + for i in range( + slider_value * nrow * ncol, + min( + (slider_value + 1) * nrow * ncol, + len(get_copick_dataset().points_per_obj[particle]), + ), + ) + ] + fig2 = draw_gallery( + run=tomogram_index, + particle=particle, + positions=positions, + hw=half_width, + avg=crop_avg, + nrow=nrow, + ncol=ncol, + ) + + if pressed_key == "ArrowRight" and slider_value < slider_max: + slider_value += 1 + positions = [ + i + for i in range( + slider_value * nrow * ncol, + min( + (slider_value + 1) * nrow * ncol, + len(get_copick_dataset().points_per_obj[particle]), + ), + ) + ] + fig2 = draw_gallery( + run=tomogram_index, + particle=particle, + positions=positions, + hw=half_width, + avg=crop_avg, + nrow=nrow, + ncol=ncol, + ) + elif pressed_key == "ArrowLeft" and slider_value: + slider_value -= 1 + positions = [ + i + for i in range( + slider_value * nrow * ncol, + min( + (slider_value + 1) * nrow * ncol, + len(get_copick_dataset().points_per_obj[particle]), + ), + ) + ] + fig2 = draw_gallery( + run=tomogram_index, + particle=particle, + positions=positions, + hw=half_width, + avg=crop_avg, + nrow=nrow, + ncol=ncol, + ) + + return ( + fig2, + particle_dict, + blank_fig(), + slider_max, + {0: "0", slider_max: str(slider_max)}, + msg, + slider_value, + new_particle, + ) + else: + return ( + fig2, + dict(), + blank_fig(), + slider_max, + {0: "0", slider_max: str(slider_max)}, + no_update, + no_update, + no_update, + ) + + +@callback( + Output({"type": "thumbnail-card", "index": MATCH}, "color"), + Input({"type": "thumbnail-image", "index": MATCH}, "n_clicks"), + Input("select-all-bttn", "n_clicks"), + Input("unselect-all-bttn", "n_clicks"), + State("image-slider", "value"), + State("display-row", "value"), + State("display-col", "value"), + State("particle-dropdown", "value"), + State({"type": "thumbnail-image", "index": MATCH}, "id"), +) +def select_thumbnail( + value, select_clicks, unselect_clicks, slider_value, nrow, ncol, particle, comp_id +): + """ + This callback assigns a color to thumbnail cards in the following scenarios: + - An image has been selected, but no label has been assigned (blue) + - An image has been labeled (label color) + - An image has been unselected or unlabeled (no color) + Args: + value: Thumbnail card that triggered the callback (n_clicks) + Returns: + thumbnail_color: Color of thumbnail card + """ + color = "" + colors = ["", "success", "danger", "warning"] + positions = [ + i + for i in range( + slider_value * nrow * ncol, + min( + (slider_value + 1) * nrow * ncol, + len(get_copick_dataset().points_per_obj[particle]), + ), + ) + ] + # print(f'positions {positions}') + selected = [ + get_copick_dataset().picked_points_mask[ + get_copick_dataset().points_per_obj[particle][i][0] + ] + for i in positions + ] + # print(f'selected {selected}') + # print(f'comp_id {comp_id}') + color = colors[selected[int(comp_id["index"])]] + # print(f'color {color} value {value}') + if value is None or ( + ctx.triggered[0]["prop_id"] == "unselect-all-bttn.n_clicks" and color == "" + ): + return "" + if value % 2 == 1: + return "primary" + else: + return color + + +@callback( + Output({"type": "thumbnail-image", "index": ALL}, "n_clicks"), + # Input({'type': 'label-button', 'index': ALL}, 'n_clicks_timestamp'), + # Input('un-label', 'n_clicks'), + Input("select-all-bttn", "n_clicks"), + Input("unselect-all-bttn", "n_clicks"), + State({"type": "thumbnail-image", "index": ALL}, "n_clicks"), + prevent_initial_call=True, +) +def deselect(select_clicks, unselect_clicks, thumb_clicked): + """ + This callback deselects a thumbnail card + Args: + label_button_trigger: Label button + unlabel_n_clicks: Un-label button + unlabel_all: Un-label all the images + thumb_clicked: Selected thumbnail card indice, e.g., [0,1,1,0,0,0] + Returns: + Modify the number of clicks for a specific thumbnail card + """ + # if all(x is None for x in label_button_trigger) and unlabel_n_clicks is None and unlabel_all is None: + # return [no_update]*len(thumb_clicked) + if ctx.triggered[0]["prop_id"] == "unselect-all-bttn.n_clicks": + print([0 for thumb in thumb_clicked]) + return [0 for thumb in thumb_clicked] + elif ctx.triggered[0]["prop_id"] == "select-all-bttn.n_clicks": + return [1 for thumb in thumb_clicked] + + +@callback( + Output("download-json", "data"), + Input("btn-download", "n_clicks"), + State("username", "value"), + prevent_initial_call=True, +) +def download_json(n_clicks, input_value): + input_value = ".".join(input_value.split(" ")) + filename = "copick_config_" + "_".join(input_value.split(".")) + ".json" + copick_dataset = get_copick_dataset() + config = get_config() + config.config["user_id"] = input_value + return dict(content=json.dumps(config.config, indent=4), filename=filename) + + +@callback( + Output("download-txt", "data"), + Input("btn-download-txt", "n_clicks"), + prevent_initial_call=True, +) +def download_txt(n_clicks): + config = get_config() + counter_file_path = config.counter_file_path + if counter_file_path: + with open(counter_file_path) as f: + counter = json.load(f) + + if counter["repeat"] == 2: + counter["start"] += counter["tasks_per_person"] + counter["repeat"] = 0 + + counter["repeat"] += 1 + local_dataset = get_local_dataset() + task_contents = "\n".join( + local_dataset.dirs[ + counter["start"] : counter["start"] + counter["tasks_per_person"] + ] + ) + print(task_contents) + task_filename = "task_recommendation.txt" + + with open(counter_file_path, "w") as f: + f.write(json.dumps(counter, indent=4)) + + return dict(content=task_contents, filename=task_filename) + + +@callback( + Output("proteins-histogram", "figure"), + Output("waitlist", "children"), + Output("rank", "children"), + Output("total-labeled", "children"), + Output("progress-bar", "value"), + Output("progress-bar", "label"), + Input("interval-component", "n_intervals"), +) +def update_results(n): + local_dataset = get_local_dataset() + data = local_dataset.fig_data() + fig = px.bar(x=data['name'], + y=data['count'], + labels={'x': 'Objects', 'y':'Counts'}, + text_auto=True, + color=data['name'], + color_discrete_map=data['colors'], + ) + fig.update(layout_showlegend=False) + num_candidates = 100 + candidates = local_dataset.candidates(num_candidates, random_sampling=False) + num_per_person_ordered = local_dataset.num_per_person_ordered + label = f"Labeled {len(local_dataset.tomos_pickers)} out of 1000 tomograms" + bar_val = round(len(local_dataset.tomos_pickers) / 1000 * 100, 1) + + return ( + fig, + dbc.ListGroup( + [candidate_list(i, j) for i, j in candidates.items()], flush=True + ), + dbc.ListGroup( + [ranking_list(i, len(j)) for i, j in num_per_person_ordered.items()], + numbered=True, + ), + [label], + bar_val, + f"{bar_val}%", + ) + + +@callback( + Output("composition", "children"), + Input("refresh-button", "n_clicks"), +) +def update_compositions(n): + local_dataset = get_local_dataset() + progress_list = [] + composition_list = html.Div() + data = local_dataset.fig_data() + if len(data['colors']) > 0: + l = 1/len(data['colors'])*100 + else: + l = 0 + obj_order = {name: i for i, name in enumerate(data["name"])} + tomograms = { + k: v + for k, v in sorted( + local_dataset.tomograms.items(), key=lambda x: dir2id[x[0]] + ) + } + for tomogram, ps in tomograms.items(): + progress = [] + ps = [p for p in ps if p in obj_order] + ps = sorted(list(ps), key=lambda x: obj_order[x]) + for p in ps: + progress.append(dbc.Progress(value=l, color=data["colors"][p], bar=True)) + + bttn = html.Button( + id={"type": "tomogram-eval-bttn", "index": tomogram}, + className="fa fa-search", + style=roundbutton, + ) + progress_list.append( + dbc.ListGroupItem( + children=[dbc.Row([tomogram, bttn]), dbc.Progress(progress)], + style={"border": "transparent"}, + ) + ) + + composition_list = dbc.ListGroup(progress_list) + return composition_list + + +# Initialize scheduler function +def init_scheduler(): + scheduler = BackgroundScheduler() + scheduler.add_job(func=get_local_dataset().refresh, trigger="interval", seconds=20) + scheduler.start() + + +# This function can be called when the app starts +def initialize_app(): + get_local_dataset().refresh() + init_scheduler() diff --git a/copick_live/components/__init__.py b/copick_live/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/components/annotators.py b/copick_live/components/annotators.py similarity index 58% rename from components/annotators.py rename to copick_live/components/annotators.py index 13e7052..4f02f69 100644 --- a/components/annotators.py +++ b/copick_live/components/annotators.py @@ -1,5 +1,6 @@ import dash_bootstrap_components as dbc from dash_iconify import DashIconify +from dash import dcc def layout(): return dbc.Card([ @@ -8,12 +9,11 @@ def layout(): ], style={"font-weight": "bold"} ), - # dcc.Loading( - # id="loading-annotators", - # children=dbc.CardBody(id='rank', style={'overflowY': 'scroll'}), - # type="circle", - # ), - dbc.CardBody(id='rank', style={'overflowY': 'scroll'}) + dcc.Loading( + id="loading-annotators", + children=[dbc.CardBody(id='rank', style={'overflowY': 'scroll'})], + type="default", + ) ], style={"height": '87vh'} - ) \ No newline at end of file + ) diff --git a/components/composition.py b/copick_live/components/composition.py similarity index 58% rename from components/composition.py rename to copick_live/components/composition.py index 8e1b977..c9fafb4 100644 --- a/components/composition.py +++ b/copick_live/components/composition.py @@ -1,6 +1,6 @@ import dash_bootstrap_components as dbc from dash_iconify import DashIconify -from dash import html +from dash import html, dcc roundbutton = { @@ -24,28 +24,22 @@ def layout(): 'Evaluation', dbc.Button('Refresh List', id="refresh-button", - #outline=True, color="primary", style = {"text-transform": "none", "fontSize": "0.85em", "width": "25%","height": "85%", "margin-left": "40%"}, ) ], style={"font-weight": "bold"} ), - #dbc.Row(dbc.Button('Refresh', id="refresh-button", outline=True, color="primary", className="me-1", size="sm"), justify="center"), dbc.CardBody(id='card-tomogram-evaluation', children=[ - # html.Div(dbc.Button('Refresh List', - # id="refresh-button", - # outline=True, - # color="primary", - # className="me-1", - # style = {"text-transform": "none"}), - # style ={'display': 'flex', 'justify-content': 'center', 'margin': '3px'}, - # ), - html.Div(id='composition') + dcc.Loading( + id="loading-composition", + children=[html.Div(id='composition')], + type="default", + ) ], style={'overflowY': 'scroll'} ), ], style={"height": "87vh"}, - ) \ No newline at end of file + ) diff --git a/components/header.py b/copick_live/components/header.py similarity index 100% rename from components/header.py rename to copick_live/components/header.py diff --git a/copick_live/components/popups.py b/copick_live/components/popups.py new file mode 100644 index 0000000..95e4834 --- /dev/null +++ b/copick_live/components/popups.py @@ -0,0 +1,435 @@ +from dash import html, dcc +import dash_bootstrap_components as dbc +import plotly.graph_objects as go +from copick_live.utils.local_dataset import get_local_dataset +from dash_extensions import EventListener + + +def blank_fig(): + """ + Creates a blank figure with no axes, grid, or background. + """ + fig = go.Figure() + fig.update_layout(template=None) + fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False) + fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False) + + return fig + + +def get_instructions(): + return [ + dcc.Markdown( + """ + Thanks for participating in CZII Pickathon! We highly encourage labeling all the 6 types of prteins in the tomogram. + ### Tools installation + #### ChimeraX + 1. Download and install [ChimeraX](https://www.cgl.ucsf.edu/chimerax/download.html) version 1.7.0+ + 2. Download the [copick plugin](https://cxtoolshed.rbvi.ucsf.edu/apps/chimeraxcopick) + 3. Open ChimeraX and install the copick wheel by typing in the command line: `toolshed install /path/to/your/ChimeraX_copick-0.1.3-py3-none-any.whl` + + #### Configuration files + Auto-generate the copick configuration file and a tomogram recomendation list for you (5 tomograms per file). + """, + link_target="_blank", + ), + html.Div( + children=[ + dbc.Input( + id="username", + placeholder="Please input your name (e.g., john.doe)", + type="text", + ), + dbc.Button( + "Download copick config file", + id="btn-download", + outline=True, + color="primary", + className="me-1", + ), + dbc.Button( + "Download recommendation file", + id="btn-download-txt", + outline=True, + color="primary", + className="me-1", + ), + dcc.Download(id="download-json"), + dcc.Download(id="download-txt"), + ], + className="d-grid gap-3 col-4 mx-auto", + ), + dcc.Markdown( + """ + ### Particle picking + The default workflow for ChimeraX should be: + 1. Type the command `copick start /path/to/config`. It will take about 2-3 mins to load the entire dataset tree. + 2. Open a tomogram by navigating the tree and double-clicking. + 3. Select or double click a pre-picked list from the upper table (double click will load the list). + 4. Press the ▼ ▼ ▼ button to copy the contents to the "editable" lower table. + 5. Select the Copick tab at the top right corner and choose a tool in the `Place Particles` session. Start editing by right click. Your picking results will be automatically saved. + """ + ), + ] + + +competition_results = [ + dcc.Upload( + id="upload-data", + children=html.Div(["Drag and Drop or ", html.A("Select Files")]), + style={ + "width": "95%", + "height": "60px", + "lineHeight": "60px", + "borderWidth": "1px", + "borderStyle": "dashed", + "borderRadius": "5px", + "textAlign": "center", + "margin": "10px", + }, + # Allow multiple files to be uploaded + multiple=True, + ), + html.Div(id="output-data-upload"), +] + + +def get_tabs(): + return html.Div( + [ + dbc.Tabs( + [ + dbc.Tab(label="Picked points visualization", tab_id="tab-1"), + dbc.Tab(label="2D Plane Inspection", tab_id="tab-2"), + # dbc.Tab(label="3D Volume Inspection", tab_id="tab-3"), + ], + id="tabs", + active_tab="tab-1", + ), + html.Div( + [ + dbc.Label( + "Choose results", + id="choose-results", + style={"margin-top": "35px", "margin-left": "7px"}, + ), + dcc.Dropdown( + ["Pickathon results"], + "Pickathon results", + id="pick-dropdown", + style={ + "width": "42%", + "justify-content": "center", + "margin-bottom": "0px", + "margin-left": "4px", + }, + ), + dbc.Collapse( + id="collapse1", + is_open=False, + children=dcc.Loading( + id="loading-fig1", + children=[dcc.Graph(id="fig1", figure=blank_fig())], + type="default", + ), + ), + dbc.Collapse( + id="collapse2", + is_open=False, + children=dbc.Container( + [ + dbc.Row( + [ + dbc.Col( + [ + dbc.Label( + "Please input your name", + style={"margin-top": "-20px"}, + ), + dbc.Input( + id="username-analysis", + placeholder="e.g., john.doe", + type="text", + style={"width": "75%"}, + ), + dbc.Label( + "Please select a particle type", + className="mt-3", + ), + dcc.Dropdown( + id="particle-dropdown", + style={"width": "87%"}, + ), + dbc.Label( + "Number of rows", className="mt-3" + ), + dcc.Input( + id="display-row", + type="number", + placeholder="5", + value=5, + min=1, + step=1, + ), + dbc.Label( + "Number of columns", + className="mt-3", + ), + dcc.Input( + id="display-col", + type="number", + placeholder="4", + value=4, + min=1, + step=1, + ), + dbc.Label( + id="crop-label", + children="Image crop size (max 100)", + className="mt-3", + ), + dcc.Input( + id="crop-width", + type="number", + placeholder="30", + value=60, + min=1, + step=1, + ), + dbc.Label( + "Average ±N neigbor layers", + className="mt-3", + ), + dcc.Input( + id="crop-avg", + type="number", + placeholder="3", + value=2, + min=0, + step=1, + ), + dbc.Label( + "Page slider (press key < or >)", + className="mt-3", + ), + html.Div( + dcc.Slider( + id="image-slider", + min=0, + max=200, + value=0, + step=1, + updatemode="drag", + tooltip={ + "placement": "top", + "always_visible": True, + }, + marks={0: "0", 199: "199"}, + ), + style={ + "width": "72%", + "margin-top": "10px", + }, + ), + ], + width=3, + align="center", + ), + dbc.Col( + [ + dcc.Loading( + id="loading-output-image-upload", + children=[html.Div( + id="output-image-upload", + children=[], + style={ + "height": "70vh", + "overflowY": "scroll", + }, + )], + type="default", + ), + ], + width=5, + align="top", + ), + dbc.Col( + [ + dbc.Row( + [ + dbc.Col( + dbc.Row( + dbc.Button( + "Select All", + id="select-all-bttn", + style={ + "width": "50%" + }, + color="primary", + className="me-1", + ), + justify="end", + ) + ), + dbc.Col( + dbc.Row( + dbc.Button( + "Unselect All", + id="unselect-all-bttn", + style={ + "width": "50%" + }, + color="primary", + className="me-1", + ), + justify="start", + ) + ), + ], + justify="evenly", + style={"margin-bottom": "40px"}, + ), + dbc.Row( + [ + dbc.Col( + dbc.Row( + dbc.Button( + "(D) Reject", + id="reject-bttn", + style={ + "width": "50%" + }, + color="danger", + className="me-1", + ), + justify="end", + ) + ), + dbc.Col( + dbc.Row( + dbc.Button( + "(A) Accept", + id="accept-bttn", + style={ + "width": "50%" + }, + color="success", + className="me-1", + ), + justify="start", + ) + ), + ], + justify="evenly", + style={"margin-bottom": "40px"}, + ), + dbc.Row( + [ + dbc.Col( + dbc.Row( + dbc.Button( + "(S) Assign", + id="assign-bttn", + style={ + "width": "25%", + "margin-left": "90px", + }, + color="primary", + className="me-1", + ), + justify="start", + ) + ), + # dbc.Col(dbc.Row(dbc.Button('Select All', id='select-all-bttn', style={'width': '50%'}, color='primary', className="me-1"), justify='start')), + # dbc.Col(dbc.Row(dbc.ListGroup([dbc.ListGroupItem(f'({str(i+1)}) {k}') for i,k in enumerate(get_local_dataset()._im_dataset['name'])]), justify='start')) + ], + justify="evenly", + style={"margin-bottom": "5px"}, + ), + dbc.Row( + [ + dbc.Col( + dbc.Row( + dcc.Dropdown( + id="assign-dropdown", + options={ + k: k + for k in get_local_dataset()._im_dataset[ + "name" + ] + }, + style={ + "width": "75%", + "margin-left": "-10px", + }, + ), + justify="end", + ) + ) + ], + justify="evenly", + ), + ], + width=4, + align="right", + ), + ], + justify="center", + align="center", + className="h-100", + ), + ], + fluid=True, + ), + ), + # dbc.Collapse(id="collapse3",is_open=False, children=dcc.Graph(id='fig3', figure=blank_fig())), + EventListener( + events=[ + { + "event": "keydown", + "props": ["key", "ctrlKey", "ctrlKey"], + } + ], + id="keybind-event-listener", + ), + ] + ), + ] + ) + + +def layout(): + return html.Div( + [ + dbc.Modal( + [ + dbc.ModalHeader(dbc.ModalTitle("Instructions")), + dbc.ModalBody(id="modal-body-help", children=get_instructions()), + ], + id="modal-help", + is_open=False, + size="xl", + ), + dbc.Modal( + [ + dbc.ModalHeader(dbc.ModalTitle("Submission Results")), + dbc.ModalBody( + id="modal-body-results", children=competition_results + ), + ], + id="modal-results", + is_open=False, + size="xl", + ), + dbc.Modal( + [ + # dbc.ModalHeader(dbc.ModalTitle("Tomogram Evaluation")), + dbc.ModalBody(id="modal-body-evaluation", children=get_tabs()), + ], + id="modal-evaluation", + is_open=False, + centered=True, + size="xl", + ), + ] + ) diff --git a/components/progress.py b/copick_live/components/progress.py similarity index 100% rename from components/progress.py rename to copick_live/components/progress.py diff --git a/copick_live/components/project_explorer.py b/copick_live/components/project_explorer.py new file mode 100644 index 0000000..10116c6 --- /dev/null +++ b/copick_live/components/project_explorer.py @@ -0,0 +1,105 @@ +import dash +import dash_bootstrap_components as dbc +from dash import html, dcc, callback, Input, Output, State, ALL +from dash.exceptions import PreventUpdate +from copick_live.utils.copick_dataset import get_copick_dataset +import json + +def layout(): + return html.Div([ + dbc.Button("Load Project Structure", id="load-project-button", color="primary", className="mb-3"), + html.Div(id="project-structure-container"), + dcc.Store(id="project-structure-store", data={}), + ]) + +@callback( + Output("project-structure-container", "children"), + Output("project-structure-store", "data"), + Input("load-project-button", "n_clicks"), + State("project-structure-store", "data"), + prevent_initial_call=True +) +def load_project_structure(n_clicks, stored_data): + if n_clicks is None: + raise PreventUpdate + + copick_dataset = get_copick_dataset() + project_structure = {"name": "Root", "children": []} + + for run in copick_dataset.root.runs: + run_structure = {"name": run.name, "children": [ + {"name": "Picks", "children": [], "parent_run": run.name}, + {"name": "Segmentations", "children": [], "parent_run": run.name}, + {"name": "VoxelSpacing", "children": [], "parent_run": run.name} + ]} + project_structure["children"].append(run_structure) + + return render_structure(project_structure), project_structure + +@callback( + Output({"type": "expand-container", "index": ALL}, "children"), + Input({"type": "expand-button", "index": ALL}, "n_clicks"), + State({"type": "expand-container", "index": ALL}, "id"), + State("project-structure-store", "data"), + prevent_initial_call=True +) +def expand_node(n_clicks, container_ids, stored_data): + if not n_clicks or not any(n_clicks): + raise PreventUpdate + + triggered_id = json.loads(dash.callback_context.triggered[0]['prop_id'].split('.')[0]) + index_path = triggered_id['index'] + + node = stored_data + for idx in index_path.split(','): + node = node['children'][int(idx)] + + if 'loaded' not in node: + node['loaded'] = True + copick_dataset = get_copick_dataset() + + if node['name'] == 'Picks': + # Load picks data + run_name = node['parent_run'] + copick_dataset.load_curr_run(run_name=run_name) + for obj_name, points in copick_dataset.points_per_obj.items(): + node['children'].append({"name": f"{obj_name} ({len(points)})", "children": []}) + + elif node['name'] == 'Segmentations': + # Load segmentations data + run_name = node['parent_run'] + run = copick_dataset.root.get_run(run_name) + segmentations = run.get_segmentations() + for seg in segmentations: + node['children'].append({"name": seg.name, "children": []}) + + elif node['name'] == 'VoxelSpacing': + # Load voxel spacing data + run_name = node['parent_run'] + run = copick_dataset.root.get_run(run_name) + voxel_spacings = run.get_voxel_spacings() + for vs in voxel_spacings: + node['children'].append({"name": f"Spacing: {vs.spacing}", "children": [ + {"name": f"Tomogram: {vs.get_tomogram().name}", "children": []}, + {"name": f"CTF: {vs.get_ctf().name}", "children": []} + ]}) + + return [render_structure(node) if id['index'] == index_path else dash.no_update for id in container_ids] + +def render_structure(node, path=''): + children = [] + for i, child in enumerate(node.get('children', [])): + new_path = f"{path},{i}" if path else str(i) + expand_button = dbc.Button( + "▶", + id={"type": "expand-button", "index": new_path}, + size="sm", + className="mr-2" + ) if child.get('children') else None + + children.append(html.Div([ + expand_button, + child['name'], + html.Div(id={"type": "expand-container", "index": new_path}, style={'margin-left': '20px'}) + ])) + return children diff --git a/components/proteins.py b/copick_live/components/proteins.py similarity index 57% rename from components/proteins.py rename to copick_live/components/proteins.py index e2ab5a2..de19213 100644 --- a/components/proteins.py +++ b/copick_live/components/proteins.py @@ -1,4 +1,4 @@ -from dash import dcc +from dash import dcc, html import dash_bootstrap_components as dbc from dash_iconify import DashIconify @@ -10,7 +10,13 @@ def layout(): ], style={"font-weight": "bold"} ), - dbc.CardBody([dcc.Graph(id='proteins-histogram')]) + dbc.CardBody([ + dcc.Loading( + id="loading-proteins-histogram", + children=[dcc.Graph(id='proteins-histogram')], + type="default", + ) + ]) ], style={"height": '87vh'} - ) \ No newline at end of file + ) diff --git a/components/waitlist.py b/copick_live/components/waitlist.py similarity index 63% rename from components/waitlist.py rename to copick_live/components/waitlist.py index 526d5ff..ff62c79 100644 --- a/components/waitlist.py +++ b/copick_live/components/waitlist.py @@ -1,5 +1,6 @@ import dash_bootstrap_components as dbc from dash_iconify import DashIconify +from dash import dcc def layout(): return dbc.Card([ @@ -9,7 +10,11 @@ def layout(): ], style={"font-weight": "bold"} ), - dbc.CardBody(id='waitlist', style={'overflowY': 'scroll'}), + dcc.Loading( + id="loading-waitlist", + children=[dbc.CardBody(id='waitlist', style={'overflowY': 'scroll'})], + type="default", + ) ], style={"height": '72vh'} - ) \ No newline at end of file + ) diff --git a/copick_live/config.py b/copick_live/config.py new file mode 100644 index 0000000..8633121 --- /dev/null +++ b/copick_live/config.py @@ -0,0 +1,30 @@ +import json +import os + +class Config: + def __init__(self, config_path=None): + if config_path is None: + config_path = os.path.join(os.getcwd(), "copick_live.json") + + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found at {config_path}") + + with open(config_path, 'r') as f: + self.config = json.load(f) + + self.copick_config_path = self.config.get("copick_config_path") + self.counter_file_path = self.config.get("counter_file_path") + self.cache_root = self.config.get("cache_root") + self.album_mode = self.config.get("album_mode") + self.copick_live_version = self.config.get("copick_live_version") + + def get(self, key, default=None): + return self.config.get(key, default) + +config = None + +def get_config(config_path=None): + global config + if config is None or config_path: + config = Config(config_path) + return config diff --git a/copick_live/copick_live.json b/copick_live/copick_live.json new file mode 100644 index 0000000..153ddfd --- /dev/null +++ b/copick_live/copick_live.json @@ -0,0 +1,5 @@ +{ + "copick_config_path": "copick_config.json", + "counter_file_path": "counter_checkpoint_file.json", + "cache_root": "copicklive_cache/" +} diff --git a/copick_live/utils/__init__.py b/copick_live/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/copick_live/utils/copick_dataset.py b/copick_live/utils/copick_dataset.py new file mode 100644 index 0000000..ae2188d --- /dev/null +++ b/copick_live/utils/copick_dataset.py @@ -0,0 +1,235 @@ +import os +from copick.impl.filesystem import CopickRootFSSpec +from collections import defaultdict +import pandas as pd +import zarr +import threading +from copick_live.config import get_config + +class CopickDataset: + def __init__(self): + config = get_config() + self.root = CopickRootFSSpec.from_file(config.copick_config_path) + self.tomogram = None + self.run_name = None + self.current_point = None # current point index + self.current_point_obj = None # current point copick object + self.dt = defaultdict(list) + + # variables for storing points in the current run + self.all_points = [] # [point_obj,...] unique pick objs from all pickers + self._point_types = [] # ['ribosome',...] + self.points_per_obj = defaultdict( + list + ) # {'ribosome': [(0,0.12),(2,0.33),(3,0.27...],...} (index, score) + self.all_points_locations = ( + set() + ) # {(x,y,z),...} a mask to check if a point is duplicated + # variables for storing picked points in the current run + self.picked_points_mask = ( + [] + ) # [1, 0, 2, 3, ...] # 1: accept, 2: reject, 0: unassigned, 3: assigned new class + self._picked_id_per_obj = defaultdict(list) # {'ribosome': [0,3...],...} + self._picked_points_per_obj = defaultdict( + list + ) # {'ribosome': [point_obj...],...} + + self._logs = defaultdict( + list + ) # {'user_id':[], 'x': [], 'y':[], 'z':[], 'operation':['reject', 'accept', 'reassign'], 'start_class':[], 'end_class'[]} + self.tomogram_lock = threading.Lock() + self.tomogram_loaded = threading.Event() + + def _reset_states(self): + self.points_per_obj = defaultdict(list) + self._point_types = [] + self.all_points = [] + self.picked_points_mask = [] + self._picked_id_per_obj = defaultdict(list) + self._picked_points_per_obj = defaultdict(list) + self.all_points_locations = set() + self._logs = defaultdict(list) + self.dt = defaultdict(list) + + def load_curr_run(self, run_name=None, sort_by_score=False, reverse=False): + if run_name is not None: + self._reset_states() + self.run_name = run_name + self.run = self.root.get_run(self.run_name) + + # Start loading tomogram in a separate thread + threading.Thread(target=self._load_tomogram, args=(run_name,)).start() + + # Load other data + self._load_points() + + if sort_by_score: + self._sort_points(reverse) + + + def _load_tomogram(self, run_name): + with self.tomogram_lock: + _run = self.tomo_root.get_run(run_name) if self.tomo_root is not None else self.run + tomogram = _run.get_voxel_spacing(10).get_tomogram("denoised") + group = zarr.open(tomogram.zarr()) + _, array = list(group.arrays())[0] + self.tomogram = array[:] + self.tomogram_loaded.set() + + def _store_points(self, obj_name=None, session_id="18"): + if obj_name is not None: + _picks = self.run.get_picks( + object_name=obj_name, user_id=self.root.user_id, session_id=session_id + ) + if not _picks: + _picks = self.run.new_picks( + object_name=obj_name, + user_id=self.root.user_id, + session_id=session_id, + ) + _picks = self.run.get_picks( + object_name=obj_name, + user_id=self.root.user_id, + session_id=session_id, + ) + _pick_set = _picks[0] + _pick_set.points = self._picked_points_per_obj[obj_name] + _pick_set.store() + + def new_user_id(self, user_id=None): + if user_id is not None: + config = get_config() + config.config["user_id"] = user_id + + def load_curr_point(self, point_id=None, obj_name=None): + if point_id is not None and obj_name is not None: + self.pickable_obj_name = obj_name + print("Creating current pick point") + if len(self.points_per_obj[obj_name]): + self.current_point = self.points_per_obj[obj_name][point_id][ + 0 + ] # current point index + self.current_point_obj = self.all_points[self.current_point] + else: + self.current_point = None + self.current_point_obj = None + + def change_obj_name(self, obj_name=None, enable_log=True): + if enable_log: + self.log_operation( + operation="reasign", + old_obj_name=self._point_types[self.current_point], + new_obj_name=obj_name, + ) + + if obj_name is not None and self.current_point is not None: + self._point_types[self.current_point] = obj_name + + def _update_logs( + self, operation="placeholder", old_obj_name="NC", new_obj_name="NC" + ): + self._logs["run_name"].append(self.run_name) + self._logs["user_id"].append(self.root.config.user_id) + self._logs["x"].append(self.current_point_obj.location.x) + self._logs["y"].append(self.current_point_obj.location.y) + self._logs["z"].append(self.current_point_obj.location.z) + self._logs["operation"].append(operation) + self._logs["start_class"].append(old_obj_name) + self._logs["end_class"].append(new_obj_name) + + def log_operation(self, operation=None, old_obj_name=None, new_obj_name=None): + self._logs = defaultdict(list) + if os.path.exists("logs.csv") == False: + self._update_logs() + pd.DataFrame(self._logs).to_csv("logs.csv", sep="\t", index=False) + + df = pd.read_csv("logs.csv", sep="\t") + self._logs = df.to_dict("list") if len(df) else defaultdict(list) + + if ( + operation is not None + and old_obj_name is not None + and new_obj_name is not None + ): + self._update_logs(operation, old_obj_name, new_obj_name) + df = pd.DataFrame.from_dict(self._logs) + df.to_csv("logs.csv", sep="\t", index=False) + + def handle_accept(self): + if self.current_point is not None: + self.picked_points_mask[self.current_point] = 1 + obj_name = self._point_types[self.current_point] + print( + f"Accept, Object Type: {obj_name}, Run Name: {self.run_name}, Location: {self.current_point_obj.location}" + ) + self._point_types[self.current_point] = obj_name + if self.current_point not in self._picked_id_per_obj[obj_name]: + self._picked_id_per_obj[obj_name].append(self.current_point) + self._picked_points_per_obj[obj_name].append(self.current_point_obj) + self._store_points(obj_name) + + def handle_reject(self, enable_log=True): + if enable_log: + self.log_operation(operation="reject", old_obj_name="NC", new_obj_name="NC") + + if self.current_point is not None: + self.picked_points_mask[self.current_point] = 2 + try: + obj_name = self._point_types[self.current_point] + index = self._picked_id_per_obj[obj_name].index(self.current_point) + print( + f"reject point index {self.current_point}, index in the list {index}" + ) + self._picked_id_per_obj[obj_name].pop(index) + self._picked_points_per_obj[obj_name].pop(index) + self._store_points(obj_name) + except: + pass + + def handle_assign(self, new_bj_name=None): + self.handle_reject(enable_log=False) + self.change_obj_name(new_bj_name) + self.handle_accept() + # self.picked_points_mask[self.current_point] = 3 + + # EXPERIMENT, dangeraous, may incurr index errors! + # Only re-assignment changes the original states (initialized when load run) + new_list = [] + target = tuple() + old_obj_name = self.pickable_obj_name + print(old_obj_name, self.points_per_obj[old_obj_name]) + for item in self.points_per_obj[old_obj_name]: + if item[0] != self.current_point: + new_list.append(item) + else: + target = item + + self.points_per_obj[old_obj_name] = new_list + # add the new assigned point to the front + self.points_per_obj[new_bj_name].insert(0, target) + + def handle_accept_batch(self, point_ids=None, obj_name=None): + if point_ids is not None: + for point_id in point_ids: + self.load_curr_point(point_id=point_id, obj_name=obj_name) + self.handle_accept() + + def handle_reject_batch(self, point_ids=None, obj_name=None): + if point_ids is not None: + for point_id in point_ids: + self.load_curr_point(point_id=point_id, obj_name=obj_name) + self.handle_reject() + + def handle_assign_batch(self, point_ids=None, obj_name=None, new_bj_name=None): + if point_ids is not None and obj_name is not None and new_bj_name is not None: + for point_id in point_ids: + self.load_curr_point(point_id=point_id, obj_name=obj_name) + self.handle_assign(new_bj_name) + +copick_dataset = None + +def get_copick_dataset(): + global copick_dataset + if copick_dataset is None: + copick_dataset = CopickDataset() + return copick_dataset diff --git a/utils/figure_utils.py b/copick_live/utils/figure_utils.py similarity index 88% rename from utils/figure_utils.py rename to copick_live/utils/figure_utils.py index 3f1ca8a..c9b41fa 100644 --- a/utils/figure_utils.py +++ b/copick_live/utils/figure_utils.py @@ -8,7 +8,7 @@ import base64 import numpy as np -from utils.copick_dataset import copick_dataset +from copick_live.utils.copick_dataset import get_copick_dataset from functools import lru_cache @@ -34,14 +34,8 @@ def crop_image2d(image, copick_loc, hw, avg): #====================================== memoization ====================================== #@lru_cache(maxsize=128) # number of images -def prepare_images2d(run=None, particle=None, positions=[], hw=60, avg=2): +def prepare_images2d(copick_dataset, run=None, particle=None, positions=[], hw=60, avg=2): padded_image = np.pad(copick_dataset.tomogram, ((hw,hw), (hw,hw), (hw, hw)), 'constant') - # cache_dir = CACHE_ROOT + 'cache-directory/' - # os.makedirs(cache_dir, exist_ok=True) - # # Create an LRU cache for the store with a maximum size of 100 MB - # store = DirectoryStore(f'{cache_dir}{run}_2d_crops.zarr') - # #cache_store = LRUStoreCache(store, max_size=100 * 2**20) - # root = zarr.group(store=store, overwrite=True) cropped_image_batch = [] if particle in copick_dataset.points_per_obj and len(positions): point_ids = [copick_dataset.points_per_obj[particle][i][0] for i in positions] @@ -49,7 +43,7 @@ def prepare_images2d(run=None, particle=None, positions=[], hw=60, avg=2): for point_obj in point_objs: cropped_image = crop_image2d(padded_image, point_obj.location, hw, avg) cropped_image_batch.append(cropped_image) - + return np.array(cropped_image_batch) @@ -165,9 +159,9 @@ def draw_gallery_components(list_of_image_arr, n_rows, n_cols): return children -def draw_gallery(run=None, particle=None, positions=[], hw=60, avg=2, nrow=5, ncol=4): +def draw_gallery(copick_dataset, run=None, particle=None, positions=[], hw=60, avg=2, nrow=5, ncol=4): figures = [] - cropped_image_batch = prepare_images2d(run=run, particle=particle, positions=positions, hw=hw, avg=avg) + cropped_image_batch = prepare_images2d(copick_dataset, run=run, particle=particle, positions=positions, hw=hw, avg=avg) if len(cropped_image_batch): figures = draw_gallery_components(cropped_image_batch, nrow, ncol) - return figures \ No newline at end of file + return figures diff --git a/copick_live/utils/local_dataset.py b/copick_live/utils/local_dataset.py new file mode 100644 index 0000000..95bf24a --- /dev/null +++ b/copick_live/utils/local_dataset.py @@ -0,0 +1,159 @@ +import os, time +from copick_live.config import get_config +from copick.impl.filesystem import CopickRootFSSpec +import random, copy +from collections import defaultdict, deque +import json +import concurrent + +class LocalDataset: + def __init__(self): + config = get_config() + self.root = CopickRootFSSpec.from_file(config.copick_config_path) + self.counter_file_path = config.counter_file_path + + # output + self.proteins = defaultdict(int) # {'ribosome': 38, ...} + self.tomograms = defaultdict(set) #{'TS_1_1':{'ribosome', ...}, ...} + self.tomos_per_person = defaultdict(set) #{'john.doe':{'TS_1_1',...},...} + self.tomos_pickers = defaultdict(set) #{'Test_1_1': {john.doe,...}, ...} + self.num_per_person_ordered = dict() # {'Tom':5, 'Julie':3, ...} + + # hidden variables for updating candidate recommendations + self._all = set() + self._tomos_done = set() # labeled at least by 2 people + self._tomos_one_pick = set() # labeled only by 1 person + self._candidate_dict = defaultdict() # {1:1, 2:0, ...} + self._prepicks = set(['slab-picking', + 'pytom-template-match', + 'relion-refinement', + 'prepick', + 'ArtiaX', + 'default'] + ) + + xdata = [] + colors = dict() + for po in config.get("pickable_objects", []): + xdata.append(po["name"]) + colors[po["name"]] = po["color"] + + self._im_dataset = {'name': xdata, + 'count': [], + 'colors': colors + } + + def _reset(self): + self.proteins = defaultdict(int) + self._tomos_one_pick = set() #may remove some elems, therefore, empty before each check + + config = get_config() + xdata = [] + colors = dict() + for po in config.get("pickable_objects", []): + xdata.append(po["name"]) + colors[po["name"]] = po["color"] + + self._im_dataset = {'name': xdata, + 'count': [], + 'colors': colors + } + + + def refresh(self): + self._reset() + self._update_tomo_sts() + + + def _process_run(self, run): + for pick_set in run.get_picks(): + try: + pickable_object_name = pick_set.pickable_object_name + user_id = pick_set.user_id + run_name = run.name + points = pick_set.points + + if user_id not in self._prepicks and points and len(points): + self.proteins[pickable_object_name] += len(points) + self.tomos_per_person[user_id].add(run_name) + self.tomograms[run_name].add(pickable_object_name) + self.tomos_pickers[run_name].add(user_id) + except json.JSONDecodeError: + print(f"Error decoding JSON for pick set in run {run.name}") + except Exception as e: + print(f"Unexpected error processing run {run.name}: {e}") + + def _update_tomo_sts(self): + start = time.time() + runs = self.root.runs + self._all = set(range(len(runs))) + + with concurrent.futures.ThreadPoolExecutor() as executor: + executor.map(self._process_run, runs) + + print(f'{time.time()-start} s to check all files') + + for tomo, pickers in self.tomos_pickers.items(): + run_id = next((i for i, run in enumerate(runs) if run.name == tomo), None) + if run_id is not None: + if len(pickers) >= 2: + self._tomos_done.add(run_id) + elif len(pickers) == 1: + self._tomos_one_pick.add(run_id) + + self.num_per_person_ordered = dict(sorted(self.tomos_per_person.items(), key=lambda item: len(item[1]), reverse=True)) + + + def _update_candidates(self, n, random_sampling=True): + # remove candidates that should not be considered any more + _candidate_dict = defaultdict() + for candidate in self._candidate_dict.keys(): + if candidate in self._tomos_done: + continue + _candidate_dict[candidate] = self._candidate_dict[candidate] + self._candidate_dict = _candidate_dict + + # add candidates that have been picked once + if len(self._candidate_dict) < n: + for i in self._tomos_one_pick: + self._candidate_dict[i] = 1 + if len(self._candidate_dict) == n: + break + + # add candidates that have not been picked yet + if len(self._candidate_dict) < n: + residuals = self._all - self._tomos_done - self._tomos_one_pick + residuals = deque(residuals) + while residuals and len(self._candidate_dict) < n: + if random_sampling: + new_id = random.randint(0, len(residuals) - 1) + self._candidate_dict[residuals[new_id]] = 0 + del residuals[new_id] + else: + new_candidate = residuals.popleft() + self._candidate_dict[new_candidate] = 0 + + + def candidates(self, n: int, random_sampling=True) -> dict: + self._candidate_dict = {k: 0 for k in range(n)} if not random_sampling else {k: 0 for k in random.sample(range(len(self.root.runs)), n)} + self._update_candidates(n, random_sampling) + return {k: v for k, v in sorted(self._candidate_dict.items(), key=lambda x: x[1], reverse=True)} + + + def fig_data(self): + image_dataset = copy.deepcopy(self._im_dataset) + proteins = copy.deepcopy(self.proteins) + for name in image_dataset['name']: + image_dataset['count'].append(proteins[name]) + + image_dataset['colors'] = {k: 'rgba' + str(tuple(v)) for k, v in image_dataset['colors'].items()} + return image_dataset + + +local_dataset = None + +def get_local_dataset(): + global local_dataset + if local_dataset is None: + local_dataset = LocalDataset() + return local_dataset diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..0d0ab94 --- /dev/null +++ b/setup.py @@ -0,0 +1,30 @@ +from setuptools import setup, find_packages + +setup( + name="copick_live", + version="0.1.0", + description="A live CoPick application", + author="Zhuowen Zhao", + author_email="kevin.zhao@czii.org", + url="https://github.com/zhuowenzhao/copick_live", + packages=find_packages(), + include_package_data=True, + install_requires=[ + "dash==2.13.0", + "plotly==5.17.0", + "pandas", + "dash-extensions==1.0.1", + "dash-bootstrap-components==1.5.0", + "dash-iconify==0.1.2", + "Flask==2.2.5", + "numpy", + "apscheduler", + "pillow", + ], + dependency_links=[ + "git+https://github.com/uermel/copick.git#egg=copick" + ], + package_data={ + '': ['assets/*'], + }, +) diff --git a/utils/copick_dataset.py b/utils/copick_dataset.py deleted file mode 100644 index 1863732..0000000 --- a/utils/copick_dataset.py +++ /dev/null @@ -1,229 +0,0 @@ -import os -import configparser -from copick.impl.filesystem import CopickRootFSSpec -from collections import defaultdict -import pandas as pd -import zarr - - -class CopickDataset: - def __init__(self, copick_config_path: str=None, copick_config_path_tomogram: str=None): - self.root = CopickRootFSSpec.from_file(copick_config_path) if copick_config_path else None - self.tomo_root = CopickRootFSSpec.from_file(copick_config_path_tomogram) if copick_config_path_tomogram else None - self.tomogram = None - self.run_name = None - self.current_point = None # current point index - self.current_point_obj = None # current point copick object - self.dt = defaultdict(list) - - # variables for storing points in the current run - self.all_points = [] #[point_obj,...] unique pick objs from all pickers - self._point_types = [] #['ribosome',...] - self.points_per_obj = defaultdict(list) # {'ribosome': [(0,0.12),(2,0.33),(3,0.27...],...} (index, score) - self.all_points_locations = set() # {(x,y,z),...} a mask to check if a point is duplicated - # variables for storing picked points in the current run - self.picked_points_mask = [] #[1, 0, 2, 3, ...] # 1: accept, 2: reject, 0: unassigned, 3: assigned new class - self._picked_id_per_obj = defaultdict(list) # {'ribosome': [0,3...],...} - self._picked_points_per_obj = defaultdict(list) # {'ribosome': [point_obj...],...} - - self._logs = defaultdict(list) # {'user_id':[], 'x': [], 'y':[], 'z':[], 'operation':['reject', 'accept', 'reassign'], 'start_class':[], 'end_class'[]} - - - def _reset_states(self): - self.points_per_obj = defaultdict(list) - self._point_types = [] - self.all_points = [] - self.picked_points_mask = [] - self._picked_id_per_obj = defaultdict(list) - self._picked_points_per_obj = defaultdict(list) - self.all_points_locations = set() - self._logs = defaultdict(list) - self.dt = defaultdict(list) - - - - def load_curr_run(self, run_name=None, sort_by_score=False, reverse=False): - if run_name is not None: - self._reset_states() - self.run_name = run_name - self.run = self.root.get_run(self.run_name) - _run = self.tomo_root.get_run(self.run_name) if self.tomo_root is not None else self.run - for pick in self.run.picks: - for point in pick.points: - # all picks from indivial pickers to show in tab1, contain duplicated picks. - self.dt['pickable_object_name'].append(pick.pickable_object_name) - self.dt['user_id'].append(pick.user_id) - self.dt['x'].append(float(point.location.x)/10) - self.dt['y'].append(float(point.location.y)/10) - self.dt['z'].append(float(point.location.z)/10) - self.dt['size'].append(0.1) - if (point.location.x, point.location.y, point.location.z) not in self.all_points_locations: - self.points_per_obj[pick.pickable_object_name].append((len(self.all_points), point.score)) - self._point_types.append(pick.pickable_object_name) - self.all_points.append(point) - self.all_points_locations.add((point.location.x, point.location.y, point.location.z)) - - self.picked_points_mask = [0]*len(self.all_points) - if sort_by_score: - for k,values in self.points_per_obj.items(): - if len(values): - values.sort(key=lambda x: x[1], reverse=reverse) # reverse=Fasle, ascending order - - tomogram = _run.get_voxel_spacing(10).get_tomogram("denoised") - # Access the data - group = zarr.open(tomogram.zarr()) - _, array = list(group.arrays())[0] # highest resolution bin=0 - self.tomogram = array[:] - - - def _store_points(self, obj_name=None, session_id='18'): - if obj_name is not None: - _picks = self.run.get_picks(object_name=obj_name, user_id=self.root.user_id, session_id=session_id) - if not _picks: - _picks = self.run.new_picks(object_name=obj_name, user_id=self.root.user_id, session_id=session_id) - _picks = self.run.get_picks(object_name=obj_name, user_id=self.root.user_id, session_id=session_id) - _pick_set = _picks[0] - _pick_set.points = self._picked_points_per_obj[obj_name] - _pick_set.store() - - - def new_user_id(self, user_id=None): - if user_id is not None: - self.root.config.user_id = user_id - - - def load_curr_point(self, point_id=None, obj_name=None): - if point_id is not None and obj_name is not None: - self.pickable_obj_name = obj_name - print("Creating current pick point") - if len(self.points_per_obj[obj_name]): - self.current_point = self.points_per_obj[obj_name][point_id][0] # current point index - self.current_point_obj = self.all_points[self.current_point] - else: - self.current_point = None - self.current_point_obj = None - - - def change_obj_name(self, obj_name=None, enable_log=True): - if enable_log: - self.log_operation(operation='reasign', old_obj_name=self._point_types[self.current_point], new_obj_name=obj_name) - - if obj_name is not None and self.current_point is not None: - self._point_types[self.current_point] = obj_name - - - def _update_logs(self, operation='placeholder', old_obj_name='NC', new_obj_name='NC'): - self._logs['run_name'].append(self.run_name) - self._logs['user_id'].append(self.root.config.user_id) - self._logs['x'].append(self.current_point_obj.location.x) - self._logs['y'].append(self.current_point_obj.location.y) - self._logs['z'].append(self.current_point_obj.location.z) - self._logs['operation'].append(operation) - self._logs['start_class'].append(old_obj_name) - self._logs['end_class'].append(new_obj_name) - - - def log_operation(self, operation=None, old_obj_name=None, new_obj_name=None): - self._logs = defaultdict(list) - if os.path.exists('logs.csv') == False: - self._update_logs() - pd.DataFrame(self._logs).to_csv('logs.csv', sep='\t', index=False) - - df = pd.read_csv('logs.csv', sep='\t') - self._logs = df.to_dict('list') if len(df) else defaultdict(list) - - if operation is not None and old_obj_name is not None and new_obj_name is not None: - self._update_logs(operation, old_obj_name, new_obj_name) - df = pd.DataFrame.from_dict(self._logs) - df.to_csv('logs.csv', sep='\t', index=False) - - - def handle_accept(self): - if self.current_point is not None: - self.picked_points_mask[self.current_point] = 1 - obj_name = self._point_types[self.current_point] - print(f"Accept, Object Type: {obj_name}, Run Name: {self.run_name}, Location: {self.current_point_obj.location}") - self._point_types[self.current_point] = obj_name - if self.current_point not in self._picked_id_per_obj[obj_name]: - self._picked_id_per_obj[obj_name].append(self.current_point) - self._picked_points_per_obj[obj_name].append(self.current_point_obj) - self._store_points(obj_name) - - - def handle_reject(self, enable_log=True): - if enable_log: - self.log_operation(operation='reject', old_obj_name='NC', new_obj_name='NC') - - if self.current_point is not None: - self.picked_points_mask[self.current_point] = 2 - try: - obj_name =self._point_types[self.current_point] - index = self._picked_id_per_obj[obj_name].index(self.current_point) - print(f'reject point index {self.current_point}, index in the list {index}') - self._picked_id_per_obj[obj_name].pop(index) - self._picked_points_per_obj[obj_name].pop(index) - self._store_points(obj_name) - except: - pass - - - def handle_assign(self, new_bj_name=None): - self.handle_reject(enable_log=False) - self.change_obj_name(new_bj_name) - self.handle_accept() - #self.picked_points_mask[self.current_point] = 3 - - # EXPERIMENT, dangeraous, may incurr index errors! - # Only re-assignment changes the original states (initialized when load run) - new_list = [] - target = tuple() - old_obj_name = self.pickable_obj_name - print(old_obj_name, self.points_per_obj[old_obj_name]) - for item in self.points_per_obj[old_obj_name]: - if item[0] != self.current_point: - new_list.append(item) - else: - target = item - - self.points_per_obj[old_obj_name] = new_list - # add the new assigned point to the front - self.points_per_obj[new_bj_name].insert(0, target) - - - - def handle_accept_batch(self, point_ids=None, obj_name=None): - if point_ids is not None: - for point_id in point_ids: - self.load_curr_point(point_id=point_id, obj_name=obj_name) - self.handle_accept() - - - def handle_reject_batch(self, point_ids=None, obj_name=None): - if point_ids is not None: - for point_id in point_ids: - self.load_curr_point(point_id=point_id, obj_name=obj_name) - self.handle_reject() - - - def handle_assign_batch(self, point_ids=None, obj_name=None, new_bj_name=None): - if point_ids is not None and obj_name is not None and new_bj_name is not None: - for point_id in point_ids: - self.load_curr_point(point_id=point_id, obj_name=obj_name) - self.handle_assign(new_bj_name) - - - -copick_dataset = None -def get_copick_dataset(COPICKLIVE_CONFIG_PATH=None, COPICK_TEMPLATE_PATH=None): - global copick_dataset - if not copick_dataset: - if not COPICKLIVE_CONFIG_PATH or not COPICK_TEMPLATE_PATH: - config = configparser.ConfigParser() - config.read(os.path.join(os.getcwd(), "config.ini")) - if not COPICKLIVE_CONFIG_PATH: - COPICKLIVE_CONFIG_PATH = '%s' % config['copicklive_config']['COPICKLIVE_CONFIG_PATH'] - if not COPICK_TEMPLATE_PATH: - COPICK_TEMPLATE_PATH = '%s' % config['copick_template']['COPICK_TEMPLATE_PATH'] - copick_dataset = CopickDataset(copick_config_path=COPICKLIVE_CONFIG_PATH, copick_config_path_tomogram=COPICK_TEMPLATE_PATH) - -get_copick_dataset() \ No newline at end of file diff --git a/utils/local_dataset.py b/utils/local_dataset.py deleted file mode 100644 index e7bbc37..0000000 --- a/utils/local_dataset.py +++ /dev/null @@ -1,211 +0,0 @@ -import os, pathlib, time -import threading - -import random, json, copy, configparser -from collections import defaultdict, deque -import json, zarr - - -dirs = ['TS_'+str(i)+'_'+str(j) for i in range(1,100) for j in range(1,10)] -dir2id = {j:i for i,j in enumerate(dirs)} -dir_set = set(dirs) - - -# define a wrapper function -def threaded(fn): - def wrapper(*args, **kwargs): - thread = threading.Thread(target=fn, args=args, kwargs=kwargs) - thread.start() - return thread - return wrapper - - -class LocalDataset: - def __init__(self, local_file_path: str=None, config_path: str=None): - self.root = local_file_path - with open(config_path) as f: - self.config_file = json.load(f) - - # output - self.proteins = defaultdict(int) # {'ribosome': 38, ...} - self.tomograms = defaultdict(set) #{'TS_1_1':{'ribosome', ...}, ...} - self.tomos_per_person = defaultdict(set) #{'john.doe':{'TS_1_1',...},...} - self.tomos_pickers = defaultdict(set) #{'Test_1_1': {john.doe,...}, ...} - self.num_per_person_ordered = dict() # {'Tom':5, 'Julie':3, ...} - - # hidden variables for updating candidate recomendations - self._all = set([i for i in range(len(dirs))]) - self._tomos_done = set() # labeled at least by 2 people, {0, 1, 2} - self._tomos_one_pick = set() # labeled only by 1 person, {3,4,5,...} - self._candidate_dict = defaultdict() # {1:1, 2:0, ...} - self._prepicks = set(['slab-picking', - 'pytom-template-match', - 'relion-refinement', - 'prepick', - 'ArtiaX', - 'default'] - ) - - xdata = [] - colors = dict() - for po in self.config_file["pickable_objects"]: - xdata.append(po["name"]) - colors[po["name"]] = po["color"] - - self._im_dataset = {'name': xdata, - 'count': [], - 'colors': colors - } - - def _reset(self): - self.proteins = defaultdict(int) - self._tomos_one_pick = set() #may remove some elems, thereofore, empty before each check - - xdata = [] - colors = dict() - for po in self.config_file["pickable_objects"]: - xdata.append(po["name"]) - colors[po["name"]] = po["color"] - - self._im_dataset = {'name': xdata, - 'count': [], - 'colors': colors - } - - - def refresh(self): - self._reset() - self._update_tomo_sts() - - - @threaded - def _walk_dir(self, args): - r, s, e = args - for dir in dirs[s:e]: - dir_path = r + dir +'/Picks' - if os.path.exists(dir_path): - for json_file in pathlib.Path(dir_path).glob('*.json'): - try: - contents = json.load(open(json_file)) - if 'user_id' in contents and contents['user_id'] not in self._prepicks: - if 'pickable_object_name' in contents and \ - 'run_name' in contents and contents['run_name'] in dir_set and \ - 'points' in contents and contents['points'] and len(contents['points']): - self.proteins[contents['pickable_object_name']] += len(contents['points']) - self.tomos_per_person[contents['user_id']].add(contents['run_name']) - self.tomograms[contents['run_name']].add(contents['pickable_object_name']) - self.tomos_pickers[contents['run_name']].add(contents['user_id']) - except: - pass - - - def _update_tomo_sts(self): - start = time.time() - seg = round(len(dirs)/6) - args1 = (self.root, 0, seg) - args2 = (self.root, seg, seg*2) - args3 = (self.root, seg*2, seg*3) - args4 = (self.root, seg*3, seg*4) - args5 = (self.root, seg*4, seg*5) - args6 = (self.root, seg*5, len(dirs)) - - t1 = self._walk_dir(args1) - t2 = self._walk_dir(args2) - t3 = self._walk_dir(args3) - t4 = self._walk_dir(args4) - t5 = self._walk_dir(args5) - t6 = self._walk_dir(args6) - - t1.join() - t2.join() - t3.join() - t4.join() - t5.join() - t6.join() - print(f'{time.time()-start} s to check all files') - - for tomo,pickers in self.tomos_pickers.items(): - if len(pickers) >= 2: - self._tomos_done.add(dir2id[tomo]) - elif len(pickers) == 1: - self._tomos_one_pick.add(dir2id[tomo]) - - self.num_per_person_ordered = dict(sorted(self.tomos_per_person.items(), key=lambda item: len(item[1]), reverse=True)) - - - def _update_candidates(self, n, random_sampling=True): - # remove candidates that should not be considered any more - _candidate_dict = defaultdict() - for candidate in self._candidate_dict.keys(): - if candidate in self._tomos_done: - continue - _candidate_dict[candidate] = self._candidate_dict[candidate] - self._candidate_dict = _candidate_dict - - # add candidates that have been picked once - if len(self._candidate_dict) < n: - for i in self._tomos_one_pick: - self._candidate_dict[i] = 1 - if len(self._candidate_dict) == n: - break - - # add candidates that have not been picked yet - if len(self._candidate_dict) < n: - residuals = self._all - self._tomos_done - self._tomos_one_pick - residuals = deque(residuals) - while residuals and len(self._candidate_dict) < n: - if random_sampling: - new_id = random.randint(0,len(residuals)) - self._candidate_dict[residuals[new_id]] = 0 - del residuals[new_id] - else: - new_candidate = residuals.popleft() - self._candidate_dict[new_candidate] = 0 - - - def candidates(self, n: int, random_sampling=True) -> dict: - self._candidate_dict = {k:0 for k in range(n)} if not random_sampling else {k:0 for k in random.sample(range(len(dirs)), n)} - self._update_candidates(n, random_sampling) - return {k: v for k, v in sorted(self._candidate_dict.items(), key=lambda x: x[1], reverse=True)} - - - def fig_data(self): - image_dataset = copy.deepcopy(self._im_dataset) - proteins = copy.deepcopy(self.proteins) - for name in image_dataset['name']: - image_dataset['count'].append(proteins[name]) - - image_dataset['colors'] = {k:'rgba'+str(tuple(v)) for k,v in image_dataset['colors'].items()} - return image_dataset - - - -COUNTER_FILE_PATH = None -local_dataset = None -def get_local_dataset(LOCAL_FILE_PATH=None, COPICK_TEMPLATE_PATH=None, COUNTER_CHECKPOINT_PATH=None): - global local_dataset - global COUNTER_FILE_PATH - if not local_dataset: - if not LOCAL_FILE_PATH or not COPICK_TEMPLATE_PATH: - config = configparser.ConfigParser() - config.read(os.path.join(os.getcwd(), "config.ini")) - - if not LOCAL_FILE_PATH: - LOCAL_FILE_PATH = '%s' % config['local_picks']['PICK_FILE_PATH'] + 'ExperimentRuns/' - if not COPICK_TEMPLATE_PATH: - COPICK_TEMPLATE_PATH = '%s' % config['copick_template']['COPICK_TEMPLATE_PATH'] - - local_dataset = LocalDataset(local_file_path=LOCAL_FILE_PATH, config_path=COPICK_TEMPLATE_PATH) - - if not COUNTER_FILE_PATH: - if not COUNTER_CHECKPOINT_PATH: - config = configparser.ConfigParser() - config.read(os.path.join(os.getcwd(), "config.ini")) - - if not COUNTER_CHECKPOINT_PATH: - COUNTER_FILE_PATH = '%s' % config['counter_checkpoint']['COUNTER_FILE_PATH'] - else: - COUNTER_FILE_PATH = COUNTER_CHECKPOINT_PATH - - -get_local_dataset() \ No newline at end of file