Source code for mesmerize.plotting.widgets.datapoint_tracer
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on June 15 2018
@author: kushal
Chatzigeorgiou Group
Sars International Centre for Marine Molecular Biology
GNU GENERAL PUBLIC LICENSE Version 3, 29 June 2007
"""
from .datapoint_tracer_pytemplate import *
from ...analysis.history_widget import HistoryTreeWidget
from ...pyqtgraphCore import ImageView, LinearRegionItem, mkColor, PlotDataItem
from uuid import UUID
import pandas as pd
import tifffile
import numpy as np
import pickle
from ...viewer.modules.roi_manager_modules.roi_types import CNMFROI, ManualROI, ScatterROI, VolCNMF, VolMultiCNMFROI
# from ...viewer.core import ViewerWorkEnv, ViewerUtils
from ...common import get_window_manager, get_project_manager
# from common import configuration
import os
from typing import Union, Optional
from ...common.utils import draw_graph
from ...analysis.data_types import HistoryTrace
from copy import deepcopy
region_data_types = ['_pf_uuid', '_ST_uuid']
[docs]class DatapointTracerWidget(QtWidgets.QWidget):
def __init__(self):
QtWidgets.QWidget.__init__(self)
self.setWindowTitle('Datapoint Tracer')
self.uuid = None
self.row = None
self.proj_path = None
self.sample_id = None
self.previous_sample_id_projection = None
self.history_trace = None
self.peak_ix = None
self.tstart = None
self.tend = None
self.is_3d = False
self.ui = Ui_DatapointTracer()
self.ui.setupUi(self)
self.history_widget = HistoryTreeWidget(parent=self.ui.groupBoxInfo)
self.ui.groupBoxInfo.layout().addWidget(self.history_widget)
self.pandas_series_widget = HistoryTreeWidget(parent=self.ui.groupBoxInfo)
self.ui.groupBoxInfo.layout().addWidget(self.pandas_series_widget)
self.image_view = ImageView()
self.image_view.tVals = np.arange(0, 100)
self.image_item = self.image_view.getImageItem()
self.view = self.ui.graphicsViewImage.addViewBox()
self.view.setAspectLocked(True)
self.view.addItem(self.image_item)
self.peak_region = TimelineLinearRegion(self.ui.graphicsViewPlot)
self.roi = None
self.plot_data: np.ndarray = None
self.plot_data_item: PlotDataItem = None
self.ui.radioButtonMaxProjection.clicked.connect(lambda x: self.set_image('max'))
self.ui.radioButtonSTDProjection.clicked.connect(lambda x: self.set_image('std'))
self.ui.pushButtonOpenInViewer.clicked.connect(self.open_in_viewer)
self.ui.pushButtonOpenAnalysisGraph.clicked.connect(self.open_analysis_graph)
[docs] def set_widget(self, datapoint_uuid: UUID, data_column_curve: str, row: pd.Series, proj_path: str,
history_trace: Optional[list] = None, peak_ix: Optional[int] = None, tstart: Optional[int] = None,
tend: Optional[int] = None, roi_color: Optional[Union[str, float, int, tuple]] = 'ff0000',
clear_linear_regions: bool = True):
"""
Set the widget from the datapoint.
:param datapoint_uuid: appropriate UUID for the datapoint (such as uuid_curve or _pfeature_uuid)
:param data_column_curve: data column containing an array to plot
:param row: DataFrame row that corresponds to the datapoint
:param proj_path: root dir of the project the datapoint comes from, used for finding max & std projections
:param history_trace: history trace of the datablock the datapoint comes from
:param peak_ix: Deprecated
:param tstart: lower bounds for drawing `LinearRegionItem`
:param tend: upper bounds for drawing `LinearRegionItem`
:param roi_color: color for drawing the spatial bounds of the ROI
"""
self.uuid = datapoint_uuid
self.row = row
self.proj_path = proj_path
self.sample_id = self.row['SampleID']
self.ui.label_zlevel.clear()
if isinstance(self.sample_id, pd.Series):
self.sample_id = self.sample_id.item()
if not isinstance(self.sample_id, str):
raise ValueError('SampleID datatype is not str or pandas.Series. '
'Something is wrong, it is this datatype :' + str(type(self.sample_id)))
if history_trace is None:
self.history_trace = []
img_info_path = row['ImgInfoPath']
if isinstance(img_info_path, pd.Series):
img_info_path = img_info_path.item()
img_info_path = os.path.join(self.proj_path, img_info_path)
preprocess_history = pickle.load(open(img_info_path, 'rb'))['history_trace']
self.history_trace = preprocess_history + history_trace
self.ui.lineEditUUID.setText(str(self.uuid))
self.history_widget.fill_widget(self.history_trace)
row_dict = row.to_dict()
for k in row_dict.keys():
row_dict[k] = row_dict[k][row.index.item()]
self.pandas_series_widget.fill_widget(row_dict)
# self.pandas_series_widget.collapseAll()
if self.ui.radioButtonMaxProjection.isChecked():
self.img_proj = 'max'
elif self.ui.radioButtonSTDProjection.isChecked():
self.img_proj = 'std'
if clear_linear_regions:
self.peak_region.clear_all()
if (tstart is not None) and (tend is not None):
self.peak_region.add_linear_region(tstart, tend, color=mkColor('#a80035'))
# get the plot data
try:
self.plot_data = self.row[data_column_curve].item()
except:
self.plot_data = self.row[data_column_curve]
# get a new pyqtgraph plot data item
if self.plot_data_item is None:
self.plot_data_item = self.ui.graphicsViewPlot.plot(self.plot_data)
# or set the existing one
else:
self.plot_data_item.clear()
self.plot_data_item.setData(self.plot_data)
self.plot_data_item.setPen('w', width=2)
self.plot_data_item.setZValue(1)
if self.roi is not None:
self.roi.remove_from_viewer()
try:
roi_state = self.row['ROI_State'].item()
except:
roi_state = self.row['ROI_State']
if roi_state['roi_type'] in ['CNMFROI', 'ScatterROI']:
self.set_image(self.img_proj)
ROIClass = globals()[roi_state['roi_type']]
self.roi = ROIClass.from_state(curve_plot_item=None, view_box=self.view, state=roi_state)
self.roi.get_roi_graphics_object().setBrush(mkColor(roi_color))
elif roi_state['roi_type'] == 'ManualROI':
self.set_image(self.img_proj)
self.roi = ManualROI.from_state(curve_plot_item=None, view_box=self.view, state=roi_state)
elif roi_state['roi_type'] in ['VolCNMF']:
self.is_3d = True
self.zcenter = roi_state['zcenter']
self.set_image(self.img_proj)
self.roi = VolCNMF.from_state(curve_plot_item=None, view_box=self.view, state=roi_state, zlevel=self.zcenter)
self.roi.get_roi_graphics_object().setBrush(mkColor(roi_color))
elif roi_state['roi_type'] in ['VolMultiCNMFROI']:
self.is_3d = True
self.zcenter = roi_state['zcenter']
self.set_image(self.img_proj)
self.roi = VolMultiCNMFROI.from_state(
curve_plot_item=None, view_box=self.view, state=roi_state, zlevel=self.zcenter
)
self.roi.get_roi_graphics_object().setPen(mkColor(roi_color))
self.roi.add_to_viewer()
if peak_ix is not None:
pass
[docs] def set_image(self, projection: str):
"""
Set either the max or std projection image
:param projection: one of either 'max' or 'std'
"""
if f'{self.sample_id}{projection}' == self.previous_sample_id_projection:
if not self.is_3d:
return
img_uuid = self.row['ImgUUID']
if isinstance(img_uuid, pd.Series):
img_uuid = img_uuid.item()
if not isinstance(img_uuid, str):
raise ValueError('Datatype for Projection Path must be pandas.Series or str, it is currently : ' + str(
type(img_uuid)))
if projection == 'max':
suffix = '_max_proj'
elif projection == 'std':
suffix = '_std_proj'
else:
raise ValueError('Can only accept "max" and "std" arguments')
self.ui.label_zlevel.clear()
if self.is_3d:
suffix += f'-{self.zcenter}'
self.ui.label_zlevel.setText(f'Showing plane #: {self.zcenter} ')
img_path = os.path.join(self.proj_path, 'images', f'{self.sample_id}-_-{img_uuid}{suffix}.tiff')
img = tifffile.imread(img_path)
# z = np.zeros(img.shape)
# img = np.dstack((img, z))
# if img.shape[0] > img.shape[1]:
# x, y = (0, 1)
# else:
# x,y = (1, 0)
vmin = np.nanmin(img)
vmax = np.nanmedian(img) + (10 * np.nanstd(img))
self.image_view.setImage(img, axes={'x': 0, 'y': 1}, levels=(vmin, vmax))
self.previous_sample_id_projection = f'{self.sample_id}{projection}'
# self.image_item.setImage(img.T.astype(np.uint16))
# self.image_item.resetTransform()
[docs] def open_in_viewer(self):
"""
Open the parent Sample of the current datapoint.
"""
w = get_window_manager().get_new_viewer_window()
w.open_from_dataframe(proj_path=self.proj_path, row=self.row)
def open_analysis_graph(self):
cleaned = HistoryTrace.clean_history_trace(deepcopy(self.history_trace))
draw_graph(cleaned, view=True)
class TimelineLinearRegion:
def __init__(self, plot_widget: PlotWidget):
self.plot_widget = plot_widget
self.linear_regions = []
def add_linear_region(self, frame_start: int, frame_end: int, color: QtGui.QColor) -> LinearRegionItem:
linear_region = LinearRegionItem(values=[frame_start, frame_end],
brush=color, movable=False, bounds=[frame_start, frame_end])
linear_region.setZValue(0)
linear_region.lines[0].setPen(mkColor('#000000'))
linear_region.lines[1].setPen(mkColor('#000000'))
self.linear_regions.append(linear_region)
self.plot_widget.addItem(linear_region)
return linear_region
def del_linear_region(self, linear_region: LinearRegionItem):
self.plot_widget.removeItem(linear_region)
linear_region.deleteLater()
self.linear_regions.remove(linear_region)
def clear_all(self):
for region in self.linear_regions:
self.plot_widget.removeItem(region)
region.deleteLater()
self.linear_regions = []