diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 470e2e5a5..f17941405 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -49,7 +49,7 @@ jobs: - name: build docs run: | cd docs - RTD_BUILD=1 make html SPHINXOPTS="-W --keep-going" + DOCS_BUILD=1 make html SPHINXOPTS="-W --keep-going" # set environment variable `DOCS_VERSION_DIR` to either the pr-branch name, "dev", or the release version tag - name: set output pr diff --git a/docs/source/api/graphics/LineGraphic.rst b/docs/source/api/graphics/LineGraphic.rst index 428e8ef56..867f1bfbb 100644 --- a/docs/source/api/graphics/LineGraphic.rst +++ b/docs/source/api/graphics/LineGraphic.rst @@ -25,6 +25,7 @@ Properties LineGraphic.axes LineGraphic.block_events LineGraphic.cmap + LineGraphic.color_mode LineGraphic.colors LineGraphic.data LineGraphic.deleted diff --git a/docs/source/api/graphics/ScatterGraphic.rst b/docs/source/api/graphics/ScatterGraphic.rst index cf8e1224d..f9dcd2487 100644 --- a/docs/source/api/graphics/ScatterGraphic.rst +++ b/docs/source/api/graphics/ScatterGraphic.rst @@ -25,6 +25,7 @@ Properties ScatterGraphic.axes ScatterGraphic.block_events ScatterGraphic.cmap + ScatterGraphic.color_mode ScatterGraphic.colors ScatterGraphic.data ScatterGraphic.deleted diff --git a/examples/events/cmap_event.py b/examples/events/cmap_event.py index 62913cb29..f01f06d6a 100644 --- a/examples/events/cmap_event.py +++ b/examples/events/cmap_event.py @@ -34,7 +34,7 @@ xs = np.linspace(0, 4 * np.pi, 100) ys = np.sin(xs) -figure["sine"].add_line(np.column_stack([xs, ys])) +figure["sine"].add_line(np.column_stack([xs, ys]), color_mode="vertex") # make a 2D gaussian cloud cloud_data = np.random.normal(0, scale=3, size=1000).reshape(500, 2) diff --git a/examples/gridplot/multigraphic_gridplot.py b/examples/gridplot/multigraphic_gridplot.py index cbf546e2a..0e89efcdc 100644 --- a/examples/gridplot/multigraphic_gridplot.py +++ b/examples/gridplot/multigraphic_gridplot.py @@ -106,7 +106,7 @@ def make_circle(center, radius: float, n_points: int = 75) -> np.ndarray: gaussian_cloud2 = np.random.multivariate_normal(mean, covariance, n_points) # add the scatter graphics to the figure -figure["scatter"].add_scatter(data=gaussian_cloud, sizes=2, cmap="jet") +figure["scatter"].add_scatter(data=gaussian_cloud, sizes=2, cmap="jet", color_mode="vertex") figure["scatter"].add_scatter(data=gaussian_cloud2, colors="r", sizes=2) figure.show() diff --git a/examples/guis/imgui_basic.py b/examples/guis/imgui_basic.py index 26b5603c0..26c2c0fca 100644 --- a/examples/guis/imgui_basic.py +++ b/examples/guis/imgui_basic.py @@ -29,10 +29,10 @@ figure = fpl.Figure(size=(700, 560)) # make some scatter points at every 10th point -figure[0, 0].add_scatter(data[::10], colors="cyan", sizes=15, name="sine-scatter", uniform_color=True) +figure[0, 0].add_scatter(data[::10], colors="cyan", sizes=15, name="sine-scatter") # place a line above the scatter -figure[0, 0].add_line(data, thickness=3, colors="r", name="sine-wave", uniform_color=True) +figure[0, 0].add_line(data, thickness=3, colors="r", name="sine-wave") class ImguiExample(EdgeWindow): diff --git a/examples/image/image_reshaping.py b/examples/image/image_reshaping.py new file mode 100644 index 000000000..23264bda1 --- /dev/null +++ b/examples/image/image_reshaping.py @@ -0,0 +1,50 @@ +""" +Image reshaping +=============== + +An example that shows replacement of the image data with new data of a different shape. Under the hood, this creates a +new buffer and a new array of Textures on the GPU that replace the older Textures. Creating a new buffer and textures +has a performance cost, so you should do this only if you need to or if the performance drawback is not a concern for +your use case. + +Note that the vmin-vmax is reset when you replace the buffers. +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'animate' + + +import numpy as np +import fastplotlib as fpl + +# create some data, diagonal sinusoidal bands +xs = np.linspace(0, 2300, 2300, dtype=np.float16) +full_data = np.vstack([np.cos(np.sqrt(xs + (np.pi / 2) * i)) * i for i in range(2_300)]) + +figure = fpl.Figure() + +image = figure[0, 0].add_image(full_data) + +figure.show() + +i, j = 1, 1 + + +def update(): + global i, j + # set the new image data as a subset of the full data + row = np.abs(np.sin(i)) * 2300 + col = np.abs(np.cos(i)) * 2300 + image.data = full_data[: int(row), : int(col)] + + i += 0.01 + j += 0.01 + + +figure.add_animations(update) + +# NOTE: fpl.loop.run() should not be used for interactive sessions +# See the "JupyterLab and IPython" section in the user guide +if __name__ == "__main__": + print(__doc__) + fpl.loop.run() diff --git a/examples/line/line_cmap.py b/examples/line/line_cmap.py index 3d2b5e8c9..6dfc1fe23 100644 --- a/examples/line/line_cmap.py +++ b/examples/line/line_cmap.py @@ -27,7 +27,7 @@ data=sine_data, thickness=10, cmap="plasma", - cmap_transform=sine_data[:, 1] + cmap_transform=sine_data[:, 1], ) # qualitative colormaps, useful for cluster labels or other types of categorical labels @@ -36,7 +36,7 @@ data=cosine_data, thickness=10, cmap="tab10", - cmap_transform=labels + cmap_transform=labels, ) figure.show() diff --git a/examples/line/line_cmap_more.py b/examples/line/line_cmap_more.py index c7c0d80f4..c6e811fb2 100644 --- a/examples/line/line_cmap_more.py +++ b/examples/line/line_cmap_more.py @@ -31,16 +31,35 @@ # set colormap by mapping data using a transform # here we map the color using the y-values of the sine data # i.e., the color is a function of sine(x) -line2 = figure[0, 0].add_line(sine, thickness=10, cmap="jet", cmap_transform=sine[:, 1], offset=(0, 4, 0)) +line2 = figure[0, 0].add_line( + sine, + thickness=10, + cmap="jet", + cmap_transform=sine[:, 1], + offset=(0, 4, 0), +) # make a line and change the cmap afterward, here we are using the cosine instead fot the transform -line3 = figure[0, 0].add_line(sine, thickness=10, cmap="jet", cmap_transform=cosine[:, 1], offset=(0, 6, 0)) +line3 = figure[0, 0].add_line( + sine, + thickness=10, + cmap="jet", + cmap_transform=cosine[:, 1], + offset=(0, 6, 0) +) + # change the cmap line3.cmap = "bwr" # use quantitative colormaps with categorical cmap_transforms labels = [0] * 25 + [1] * 5 + [2] * 50 + [3] * 20 -line4 = figure[0, 0].add_line(sine, thickness=10, cmap="tab10", cmap_transform=labels, offset=(0, 8, 0)) +line4 = figure[0, 0].add_line( + sine, + thickness=10, + cmap="tab10", + cmap_transform=labels, + offset=(0, 8, 0), +) # some text labels for i in range(5): diff --git a/examples/line/line_colorslice.py b/examples/line/line_colorslice.py index b6865eadb..264f944f3 100644 --- a/examples/line/line_colorslice.py +++ b/examples/line/line_colorslice.py @@ -30,7 +30,8 @@ sine = figure[0, 0].add_line( data=sine_data, thickness=5, - colors="magenta" + colors="magenta", + color_mode="vertex", # initialize with same color across vertices, but we will change the per-vertex colors later ) # you can also use colormaps for lines! @@ -56,6 +57,7 @@ data=zeros_data, thickness=8, colors="w", + color_mode="vertex", # initialize with same color across vertices, but we will change the per-vertex colors later offset=(0, 10, 0) ) diff --git a/examples/line_collection/line_collection_slicing.py b/examples/line_collection/line_collection_slicing.py index f829a53c6..98ad97056 100644 --- a/examples/line_collection/line_collection_slicing.py +++ b/examples/line_collection/line_collection_slicing.py @@ -26,6 +26,7 @@ multi_data, thickness=[2, 10, 2, 5, 5, 5, 8, 8, 8, 9, 3, 3, 3, 4, 4], separation=4, + color_mode="vertex", # this will allow us to set per-vertex colors on each line metadatas=list(range(15)), # some metadata names=list("abcdefghijklmno"), # unique name for each line ) diff --git a/examples/machine_learning/kmeans.py b/examples/machine_learning/kmeans.py index f571882ce..4c49844f0 100644 --- a/examples/machine_learning/kmeans.py +++ b/examples/machine_learning/kmeans.py @@ -80,6 +80,7 @@ sizes=5, cmap="tab10", # use a qualitative cmap cmap_transform=kmeans.labels_, # color by the predicted cluster + uniform_size=False, ) # initial index diff --git a/examples/misc/buffer_replace_gc.py b/examples/misc/buffer_replace_gc.py new file mode 100644 index 000000000..e3b0ac104 --- /dev/null +++ b/examples/misc/buffer_replace_gc.py @@ -0,0 +1,91 @@ +""" +Buffer replacement garbage collection test +========================================== + +This is an example that used for a manual test to ensure that GPU VRAM is free when buffers are replaced. + +Use while monitoring VRAM usage with nvidia-smi +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'code' + + +from typing import Literal +import numpy as np +import fastplotlib as fpl +from fastplotlib.ui import EdgeWindow +from imgui_bundle import imgui + + +def generate_dataset(size: int) -> dict[str, np.ndarray]: + return { + "data": np.random.rand(size, 3), + "colors": np.random.rand(size, 4), + # TODO: there's a wgpu bind group issue with edge_colors, will figure out later + # "edge_colors": np.random.rand(size, 4), + "markers": np.random.choice(list("osD+x^v<>*"), size=size), + "sizes": np.random.rand(size) * 5, + "point_rotations": np.random.rand(size) * 180, + } + + +datasets = { + "init": generate_dataset(50_000), + "small": generate_dataset(100), + "large": generate_dataset(5_000_000), +} + + +class UI(EdgeWindow): + def __init__(self, figure): + super().__init__(figure=figure, size=200, location="right", title="UI") + init_data = datasets["init"] + self._figure["line"].add_line( + data=init_data["data"], colors=init_data["colors"], name="line" + ) + self._figure["scatter"].add_scatter( + **init_data, + uniform_size=False, + uniform_marker=False, + uniform_edge_color=False, + point_rotation_mode="vertex", + name="scatter", + ) + + def update(self): + for graphic in ["line", "scatter"]: + if graphic == "line": + features = ["data", "colors"] + + elif graphic == "scatter": + features = list(datasets["init"].keys()) + + for size in ["small", "large"]: + for fea in features: + if imgui.button(f"{size} - {graphic} - {fea}"): + self._replace(graphic, fea, size) + + def _replace( + self, + graphic: Literal["line", "scatter", "image"], + feature: Literal["data", "colors", "markers", "sizes", "point_rotations"], + size: Literal["small", "large"], + ): + new_value = datasets[size][feature] + + setattr(self._figure[graphic][graphic], feature, new_value) + + +figure = fpl.Figure(shape=(3, 1), size=(700, 1600), names=["line", "scatter", "image"]) +ui = UI(figure) +figure.add_gui(ui) + +figure.show() + + +# NOTE: fpl.loop.run() should not be used for interactive sessions +# See the "JupyterLab and IPython" section in the user guide +if __name__ == "__main__": + print(__doc__) + fpl.loop.run() diff --git a/examples/misc/lorenz_animation.py b/examples/misc/lorenz_animation.py index 20aee5d83..52a77a243 100644 --- a/examples/misc/lorenz_animation.py +++ b/examples/misc/lorenz_animation.py @@ -60,7 +60,12 @@ def lorenz(xyz, *, s=10, r=28, b=2.667): scatter_markers = list() for graphic in lorenz_line: - marker = figure[0, 0].add_scatter(graphic.data.value[0], sizes=16, colors=graphic.colors[0]) + marker = figure[0, 0].add_scatter( + graphic.data.value[0], + sizes=16, + colors=graphic.colors, + edge_colors="w", + ) scatter_markers.append(marker) # initialize time diff --git a/examples/misc/reshape_lines_scatters.py b/examples/misc/reshape_lines_scatters.py new file mode 100644 index 000000000..db8adb29e --- /dev/null +++ b/examples/misc/reshape_lines_scatters.py @@ -0,0 +1,92 @@ +""" +Change number of points in lines and scatters +============================================= + +This example sets lines and scatters with new data of a different shape, i.e. new data with more or fewer datapoints. +Internally, this creates new buffers for the feature that is being set (data, colors, markers, etc.). Note that there +are performance drawbacks to doing this, so it is recommended to maintain the same number of datapoints in a graphic +when possible. You only want to change the number of datapoints when it's really necessary, and you don't want to do +it constantly (such as tens or hundreds of times per second). + +This example is also useful for manually checking that GPU buffers are freed when they're no longer in use. Run this +example while monitoring VRAM usage with `nvidia-smi` +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'animate' + + +import numpy as np +import fastplotlib as fpl + +# create some data to start with +xs = np.linspace(0, 10 * np.pi, 100) +ys = np.sin(xs) + +data = np.column_stack([xs, ys]) + +# create a figure, add a line, scatter and line_stack +figure = fpl.Figure(shape=(3, 1), size=(700, 700)) + +line = figure[0, 0].add_line(data) + +scatter = figure[1, 0].add_scatter( + np.random.rand(100, 3), + colors=np.random.rand(100, 4), + markers=np.random.choice(list("osD+x^v<>*"), size=100), + sizes=(np.random.rand(100) + 1) * 3, + edge_colors=np.random.rand(100, 4), + point_rotations=np.random.rand(100) * 180, + uniform_marker=False, + uniform_size=False, + uniform_edge_color=False, + point_rotation_mode="vertex", +) + +line_stack = figure[2, 0].add_line_stack(np.stack([data] * 10), cmap="viridis") + +text = figure[0, 0].add_text(f"n_points: {100}", offset=(0, 1.5, 0), anchor="middle-left") + +figure.show(maintain_aspect=False) + +i = 0 + + +def update(): + # set a new larger or smaller data array on every render + global i + + # create new data + freq = np.abs(np.sin(i)) * 10 + n_points = int((freq * 20_000) + 10) + + xs = np.linspace(0, 10 * np.pi, n_points) + ys = np.sin(xs * freq) + + new_data = np.column_stack([xs, ys]) + + # update line data + line.data = new_data + + # update scatter data, colors, markers, etc. + scatter.data = np.random.rand(n_points, 3) + scatter.colors = np.random.rand(n_points, 4) + scatter.markers = np.random.choice(list("osD+x^v<>*"), size=n_points) + scatter.edge_colors = np.random.rand(n_points, 4) + scatter.point_rotations = np.random.rand(n_points) * 180 + + # update line stack data + line_stack.data = np.stack([new_data] * 10) + + text.text = f"n_points: {n_points}" + + i += 0.01 + + +figure.add_animations(update) + +# NOTE: fpl.loop.run() should not be used for interactive sessions +# See the "JupyterLab and IPython" section in the user guide +if __name__ == "__main__": + print(__doc__) + fpl.loop.run() diff --git a/examples/misc/scatter_animation.py b/examples/misc/scatter_animation.py index d37aea976..549059b65 100644 --- a/examples/misc/scatter_animation.py +++ b/examples/misc/scatter_animation.py @@ -37,7 +37,7 @@ figure = fpl.Figure(size=(700, 560)) subplot_scatter = figure[0, 0] # use an alpha value since this will be a lot of points -scatter = subplot_scatter.add_scatter(data=cloud, sizes=3, colors=colors, alpha=0.6) +scatter = subplot_scatter.add_scatter(data=cloud, sizes=3, uniform_size=False, colors=colors, alpha=0.6) def update_points(subplot): diff --git a/examples/misc/scatter_sizes_animation.py b/examples/misc/scatter_sizes_animation.py index 53a616a68..2092787f3 100644 --- a/examples/misc/scatter_sizes_animation.py +++ b/examples/misc/scatter_sizes_animation.py @@ -20,7 +20,7 @@ figure = fpl.Figure(size=(700, 560)) -figure[0, 0].add_scatter(data, sizes=sizes, name="sine") +figure[0, 0].add_scatter(data, sizes=sizes, uniform_size=False, name="sine") i = 0 diff --git a/examples/notebooks/quickstart.ipynb b/examples/notebooks/quickstart.ipynb index 7b7551588..61bcb6b06 100644 --- a/examples/notebooks/quickstart.ipynb +++ b/examples/notebooks/quickstart.ipynb @@ -719,8 +719,8 @@ "# we will add all the lines to the same subplot\n", "subplot = fig_lines[0, 0]\n", "\n", - "# plot sine wave, use a single color\n", - "sine = subplot.add_line(data=sine_data, thickness=5, colors=\"magenta\")\n", + "# plot sine wave, use a single color for now, but we will set per-vertex colors later\n", + "sine = subplot.add_line(data=sine_data, thickness=5, colors=\"magenta\", color_mode=\"vertex\")\n", "\n", "# you can also use colormaps for lines!\n", "cosine = subplot.add_line(data=cosine_data, thickness=12, cmap=\"autumn\")\n", diff --git a/examples/scatter/scatter_iris.py b/examples/scatter/scatter_iris.py index b9df16026..fc228e5bf 100644 --- a/examples/scatter/scatter_iris.py +++ b/examples/scatter/scatter_iris.py @@ -35,6 +35,7 @@ cmap="tab10", cmap_transform=clusters_labels, markers=markers, + uniform_marker=False, ) figure.show() diff --git a/examples/scatter/scatter_size.py b/examples/scatter/scatter_size.py index 30d3e6ea3..2b3899dbe 100644 --- a/examples/scatter/scatter_size.py +++ b/examples/scatter/scatter_size.py @@ -35,7 +35,7 @@ ) # add a set of scalar sizes non_scalar_sizes = np.abs((y_values / np.pi)) # ensure minimum size of 5 -figure["array_size"].add_scatter(data=data, sizes=non_scalar_sizes, colors="red") +figure["array_size"].add_scatter(data=data, sizes=non_scalar_sizes, uniform_size=False, colors="red") for graph in figure: graph.auto_scale(maintain_aspect=True) diff --git a/examples/scatter/scatter_validate.py b/examples/scatter/scatter_validate.py index abddffee0..45f0a177c 100644 --- a/examples/scatter/scatter_validate.py +++ b/examples/scatter/scatter_validate.py @@ -41,6 +41,7 @@ uniform_edge_color=False, edge_colors=["w"] * 3 + ["orange"] * 3 + ["blue"] * 3 + ["green"], markers=list("osD+x^v<>*"), + uniform_marker=False, edge_width=2.0, sizes=20, uniform_size=True, @@ -64,6 +65,7 @@ sine, markers="s", sizes=xs * 5, + uniform_size=False, offset=(0, 2, 0) ) diff --git a/examples/scatter/spinning_spiral.py b/examples/scatter/spinning_spiral.py index 89e74eaec..4f947970a 100644 --- a/examples/scatter/spinning_spiral.py +++ b/examples/scatter/spinning_spiral.py @@ -34,7 +34,14 @@ canvas_kwargs={"max_fps": 500, "vsync": False} ) -spiral = figure[0, 0].add_scatter(data, cmap="viridis_r", edge_colors=None, alpha=0.5, sizes=sizes) +spiral = figure[0, 0].add_scatter( + data, + cmap="viridis_r", + edge_colors=None, + alpha=0.5, + sizes=sizes, + uniform_size=False, +) # pre-generate normally distributed data to jitter the points before each render jitter = np.random.normal(scale=0.001, size=n * 3).reshape((n, 3)) diff --git a/fastplotlib/graphics/__init__.py b/fastplotlib/graphics/__init__.py index 3d01e4a35..8734a5e72 100644 --- a/fastplotlib/graphics/__init__.py +++ b/fastplotlib/graphics/__init__.py @@ -7,7 +7,7 @@ from .mesh import MeshGraphic, SurfaceGraphic, PolygonGraphic from .text import TextGraphic from .line_collection import LineCollection, LineStack - +from .scatter_collection import ScatterCollection __all__ = [ "Graphic", @@ -22,4 +22,5 @@ "TextGraphic", "LineCollection", "LineStack", + "ScatterCollection", ] diff --git a/fastplotlib/graphics/_base.py b/fastplotlib/graphics/_base.py index 47673cbc0..e0602e4e3 100644 --- a/fastplotlib/graphics/_base.py +++ b/fastplotlib/graphics/_base.py @@ -178,6 +178,7 @@ def __init__( self._alpha_mode = AlphaMode(alpha_mode) self._visible = Visible(visible) self._block_events = False + self._block_handlers = list() self._axes: Axes = None @@ -274,6 +275,11 @@ def block_events(self) -> bool: def block_events(self, value: bool): self._block_events = value + @property + def block_handlers(self) -> list: + """Used to block event handlers for a graphic and prevent recursion.""" + return self._block_handlers + @property def world_object(self) -> pygfx.WorldObject: """Associated pygfx WorldObject. Always returns a proxy, real object cannot be accessed directly.""" @@ -440,6 +446,9 @@ def _handle_event(self, callback, event: pygfx.Event): if self.block_events: return + if callback in self._block_handlers: + return + if event.type in self._features: # for feature events event._target = self.world_object diff --git a/fastplotlib/graphics/_positions_base.py b/fastplotlib/graphics/_positions_base.py index af7d7badb..763f5e775 100644 --- a/fastplotlib/graphics/_positions_base.py +++ b/fastplotlib/graphics/_positions_base.py @@ -1,4 +1,6 @@ -from typing import Any, Sequence +from numbers import Real +from typing import Any, Sequence, Literal +from warnings import warn import numpy as np @@ -18,12 +20,20 @@ class PositionsGraphic(Graphic): @property def data(self) -> VertexPositions: - """Get or set the graphic's data""" + """ + Get or set the graphic's data. + + Note that if the number of datapoints does not match the number of + current datapoints a new buffer is automatically allocated. This can + have performance drawbacks when you have a very large number of datapoints. + This is usually fine as long as you don't need to do it hundreds of times + per second. + """ return self._data @data.setter def data(self, value): - self._data[:] = value + self._data.set_value(self, value) @property def colors(self) -> VertexColors | pygfx.Color: @@ -36,11 +46,59 @@ def colors(self) -> VertexColors | pygfx.Color: @colors.setter def colors(self, value: str | np.ndarray | Sequence[float] | Sequence[str]): + self._colors.set_value(self, value) + + @property + def color_mode(self) -> Literal["uniform", "vertex"]: + """ + Get or set the color mode. Note that after setting the color_mode, you will have to set the `colors` + as well for switching between 'uniform' and 'vertex' modes. + """ + return self.world_object.material.color_mode + + @color_mode.setter + def color_mode(self, mode: Literal["uniform", "vertex"]): + valid = ("uniform", "vertex") + if mode not in valid: + raise ValueError(f"`color_mode` must be one of : {valid}") + if mode == "vertex" and isinstance(self._colors, UniformColor): + # uniform -> vertex + # need to make a new vertex buffer and get rid of uniform buffer + new_colors = self._create_colors_buffer(self._colors.value, "vertex") + # we can't clear world_object.material.color so just set the colors buffer on the geometry + # this doesn't really matter anyways since the lingering uniform color takes up just a few bytes + self.world_object.geometry.colors = new_colors._fpl_buffer + + elif mode == "uniform" and isinstance(self._colors, VertexColors): + # vertex -> uniform + # use first vertex color and spit out a warning + warn( + "changing `color_mode` from vertex -> uniform, will use first vertex color " + "for the uniform and discard the remaining color values" + ) + new_colors = self._create_colors_buffer(self._colors.value[0], "uniform") + self.world_object.geometry.colors = None + self.world_object.material.color = new_colors.value + + # clear out cmap + self._cmap.clear_event_handlers() + self._cmap = None + + else: + # no change, return + return + + # restore event handlers onto the new colors feature + new_colors._event_handlers[:] = self._colors._event_handlers + self._colors.clear_event_handlers() + # this should trigger gc + self._colors = new_colors + + # this is created so that cmap can be set later if isinstance(self._colors, VertexColors): - self._colors[:] = value + self._cmap = VertexCmap(self._colors, cmap_name=None, transform=None) - elif isinstance(self._colors, UniformColor): - self._colors.set_value(self, value) + self.world_object.material.color_mode = mode @property def cmap(self) -> VertexCmap: @@ -53,8 +111,8 @@ def cmap(self) -> VertexCmap: @cmap.setter def cmap(self, name: str): - if self._cmap is None: - raise BufferError("Cannot use cmap with uniform_colors=True") + if self.color_mode == "uniform": + raise ValueError("cannot use `cmap` with `color_mode` = 'uniform'") self._cmap[:] = name @@ -71,14 +129,68 @@ def size_space(self): def size_space(self, value: str): self._size_space.set_value(self, value) + def _create_colors_buffer(self, colors, color_mode) -> UniformColor | VertexColors: + # creates either a UniformColor or VertexColors based on the given `colors` and `color_mode` + # if `color_mode` = "auto", returns {UniformColor | VertexColor} based on what the `colors` arg represents + # if `color_mode` = "uniform", it verifies that the user `colors` input represents just 1 color + # if `color_mode` = "vertex", always returns VertexColors regardless of whether `colors` represents >= 1 colors + + if isinstance(colors, VertexColors): + if color_mode == "uniform": + raise ValueError( + "if a `VertexColors` instance is provided for `colors`, " + "`color_mode` must be 'vertex' or 'auto', not 'uniform'" + ) + # share buffer with existing colors instance + new_colors = colors + # blank colormap instance + self._cmap = VertexCmap(new_colors, cmap_name=None, transform=None) + + else: + # determine if a single or multiple colors were passed and decide color mode + if isinstance(colors, (pygfx.Color, str)) or ( + len(colors) in [3, 4] and all(isinstance(v, Real) for v in colors) + ): + # one color specified as a str or pygfx.Color, or one color specified with RGB(A) values + if color_mode in ("auto", "uniform"): + new_colors = UniformColor(colors) + else: + new_colors = VertexColors( + colors, n_colors=self._data.value.shape[0] + ) + + elif all(isinstance(c, (str, pygfx.Color)) for c in colors): + # sequence of colors + if color_mode == "uniform": + raise ValueError( + "You passed `color_mode` = 'uniform', but specified a sequence of multiple colors. Use " + "`color_mode` = 'auto' or 'vertex' for multiple colors." + ) + new_colors = VertexColors(colors, n_colors=self._data.value.shape[0]) + + elif len(colors) > 4: + # sequence of multiple colors, must again ensure color_mode is not uniform + if color_mode == "uniform": + raise ValueError( + "You passed `color_mode` = 'uniform', but specified a sequence of multiple colors. Use " + "`color_mode` = 'auto' or 'vertex' for multiple colors." + ) + new_colors = VertexColors(colors, n_colors=self._data.value.shape[0]) + else: + raise ValueError( + "`colors` must be a str, pygfx.Color, array, list or tuple indicating an RGB(A) color, or a " + "sequence of str, pygfx.Color, or array of shape [n_datapoints, 3 | 4]" + ) + + return new_colors + def __init__( self, data: Any, colors: str | np.ndarray | tuple[float] | list[float] | list[str] = "w", - uniform_color: bool = False, cmap: str | VertexCmap = None, cmap_transform: np.ndarray = None, - isolated_buffer: bool = True, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", *args, **kwargs, @@ -86,22 +198,31 @@ def __init__( if isinstance(data, VertexPositions): self._data = data else: - self._data = VertexPositions(data, isolated_buffer=isolated_buffer) + self._data = VertexPositions(data) if cmap_transform is not None and cmap is None: raise ValueError("must pass `cmap` if passing `cmap_transform`") + valid = ("auto", "uniform", "vertex") + + # default _cmap is None + self._cmap = None + + if color_mode not in valid: + raise ValueError(f"`color_mode` must be one of {valid}") + if cmap is not None: # if a cmap is specified it overrides colors argument - if uniform_color: - raise TypeError("Cannot use cmap if uniform_color=True") + if color_mode == "uniform": + raise ValueError( + "if a `cmap` is provided, `color_mode` must be 'vertex' or 'auto', not 'uniform'" + ) if isinstance(cmap, str): # make colors from cmap if isinstance(colors, VertexColors): # share buffer with existing colors instance for the cmap self._colors = colors - self._colors._shared += 1 else: # create vertex colors buffer self._colors = VertexColors("w", n_colors=self._data.value.shape[0]) @@ -115,34 +236,18 @@ def __init__( # use existing cmap instance self._cmap = cmap self._colors = cmap._vertex_colors + else: raise TypeError( "`cmap` argument must be a cmap name or an existing `VertexCmap` instance" ) else: # no cmap given - if isinstance(colors, VertexColors): - # share buffer with existing colors instance - self._colors = colors - self._colors._shared += 1 - # blank colormap instance + self._colors = self._create_colors_buffer(colors, color_mode) + + # this is created so that cmap can be set later + if isinstance(self._colors, VertexColors): self._cmap = VertexCmap(self._colors, cmap_name=None, transform=None) - else: - if uniform_color: - if not isinstance(colors, str): # not a single color - if not len(colors) in [3, 4]: # not an RGB(A) array - raise TypeError( - "must pass a single color if using `uniform_colors=True`" - ) - self._colors = UniformColor(colors) - self._cmap = None - else: - self._colors = VertexColors( - colors, n_colors=self._data.value.shape[0] - ) - self._cmap = VertexCmap( - self._colors, cmap_name=None, transform=None - ) self._size_space = SizeSpace(size_space) super().__init__(*args, **kwargs) diff --git a/fastplotlib/graphics/features/_base.py b/fastplotlib/graphics/features/_base.py index 779310476..68fe54c33 100644 --- a/fastplotlib/graphics/features/_base.py +++ b/fastplotlib/graphics/features/_base.py @@ -1,5 +1,6 @@ +import weakref from warnings import warn -from typing import Literal +from typing import Callable import numpy as np from numpy.typing import NDArray @@ -78,7 +79,7 @@ def block_events(self, val: bool): """ self._block_events = val - def add_event_handler(self, handler: callable): + def add_event_handler(self, handler: Callable): """ Add an event handler. All added event handlers are called when this feature changes. @@ -89,7 +90,7 @@ def add_event_handler(self, handler: callable): Parameters ---------- - handler: callable + handler: Callable a function to call when this feature changes """ @@ -102,7 +103,7 @@ def add_event_handler(self, handler: callable): self._event_handlers.append(handler) - def remove_event_handler(self, handler: callable): + def remove_event_handler(self, handler: Callable): """ Remove a registered event ``handler``. @@ -137,32 +138,28 @@ class BufferManager(GraphicFeature): def __init__( self, - data: NDArray | pygfx.Buffer, - buffer_type: Literal["buffer", "texture", "texture-array"] = "buffer", - isolated_buffer: bool = True, + data: NDArray | pygfx.Buffer | None, **kwargs, ): super().__init__(**kwargs) - if isolated_buffer and not isinstance(data, pygfx.Resource): - # useful if data is read-only, example: memmaps - bdata = np.zeros(data.shape, dtype=data.dtype) - bdata[:] = data[:] - else: - # user's input array is used as the buffer - bdata = data - - if isinstance(data, pygfx.Resource): - # already a buffer, probably used for - # managing another BufferManager, example: VertexCmap manages VertexColors - self._buffer = data - elif buffer_type == "buffer": - self._buffer = pygfx.Buffer(bdata) + + # if data is None, then the BufferManager just provides a view into an existing buffer + # example: VertexCmap is basically a view into VertexColors + if data is not None: + if isinstance(data, pygfx.Resource): + # already a buffer, probably used for + # managing another BufferManager, example: VertexCmap manages VertexColors + self._fpl_buffer = data + else: + # create a buffer + bdata = np.empty(data.shape, dtype=data.dtype) + bdata[:] = data[:] + + self._fpl_buffer = pygfx.Buffer(bdata) else: - raise ValueError( - "`data` must be a pygfx.Buffer instance or `buffer_type` must be one of: 'buffer' or 'texture'" - ) + self._fpl_buffer = None - self._event_handlers: list[callable] = list() + self._event_handlers: list[Callable] = list() @property def value(self) -> np.ndarray: @@ -174,9 +171,10 @@ def set_value(self, graphic, value): self[:] = value @property - def buffer(self) -> pygfx.Buffer | pygfx.Texture: - """managed buffer""" - return self._buffer + def buffer(self) -> pygfx.Buffer: + """managed buffer, returns a weakref proxy""" + # the user should never create their own references to the buffer + return weakref.proxy(self._fpl_buffer) @property def __array_interface__(self): @@ -320,7 +318,7 @@ def __repr__(self): def block_reentrance(set_value): # decorator to block re-entrant set_value methods # useful when creating complex, circular, bidirectional event graphs - def set_value_wrapper(self: GraphicFeature, graphic_or_key, value): + def set_value_wrapper(self: GraphicFeature, graphic_or_key, value, **kwargs): """ wraps GraphicFeature.set_value @@ -336,7 +334,7 @@ def set_value_wrapper(self: GraphicFeature, graphic_or_key, value): try: # block re-execution of set_value until it has *fully* finished executing self._reentrant_block = True - set_value(self, graphic_or_key, value) + set_value(self, graphic_or_key, value, **kwargs) except Exception as exc: # raise original exception raise exc # set_value has raised. The line above and the lines 2+ steps below are probably more relevant! diff --git a/fastplotlib/graphics/features/_image.py b/fastplotlib/graphics/features/_image.py index 648f79bc8..cb66bb1ef 100644 --- a/fastplotlib/graphics/features/_image.py +++ b/fastplotlib/graphics/features/_image.py @@ -33,7 +33,7 @@ class TextureArray(GraphicFeature): }, ] - def __init__(self, data, isolated_buffer: bool = True, property_name: str = "data"): + def __init__(self, data, property_name: str = "data"): super().__init__(property_name=property_name) data = self._fix_data(data) @@ -41,13 +41,9 @@ def __init__(self, data, isolated_buffer: bool = True, property_name: str = "dat shared = pygfx.renderers.wgpu.get_shared() self._texture_limit_2d = shared.device.limits["max-texture-dimension-2d"] - if isolated_buffer: - # useful if data is read-only, example: memmaps - self._value = np.zeros(data.shape, dtype=data.dtype) - self.value[:] = data[:] - else: - # user's input array is used as the buffer - self._value = data + # create a new buffer + self._value = np.zeros(data.shape, dtype=data.dtype) + self.value[:] = data[:] # data start indices for each Texture self._row_indices = np.arange( diff --git a/fastplotlib/graphics/features/_mesh.py b/fastplotlib/graphics/features/_mesh.py index 7355acb4e..776d77ce4 100644 --- a/fastplotlib/graphics/features/_mesh.py +++ b/fastplotlib/graphics/features/_mesh.py @@ -51,18 +51,14 @@ class MeshIndices(VertexPositions): }, ] - def __init__( - self, data: Any, isolated_buffer: bool = True, property_name: str = "indices" - ): + def __init__(self, data: Any, property_name: str = "indices"): """ Manages the vertex indices buffer shown in the graphic. Supports fancy indexing if the data array also supports it. """ data = self._fix_data(data) - super().__init__( - data, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data, property_name=property_name) def _fix_data(self, data): if data.ndim != 2 or data.shape[1] not in (3, 4): diff --git a/fastplotlib/graphics/features/_positions.py b/fastplotlib/graphics/features/_positions.py index 295d22417..7b67e6bd7 100644 --- a/fastplotlib/graphics/features/_positions.py +++ b/fastplotlib/graphics/features/_positions.py @@ -39,7 +39,6 @@ def __init__( self, colors: str | pygfx.Color | np.ndarray | Sequence[float] | Sequence[str], n_colors: int, - isolated_buffer: bool = True, property_name: str = "colors", ): """ @@ -57,9 +56,56 @@ def __init__( """ data = parse_colors(colors, n_colors) - super().__init__( - data=data, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data=data, property_name=property_name) + + def set_value( + self, + graphic, + value: str | pygfx.Color | np.ndarray | Sequence[float] | Sequence[str], + ): + """set the entire array, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + # TODO: Refactor this triage so it's more elegant + + # first make sure it's not representing one color + skip = False + if isinstance(value, np.ndarray): + if (value.shape in ((3,), (4,))) and ( + np.issubdtype(value.dtype, np.floating) + or np.issubdtype(value.dtype, np.integer) + ): + # represents one color + skip = True + elif isinstance(value, (list, tuple)): + if len(value) in (3, 4) and all( + [isinstance(v, (float, int)) for v in value] + ): + # represents one color + skip = True + + # check if the number of elements matches current buffer size + if not skip and self.buffer.data.shape[0] != len(value): + # parse the new colors + new_colors = parse_colors(value, len(value)) + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(new_colors) + graphic.world_object.geometry.colors = self._fpl_buffer + + if len(self._event_handlers) < 1: + return + + event_info = { + "key": slice(None), + "value": new_colors, + "user_value": value, + } + + event = GraphicFeatureEvent(self._property_name, info=event_info) + self._call_event_handlers(event) + return + + self[:] = value @block_reentrance def __setitem__( @@ -231,18 +277,14 @@ class VertexPositions(BufferManager): }, ] - def __init__( - self, data: Any, isolated_buffer: bool = True, property_name: str = "data" - ): + def __init__(self, data: Any, property_name: str = "data"): """ Manages the vertex positions buffer shown in the graphic. Supports fancy indexing if the data array also supports it. """ data = self._fix_data(data) - super().__init__( - data, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data, property_name=property_name) def _fix_data(self, data): if data.ndim == 1: @@ -261,13 +303,42 @@ def _fix_data(self, data): return to_gpu_supported_dtype(data) + def set_value(self, graphic, value): + """Sets the entire array, creates new buffer if necessary""" + if isinstance(value, np.ndarray): + if self.buffer.data.shape[0] != value.shape[0]: + # number of items doesn't match, create a new buffer + + # if data is not 3D + if value.ndim == 1: + # _fix_data creates a new array so we don't need to re-allocate with np.zeros + bdata = self._fix_data(value) + + elif value.shape[1] == 2: + # _fix_data creates a new array so we don't need to re-allocate with np.zeros + bdata = self._fix_data(value) + + elif value.shape[1] == 3: + # need to allocate a buffer to use here + bdata = np.empty(value.shape, dtype=np.float32) + bdata[:] = value[:] + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(bdata) + graphic.world_object.geometry.positions = self._fpl_buffer + + self._emit_event(self._property_name, key=slice(None), value=value) + return + + self[:] = value + @block_reentrance def __setitem__( self, key: int | slice | np.ndarray[int | bool] | tuple[slice, ...], value: np.ndarray | float | list[float], ): - # directly use the key to slice the buffer + # directly use the key to slice the buffer and set the values self.buffer.data[key] = value # _update_range handles parsing the key to @@ -306,7 +377,7 @@ def __init__( provides a way to set colormaps with arbitrary transforms """ - super().__init__(data=vertex_colors.buffer, property_name=property_name) + super().__init__(data=None, property_name=property_name) self._vertex_colors = vertex_colors self._cmap_name = cmap_name @@ -331,6 +402,10 @@ def __init__( # set vertex colors from cmap self._vertex_colors[:] = colors + @property + def buffer(self) -> pygfx.Buffer: + return self._vertex_colors.buffer + @block_reentrance def __setitem__(self, key: slice, cmap_name): if not isinstance(key, slice): diff --git a/fastplotlib/graphics/features/_scatter.py b/fastplotlib/graphics/features/_scatter.py index 16671ef89..36c8527be 100644 --- a/fastplotlib/graphics/features/_scatter.py +++ b/fastplotlib/graphics/features/_scatter.py @@ -100,6 +100,37 @@ def searchsorted_markers_to_int_array(markers_str_array: np.ndarray[str]): return marker_int_searchsorted_vals[indices] +def parse_markers_init(markers: str | Sequence[str] | np.ndarray, n_datapoints: int): + # first validate then allocate buffers + + if isinstance(markers, str): + markers = user_input_to_marker(markers) + + elif isinstance(markers, (tuple, list, np.ndarray)): + validate_user_markers_array(markers) + + # allocate buffers + markers_int_array = np.zeros(n_datapoints, dtype=np.int32) + + marker_str_length = max(map(len, list(pygfx.MarkerShape))) + + markers_readable_array = np.empty(n_datapoints, dtype=f" np.ndarray[str]: @@ -200,6 +200,25 @@ def _set_markers_arrays(self, key, value, n_markers): "new markers value must be a str, Sequence or np.ndarray of new marker values" ) + def set_value(self, graphic, value): + """set all the markers, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + if self.buffer.data.shape[0] != len(value): + # need to create a new buffer + markers_int_array, self._markers_readable_array = parse_markers_init( + value, len(value) + ) + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(markers_int_array) + graphic.world_object.geometry.markers = self._fpl_buffer + + self._emit_event(self._property_name, key=slice(None), value=value) + + return + + self[:] = value + @block_reentrance def __setitem__( self, @@ -414,18 +433,15 @@ def __init__( self, rotations: int | float | np.ndarray | Sequence[int | float], n_datapoints: int, - isolated_buffer: bool = True, property_name: str = "point_rotations", ): """ Manages rotations buffer of scatter points. """ - sizes = self._fix_sizes(rotations, n_datapoints) - super().__init__( - data=sizes, isolated_buffer=isolated_buffer, property_name=property_name - ) + sizes = self._fix_rotations(rotations, n_datapoints) + super().__init__(data=sizes, property_name=property_name) - def _fix_sizes( + def _fix_rotations( self, sizes: int | float | np.ndarray | Sequence[int | float], n_datapoints: int, @@ -454,6 +470,22 @@ def _fix_sizes( return sizes + def set_value(self, graphic, value): + """set all rotations, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + if self.buffer.data.shape[0] != value.shape[0]: + # need to create a new buffer + value = self._fix_rotations(value, len(value)) + data = np.empty(shape=(len(value),), dtype=np.float32) + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(data) + graphic.world_object.geometry.rotations = self._fpl_buffer + self._emit_event(self._property_name, key=slice(None), value=value) + return + + self[:] = value + @block_reentrance def __setitem__( self, @@ -488,16 +520,13 @@ def __init__( self, sizes: int | float | np.ndarray | Sequence[int | float], n_datapoints: int, - isolated_buffer: bool = True, property_name: str = "sizes", ): """ Manages sizes buffer of scatter points. """ sizes = self._fix_sizes(sizes, n_datapoints) - super().__init__( - data=sizes, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data=sizes, property_name=property_name) def _fix_sizes( self, @@ -533,6 +562,23 @@ def _fix_sizes( return sizes + def set_value(self, graphic, value): + """set all sizes, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + if self.buffer.data.shape[0] != len(value): + # create new buffer + value = self._fix_sizes(value, len(value)) + data = np.empty(shape=(len(value),), dtype=np.float32) + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(data) + graphic.world_object.geometry.sizes = self._fpl_buffer + + self._emit_event(self._property_name, key=slice(None), value=value) + return + + self[:] = value + @block_reentrance def __setitem__( self, diff --git a/fastplotlib/graphics/features/_selection_features.py b/fastplotlib/graphics/features/_selection_features.py index 9b30dd70c..1f049f0cb 100644 --- a/fastplotlib/graphics/features/_selection_features.py +++ b/fastplotlib/graphics/features/_selection_features.py @@ -118,7 +118,7 @@ def axis(self) -> str: return self._axis @block_reentrance - def set_value(self, selector, value: Sequence[float]): + def set_value(self, selector, value: Sequence[float], *, change: str = "full"): """ Set start, stop range of selector @@ -182,7 +182,9 @@ def set_value(self, selector, value: Sequence[float]): if len(self._event_handlers) < 1: return - event = GraphicFeatureEvent(self._property_name, {"value": self.value}) + event = GraphicFeatureEvent( + self._property_name, {"value": self.value, "change": change} + ) event.get_selected_indices = selector.get_selected_indices event.get_selected_data = selector.get_selected_data diff --git a/fastplotlib/graphics/features/_vectors.py b/fastplotlib/graphics/features/_vectors.py index 9c86d25fc..729562b06 100644 --- a/fastplotlib/graphics/features/_vectors.py +++ b/fastplotlib/graphics/features/_vectors.py @@ -22,7 +22,6 @@ class VectorPositions(GraphicFeature): def __init__( self, positions: np.ndarray, - isolated_buffer: bool = True, property_name: str = "positions", ): """ @@ -111,7 +110,6 @@ class VectorDirections(GraphicFeature): def __init__( self, directions: np.ndarray, - isolated_buffer: bool = True, property_name: str = "directions", ): """Manages vector field positions by managing the mesh instance buffer's full transform matrix""" diff --git a/fastplotlib/graphics/features/_volume.py b/fastplotlib/graphics/features/_volume.py index ec4c4052a..532065fb7 100644 --- a/fastplotlib/graphics/features/_volume.py +++ b/fastplotlib/graphics/features/_volume.py @@ -34,7 +34,7 @@ class TextureArrayVolume(GraphicFeature): }, ] - def __init__(self, data, isolated_buffer: bool = True): + def __init__(self, data): super().__init__(property_name="data") data = self._fix_data(data) @@ -43,13 +43,9 @@ def __init__(self, data, isolated_buffer: bool = True): self._texture_size_limit = shared.device.limits["max-texture-dimension-3d"] - if isolated_buffer: - # useful if data is read-only, example: memmaps - self._value = np.zeros(data.shape, dtype=data.dtype) - self.value[:] = data[:] - else: - # user's input array is used as the buffer - self._value = data + # create a new buffer that will be used for the texture data + self._value = np.zeros(data.shape, dtype=data.dtype) + self.value[:] = data[:] # data start indices for each Texture self._row_indices = np.arange( diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index 44bffcedc..7b670d531 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -1,6 +1,7 @@ import math from typing import * +import numpy as np import pygfx from ..utils import quick_min_max @@ -102,7 +103,6 @@ def __init__( cmap: str = "plasma", interpolation: str = "nearest", cmap_interpolation: str = "linear", - isolated_buffer: bool = True, **kwargs, ): """ @@ -130,12 +130,6 @@ def __init__( cmap_interpolation: str, optional, default "linear" colormap interpolation method, one of "nearest" or "linear" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. - kwargs: additional keyword arguments passed to :class:`.Graphic` @@ -143,7 +137,7 @@ def __init__( super().__init__(**kwargs) - world_object = pygfx.Group() + group = pygfx.Group() if isinstance(data, TextureArray): # share buffer @@ -151,7 +145,7 @@ def __init__( else: # create new texture array to manage buffer # texture array that manages the multiple textures on the GPU that represent this image - self._data = TextureArray(data, isolated_buffer=isolated_buffer) + self._data = TextureArray(data) if (vmin is None) or (vmax is None): _vmin, _vmax = quick_min_max(self.data.value) @@ -165,21 +159,28 @@ def __init__( self._vmax = ImageVmax(vmax) self._interpolation = ImageInterpolation(interpolation) + self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) # set map to None for RGB images - if self._data.value.ndim > 2: + if self._data.value.ndim == 3: self._cmap = None _map = None - else: + + elif self._data.value.ndim == 2: # use TextureMap for grayscale images self._cmap = ImageCmap(cmap) - self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) _map = pygfx.TextureMap( self._cmap.texture, filter=self._cmap_interpolation.value, wrap="clamp-to-edge", ) + else: + raise ValueError( + f"ImageGraphic `data` must have 2 dimensions for grayscale images, or 3 dimensions for RGB(A) images.\n" + f"You have passed a a data array with: {self._data.value.ndim} dimensions, " + f"and of shape: {self._data.value.shape}" + ) # one common material is used for every Texture chunk self._material = pygfx.ImageBasicMaterial( @@ -189,6 +190,14 @@ def __init__( pick_write=True, ) + # create the _ImageTile world objects, add to group + for tile in self._create_tiles(): + group.add(tile) + + self._set_world_object(group) + + def _create_tiles(self) -> list[_ImageTile]: + tiles = list() # iterate through each texture chunk and create # an _ImageTile, offset the tile using the data indices for texture, chunk_index, data_slice in self._data: @@ -209,17 +218,58 @@ def __init__( img.world.x = data_col_start img.world.y = data_row_start - world_object.add(img) + tiles.append(img) - self._set_world_object(world_object) + return tiles @property def data(self) -> TextureArray: - """Get or set the image data""" + """ + Get or set the image data. + + Note that if the shape of the new data array does not equal the shape of + current data array, a new set of GPU Textures are automatically created. + This can have performance drawbacks when you have a ver large images. + This is usually fine as long as you don't need to do it hundreds of times + per second. + """ return self._data @data.setter def data(self, data): + if isinstance(data, np.ndarray): + # check if a new buffer is required + if self._data.value.shape != data.shape: + # create new TextureArray + self._data = TextureArray(data) + + # cmap based on if rgb or grayscale + if self._data.value.ndim > 2: + self._cmap = None + + # must be None if RGB(A) + self._material.map = None + else: + if self.cmap is None: # have switched from RGBA -> grayscale image + # create default cmap + self._cmap = ImageCmap("plasma") + self._material.map = pygfx.TextureMap( + self._cmap.texture, + filter=self._cmap_interpolation.value, + wrap="clamp-to-edge", + ) + + self._material.clim = quick_min_max(self.data.value) + + # clear image tiles + self.world_object.clear() + + # create new tiles + for tile in self._create_tiles(): + self.world_object.add(tile) + + return + self._data[:] = data @property @@ -232,8 +282,6 @@ def cmap(self) -> str | None: if self._cmap is not None: return self._cmap.value - return None - @cmap.setter def cmap(self, name: str): if self.data.value.ndim > 2: @@ -269,7 +317,7 @@ def interpolation(self, value: str): @property def cmap_interpolation(self) -> str: - """cmap interpolation method""" + """cmap interpolation method, 'linear' or 'nearest'. Used only for grayscale images""" return self._cmap_interpolation.value @cmap_interpolation.setter diff --git a/fastplotlib/graphics/image_volume.py b/fastplotlib/graphics/image_volume.py index db8f29eaa..3d2d064e8 100644 --- a/fastplotlib/graphics/image_volume.py +++ b/fastplotlib/graphics/image_volume.py @@ -113,7 +113,6 @@ def __init__( substep_size: float = 0.1, emissive: str | tuple | np.ndarray = (0, 0, 0), shininess: int = 30, - isolated_buffer: bool = True, **kwargs, ): """ @@ -170,11 +169,6 @@ def __init__( How shiny the specular highlight is; a higher value gives a sharper highlight. Used only if `mode` = "iso" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then set the data, useful if the - data arrays are ready-only such as memmaps. If False, the input array is itself used as the - buffer - useful if the array is large. - kwargs additional keyword arguments passed to :class:`.Graphic` @@ -188,7 +182,7 @@ def __init__( super().__init__(**kwargs) - world_object = pygfx.Group() + group = pygfx.Group() if isinstance(data, TextureArrayVolume): # share existing buffer @@ -196,7 +190,7 @@ def __init__( else: # create new texture array to manage buffer # texture array that manages the textures on the GPU that represent this image volume - self._data = TextureArrayVolume(data, isolated_buffer=isolated_buffer) + self._data = TextureArrayVolume(data) if (vmin is None) or (vmax is None): _vmin, _vmax = quick_min_max(self.data.value) @@ -210,18 +204,24 @@ def __init__( self._vmax = ImageVmax(vmax) self._interpolation = ImageInterpolation(interpolation) + self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) - # TODO: I'm assuming RGB volume images aren't supported??? # use TextureMap for grayscale images self._cmap = ImageCmap(cmap) - self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) - self._texture_map = pygfx.TextureMap( self._cmap.texture, filter=self._cmap_interpolation.value, wrap="clamp-to-edge", ) + if self._data.value.ndim not in (3, 4): + raise ValueError( + f"ImageVolumeGraphic `data` must have 3 dimensions for grayscale images, " + f"or 4 dimensions for RGB(A) images.\n" + f"You have passed a a data array with: {self._data.value.ndim} dimensions, " + f"and of shape: {self._data.value.shape}" + ) + self._plane = VolumeSlicePlane(plane) self._threshold = VolumeIsoThreshold(threshold) self._step_size = VolumeIsoStepSize(step_size) @@ -237,6 +237,15 @@ def __init__( self._mode = VolumeRenderMode(mode) + # create tiles + for tile in self._create_tiles(): + group.add(tile) + + self._set_world_object(group) + + def _create_tiles(self) -> list[_VolumeTile]: + tiles = list() + # iterate through each texture chunk and create # a _VolumeTile, offset the tile using the data indices for texture, chunk_index, data_slice in self._data: @@ -259,9 +268,9 @@ def __init__( vol.world.x = data_col_start vol.world.y = data_row_start - world_object.add(vol) + tiles.append(vol) - self._set_world_object(world_object) + return tiles @property def data(self) -> TextureArrayVolume: @@ -270,6 +279,21 @@ def data(self) -> TextureArrayVolume: @data.setter def data(self, data): + if isinstance(data, np.ndarray): + # check if a new buffer is required + if self._data.value.shape != data.shape: + # create new TextureArray + self._data = TextureArrayVolume(data) + + # clear image tiles + self.world_object.clear() + + # create new tiles + for tile in self._create_tiles(): + self.world_object.add(tile) + + return + self._data[:] = data @property @@ -283,7 +307,7 @@ def mode(self, mode: str): @property def cmap(self) -> str: - """Get or set colormap name""" + """Get or set colormap name, only used for grayscale images""" return self._cmap.value @cmap.setter diff --git a/fastplotlib/graphics/line.py b/fastplotlib/graphics/line.py index a4f42704f..bba10b10f 100644 --- a/fastplotlib/graphics/line.py +++ b/fastplotlib/graphics/line.py @@ -18,6 +18,7 @@ UniformColor, VertexCmap, SizeSpace, + UniformRotations, ) from ..utils import quick_min_max @@ -36,10 +37,9 @@ def __init__( data: Any, thickness: float = 2.0, colors: str | np.ndarray | Sequence = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: np.ndarray | Sequence = None, - isolated_buffer: bool = True, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", **kwargs, ): @@ -61,15 +61,19 @@ def __init__( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default ``False`` - if True, uses a uniform buffer for the line color, - basically saves GPU VRAM when the entire line has a single color - cmap: str, optional Apply a colormap to the line instead of assigning colors manually, this overrides any argument passed to "colors". For supported colormaps see the ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap @@ -84,10 +88,9 @@ def __init__( super().__init__( data=data, colors=colors, - uniform_color=uniform_color, cmap=cmap, cmap_transform=cmap_transform, - isolated_buffer=isolated_buffer, + color_mode=color_mode, size_space=size_space, **kwargs, ) @@ -102,8 +105,8 @@ def __init__( aa = kwargs.get("alpha_mode", "auto") in ("blend", "weighted_blend") - if uniform_color: - geometry = pygfx.Geometry(positions=self._data.buffer) + if isinstance(self._colors, UniformColor): + geometry = pygfx.Geometry(positions=self._data._fpl_buffer) material = MaterialCls( aa=aa, thickness=self.thickness, @@ -123,7 +126,7 @@ def __init__( depth_compare="<=", ) geometry = pygfx.Geometry( - positions=self._data.buffer, colors=self._colors.buffer + positions=self._data._fpl_buffer, colors=self._colors._fpl_buffer ) world_object: pygfx.Line = pygfx.Line(geometry=geometry, material=material) diff --git a/fastplotlib/graphics/line_collection.py b/fastplotlib/graphics/line_collection.py index d08231f7d..5ec56777e 100644 --- a/fastplotlib/graphics/line_collection.py +++ b/fastplotlib/graphics/line_collection.py @@ -128,14 +128,13 @@ def __init__( data: np.ndarray | List[np.ndarray], thickness: float | Sequence[float] = 2.0, colors: str | Sequence[str] | np.ndarray | Sequence[np.ndarray] = "w", - uniform_colors: bool = False, cmap: Sequence[str] | str = None, cmap_transform: np.ndarray | List = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", name: str = None, names: list[str] = None, metadata: Any = None, metadatas: Sequence[Any] | np.ndarray = None, - isolated_buffer: bool = True, kwargs_lines: list[dict] = None, **kwargs, ): @@ -170,6 +169,9 @@ def __init__( cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + The color mode for each line in the collection. See `color_mode` in :class:`.LineGraphic` for details. + name: str, optional name of the line collection as a whole @@ -320,11 +322,10 @@ def __init__( data=d, thickness=_s, colors=_c, - uniform_color=uniform_colors, cmap=_cmap, + color_mode=color_mode, name=_name, metadata=_m, - isolated_buffer=isolated_buffer, **kwargs_lines, ) @@ -560,7 +561,6 @@ def __init__( names: list[str] = None, metadata: Any = None, metadatas: Sequence[Any] | np.ndarray = None, - isolated_buffer: bool = True, separation: float = 10.0, separation_axis: str = "y", kwargs_lines: list[dict] = None, @@ -634,7 +634,6 @@ def __init__( names=names, metadata=metadata, metadatas=metadatas, - isolated_buffer=isolated_buffer, kwargs_lines=kwargs_lines, **kwargs, ) diff --git a/fastplotlib/graphics/mesh.py b/fastplotlib/graphics/mesh.py index 0e1ac42a3..efe03c57b 100644 --- a/fastplotlib/graphics/mesh.py +++ b/fastplotlib/graphics/mesh.py @@ -38,7 +38,6 @@ def __init__( mapcoords: Any = None, cmap: str | dict | pygfx.Texture | pygfx.TextureMap | np.ndarray = None, clim: tuple[float, float] = None, - isolated_buffer: bool = True, **kwargs, ): """ @@ -77,12 +76,6 @@ def __init__( Both 1D and 2D colormaps are supported, though the mapcoords has to match the dimensionality. An image can also be used, this is basically a 2D colormap. - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. In almost all cases this should be ``True``. - **kwargs passed to :class:`.Graphic` @@ -93,16 +86,12 @@ def __init__( if isinstance(positions, VertexPositions): self._positions = positions else: - self._positions = VertexPositions( - positions, isolated_buffer=isolated_buffer, property_name="positions" - ) + self._positions = VertexPositions(positions, property_name="positions") if isinstance(positions, MeshIndices): self._indices = indices else: - self._indices = MeshIndices( - indices, isolated_buffer=isolated_buffer, property_name="indices" - ) + self._indices = MeshIndices(indices, property_name="indices") self._cmap = MeshCmap(cmap) @@ -139,7 +128,7 @@ def __init__( ) geometry = pygfx.Geometry( - positions=self._positions.buffer, indices=self._indices._buffer + positions=self._positions.buffer, indices=self._indices._fpl_buffer ) valid_modes = ["basic", "phong", "slice"] diff --git a/fastplotlib/graphics/scatter.py b/fastplotlib/graphics/scatter.py index a2e696a82..b9cacf908 100644 --- a/fastplotlib/graphics/scatter.py +++ b/fastplotlib/graphics/scatter.py @@ -40,12 +40,12 @@ def __init__( self, data: Any, colors: str | np.ndarray | Sequence[float] | Sequence[str] = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: np.ndarray = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", mode: Literal["markers", "simple", "gaussian", "image"] = "markers", markers: str | np.ndarray | Sequence[str] = "o", - uniform_marker: bool = False, + uniform_marker: bool = True, custom_sdf: str = None, edge_colors: str | np.ndarray | pygfx.Color | Sequence[float] = "black", uniform_edge_color: bool = True, @@ -53,10 +53,9 @@ def __init__( image: np.ndarray = None, point_rotations: float | np.ndarray = 0, point_rotation_mode: Literal["uniform", "vertex", "curve"] = "uniform", - sizes: float | np.ndarray | Sequence[float] = 1, - uniform_size: bool = False, + sizes: float | np.ndarray | Sequence[float] = 5, + uniform_size: bool = True, size_space: str = "screen", - isolated_buffer: bool = True, **kwargs, ): """ @@ -72,18 +71,23 @@ def __init__( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default False - if True, uses a uniform buffer for the scatter point colors. Useful if you need to - save GPU VRAM when all points have the same color. - cmap: str, optional apply a colormap to the scatter instead of assigning colors manually, this - overrides any argument passed to "colors". For supported colormaps see the - ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + overrides any argument passed to "colors". + For supported colormaps see the ``cmap`` library catalogue: + https://cmap-docs.readthedocs.io/en/stable/catalog/ cmap_transform: 1D array-like or list of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + mode: one of: "markers", "simple", "gaussian", "image", default "markers" The scatter points mode, cannot be changed after the graphic has been created. @@ -103,9 +107,10 @@ def __init__( * Emojis: "❤️♠️♣️♦️💎💍✳️📍". * A string containing the value "custom". In this case, WGSL code defined by ``custom_sdf`` will be used. - uniform_marker: bool, default False - Use the same marker for all points. Only valid when `mode` is "markers". Useful if you need to use - the same marker for all points and want to save GPU RAM. + uniform_marker: bool, default ``True`` + If ``True``, use the same marker for all points. Only valid when `mode` is "markers". + Useful if you need to use the same marker for all points and want to save GPU RAM. If ``False``, you can + set per-vertex markers. custom_sdf: str = None, The SDF code for the marker shape when the marker is set to custom. @@ -125,8 +130,9 @@ def __init__( edge_colors: str | np.ndarray | pygfx.Color | Sequence[float], default "black" edge color of the markers, used when `mode` is "markers" - uniform_edge_color: bool, default True - Set the same edge color for all markers. Useful for saving GPU RAM. + uniform_edge_color: bool, default ``True`` + Set the same edge color for all markers. Useful for saving GPU RAM. Set to ``False`` for per-vertex edge + colors edge_width: float = 1.0, Width of the marker edges. used when `mode` is "markers". @@ -147,17 +153,13 @@ def __init__( sizes: float or iterable of float, optional, default 1.0 sizes of the scatter points - uniform_size: bool, default False - if True, uses a uniform buffer for the scatter point sizes. Useful if you need to - save GPU VRAM when all points have the same size. + uniform_size: bool, default ``False`` + if ``True``, uses a uniform buffer for the scatter point sizes. Useful if you need to + save GPU VRAM when all points have the same size. Set to ``False`` if you need per-vertex sizes. size_space: str, default "screen" coordinate space in which the size is expressed, one of ("screen", "world", "model") - isolated_buffer: bool, default True - whether the buffers should be isolated from the user input array. - Generally always ``True``, ``False`` is for rare advanced use if you have large arrays. - kwargs passed to :class:`.Graphic` @@ -166,17 +168,16 @@ def __init__( super().__init__( data=data, colors=colors, - uniform_color=uniform_color, cmap=cmap, cmap_transform=cmap_transform, - isolated_buffer=isolated_buffer, + color_mode=color_mode, size_space=size_space, **kwargs, ) n_datapoints = self.data.value.shape[0] - geo_kwargs = {"positions": self._data.buffer} + geo_kwargs = {"positions": self._data._fpl_buffer} aa = kwargs.get("alpha_mode", "auto") in ("blend", "weighted_blend") @@ -214,7 +215,7 @@ def __init__( self._markers = VertexMarkers(markers, n_datapoints) - geo_kwargs["markers"] = self._markers.buffer + geo_kwargs["markers"] = self._markers._fpl_buffer if edge_colors is None: # interpret as no edge color @@ -237,7 +238,7 @@ def __init__( edge_colors, n_datapoints, property_name="edge_colors" ) material_kwargs["edge_color_mode"] = pygfx.ColorMode.vertex - geo_kwargs["edge_colors"] = self._edge_colors.buffer + geo_kwargs["edge_colors"] = self._edge_colors._fpl_buffer self._edge_width = EdgeWidth(edge_width) material_kwargs["edge_width"] = self._edge_width.value @@ -274,12 +275,12 @@ def __init__( self._size_space = SizeSpace(size_space) - if uniform_color: + if isinstance(self._colors, UniformColor): material_kwargs["color_mode"] = pygfx.ColorMode.uniform material_kwargs["color"] = self.colors else: material_kwargs["color_mode"] = pygfx.ColorMode.vertex - geo_kwargs["colors"] = self.colors.buffer + geo_kwargs["colors"] = self.colors._fpl_buffer if uniform_size: material_kwargs["size_mode"] = pygfx.SizeMode.uniform @@ -288,14 +289,14 @@ def __init__( else: material_kwargs["size_mode"] = pygfx.SizeMode.vertex self._sizes = VertexPointSizes(sizes, n_datapoints=n_datapoints) - geo_kwargs["sizes"] = self.sizes.buffer + geo_kwargs["sizes"] = self.sizes._fpl_buffer match point_rotation_mode: case pygfx.enums.RotationMode.vertex: self._point_rotations = VertexRotations( point_rotations, n_datapoints=n_datapoints ) - geo_kwargs["rotations"] = self._point_rotations.buffer + geo_kwargs["rotations"] = self._point_rotations._fpl_buffer case pygfx.enums.RotationMode.uniform: self._point_rotations = UniformRotations(point_rotations) @@ -338,10 +339,8 @@ def markers(self, value: str | np.ndarray[str] | Sequence[str]): raise AttributeError( f"scatter plot is: {self.mode}. The mode must be 'markers' to set the markers" ) - if isinstance(self._markers, VertexMarkers): - self._markers[:] = value - elif isinstance(self._markers, UniformMarker): - self._markers.set_value(self, value) + + self._markers.set_value(self, value) @property def edge_colors(self) -> str | pygfx.Color | VertexColors | None: @@ -359,12 +358,7 @@ def edge_colors(self, value: str | np.ndarray | Sequence[str] | Sequence[float]) raise AttributeError( f"scatter plot is: {self.mode}. The mode must be 'markers' to set the edge_colors" ) - - if isinstance(self._edge_colors, VertexColors): - self._edge_colors[:] = value - - elif isinstance(self._edge_colors, UniformEdgeColor): - self._edge_colors.set_value(self, value) + self._edge_colors.set_value(self, value) @property def edge_width(self) -> float | None: @@ -406,11 +400,7 @@ def point_rotations(self, value: float | np.ndarray[float]): f"it be 'uniform' or 'vertex' to set the `point_rotations`" ) - if isinstance(self._point_rotations, VertexRotations): - self._point_rotations[:] = value - - elif isinstance(self._point_rotations, UniformRotations): - self._point_rotations.set_value(self, value) + self._point_rotations.set_value(self, value) @property def image(self) -> TextureArray | None: @@ -437,8 +427,4 @@ def sizes(self) -> VertexPointSizes | float: @sizes.setter def sizes(self, value): - if isinstance(self._sizes, VertexPointSizes): - self._sizes[:] = value - - elif isinstance(self._sizes, UniformSize): - self._sizes.set_value(self, value) + self._sizes.set_value(self, value) diff --git a/fastplotlib/graphics/scatter_collection.py b/fastplotlib/graphics/scatter_collection.py new file mode 100644 index 000000000..b8e7556ad --- /dev/null +++ b/fastplotlib/graphics/scatter_collection.py @@ -0,0 +1,636 @@ +from typing import * + +import numpy as np + +import pygfx + +from ..utils import parse_cmap_values +from ._collection_base import CollectionIndexer, GraphicCollection, CollectionFeature +from .scatter import ScatterGraphic +from .selectors import ( + LinearRegionSelector, + LinearSelector, + RectangleSelector, + PolygonSelector, +) + + +class _ScatterCollectionProperties: + """Mix-in class for ScatterCollection properties""" + + @property + def colors(self) -> CollectionFeature: + """get or set colors of scatters in the collection""" + return CollectionFeature(self.graphics, "colors") + + @colors.setter + def colors(self, values: str | np.ndarray | tuple[float] | list[float] | list[str]): + if isinstance(values, str): + # set colors of all scatter to one str color + for g in self: + g.colors = values + return + + elif all(isinstance(v, str) for v in values): + # individual str colors for each scatter + if not len(values) == len(self): + raise IndexError + + for g, v in zip(self.graphics, values): + g.colors = v + + return + + if isinstance(values, np.ndarray): + if values.ndim == 2: + # assume individual colors for each + for g, v in zip(self, values): + g.colors = v + return + + elif len(values) == 4: + # assume RGBA + self.colors[:] = values + + else: + # assume individual colors for each + for g, v in zip(self, values): + g.colors = v + + @property + def data(self) -> CollectionFeature: + """get or set data of lines in the collection""" + return CollectionFeature(self.graphics, "data") + + @data.setter + def data(self, values): + for g, v in zip(self, values): + g.data = v + + @property + def cmap(self) -> CollectionFeature: + """ + Get or set a cmap along the scatter collection. + + Optionally set using a tuple ("cmap", ) to set the transform. + Example: + + scatter_collection.cmap = ("jet", sine_transform_vals, 0.7) + + """ + return CollectionFeature(self.graphics, "cmap") + + @cmap.setter + def cmap(self, args): + if isinstance(args, str): + name = args + transform = None + elif len(args) == 1: + name = args[0] + transform = None + elif len(args) == 2: + name, transform = args + else: + raise ValueError( + "Too many values for cmap (note that alpha is deprecated, set alpha on the graphic instead)" + ) + + self.colors = parse_cmap_values( + n_colors=len(self), cmap_name=name, transform=transform + ) + + +class ScatterCollectionIndexer(CollectionIndexer, _ScatterCollectionProperties): + """Indexer for scatter collections""" + + pass + + +class ScatterCollection(GraphicCollection, _ScatterCollectionProperties): + _child_type = ScatterGraphic + _indexer = ScatterCollectionIndexer + + def __init__( + self, + data: np.ndarray | List[np.ndarray], + colors: str | Sequence[str] | np.ndarray | Sequence[np.ndarray] = "w", + cmap: Sequence[str] | str = None, + cmap_transform: np.ndarray | List = None, + sizes: float | Sequence[float] = 5.0, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Sequence[Any] | np.ndarray = None, + kwargs_lines: list[dict] = None, + **kwargs, + ): + """ + Create a collection of :class:`.ScatterGraphic` + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + meatadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + kwargs_lines: list[dict], optional + list of kwargs passed to the individual lines, ``len(kwargs_lines)`` must equal ``len(data)`` + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + """ + + super().__init__(name=name, metadata=metadata, **kwargs) + + if names is not None: + if len(names) != len(data): + raise ValueError( + f"len(names) != len(data)\n{len(names)} != {len(data)}" + ) + + if metadatas is not None: + if len(metadatas) != len(data): + raise ValueError( + f"len(metadata) != len(data)\n{len(metadatas)} != {len(data)}" + ) + + if kwargs_lines is not None: + if len(kwargs_lines) != len(data): + raise ValueError( + f"len(kwargs_lines) != len(data)\n" + f"{len(kwargs_lines)} != {len(data)}" + ) + + self._cmap_transform = cmap_transform + self._cmap_str = cmap + + # cmap takes priority over colors + if cmap is not None: + # cmap across lines + if isinstance(cmap, str): + colors = parse_cmap_values( + n_colors=len(data), cmap_name=cmap, transform=cmap_transform + ) + single_color = False + cmap = None + + elif isinstance(cmap, (tuple, list)): + if len(cmap) != len(data): + raise ValueError( + "cmap argument must be a single cmap or a list of cmaps " + "with the same length as the data" + ) + single_color = False + else: + raise ValueError( + "cmap argument must be a single cmap or a list of cmaps " + "with the same length as the data" + ) + else: + if isinstance(colors, np.ndarray): + # single color for all lines in the collection as RGBA + if colors.shape in [(3,), (4,)]: + single_color = True + + # colors specified for each line as array of shape [n_lines, RGBA] + elif colors.shape == (len(data), 4): + single_color = False + + else: + raise ValueError( + f"numpy array colors argument must be of shape (4,) or (n_lines, 4)." + f"You have pass the following shape: {colors.shape}" + ) + + elif isinstance(colors, str): + if colors == "random": + colors = np.random.rand(len(data), 3) + single_color = False + else: + # parse string color + single_color = True + colors = pygfx.Color(colors) + + elif isinstance(colors, (tuple, list)): + if len(colors) == 4: + # single color specified as (R, G, B, A) tuple or list + if all([isinstance(c, (float, int)) for c in colors]): + single_color = True + + elif len(colors) == len(data): + # colors passed as list/tuple of colors, such as list of string + single_color = False + + else: + raise ValueError( + "tuple or list colors argument must be a single color represented as [R, G, B, A], " + "or must be a tuple/list of colors represented by a string with the same length as the data" + ) + + if kwargs_lines is None: + kwargs_lines = dict() + + self._set_world_object(pygfx.Group()) + + for i, d in enumerate(data): + if cmap is None: + _cmap = None + + if single_color: + _c = colors + else: + _c = colors[i] + else: + _cmap = cmap[i] + _c = None + + if metadatas is not None: + _m = metadatas[i] + else: + _m = None + + if names is not None: + _name = names[i] + else: + _name = None + + lg = ScatterGraphic( + data=d, + colors=_c, + sizes=sizes, + cmap=_cmap, + name=_name, + metadata=_m, + **kwargs_lines, + ) + + self.add_graphic(lg) + + def __getitem__(self, item) -> ScatterCollectionIndexer: + return super().__getitem__(item) + + def add_linear_selector( + self, selection: float = None, padding: float = 0.0, axis: str = "x", **kwargs + ) -> LinearSelector: + """ + Adds a linear selector. + + Parameters + ---------- + Parameters + ---------- + selection: float, optional + selected point on the linear selector, computed from data if not provided + + axis: str, default "x" + axis that the selector resides on + + padding: float, default 0.0 + Extra padding to extend the linear selector along the orthogonal axis to make it easier to interact with. + + kwargs + passed to :class:`.LinearSelector` + + Returns + ------- + LinearSelector + + """ + + bounds_init, limits, size, center = self._get_linear_selector_init_args( + axis, padding + ) + + if selection is None: + selection = bounds_init[0] + + selector = LinearSelector( + selection=selection, + limits=limits, + axis=axis, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + return selector + + def add_linear_region_selector( + self, + selection: tuple[float, float] = None, + padding: float = 0.0, + axis: str = "x", + **kwargs, + ) -> LinearRegionSelector: + """ + Add a :class:`.LinearRegionSelector`. Selectors are just ``Graphic`` objects, so you can manage, + remove, or delete them from a plot area just like any other ``Graphic``. + + Parameters + ---------- + selection: (float, float), optional + the starting bounds of the linear region selector, computed from data if not provided + + axis: str, default "x" + axis that the selector resides on + + padding: float, default 0.0 + Extra padding to extend the linear region selector along the orthogonal axis to make it easier to interact with. + + kwargs + passed to ``LinearRegionSelector`` + + Returns + ------- + LinearRegionSelector + linear selection graphic + + """ + + bounds_init, limits, size, center = self._get_linear_selector_init_args( + axis, padding + ) + + if selection is None: + selection = bounds_init + + # create selector + selector = LinearRegionSelector( + selection=selection, + limits=limits, + size=size, + center=center, + axis=axis, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + # PlotArea manages this for garbage collection etc. just like all other Graphics + # so we should only work with a proxy on the user-end + return selector + + def add_rectangle_selector( + self, + selection: tuple[float, float, float] = None, + **kwargs, + ) -> RectangleSelector: + """ + Add a :class:`.RectangleSelector`. Selectors are just ``Graphic`` objects, so you can manage, + remove, or delete them from a plot area just like any other ``Graphic``. + + Parameters + ---------- + selection: (float, float, float, float), optional + initial (xmin, xmax, ymin, ymax) of the selection + """ + bbox = self.world_object.get_world_bounding_box() + + xdata = np.array(self.data[:, 0]) + xmin, xmax = (np.nanmin(xdata), np.nanmax(xdata)) + value_25px = (xmax - xmin) / 4 + + ydata = np.array(self.data[:, 1]) + ymin = np.floor(ydata.min()).astype(int) + + ymax = np.ptp(bbox[:, 1]) + + if selection is None: + selection = (xmin, value_25px, ymin, ymax) + + limits = (xmin, xmax, ymin - (ymax * 1.5 - ymax), ymax * 1.5) + + selector = RectangleSelector( + selection=selection, + limits=limits, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + return selector + + def add_polygon_selector( + self, + selection: List[tuple[float, float]] = None, + **kwargs, + ) -> PolygonSelector: + """ + Add a :class:`.PolygonSelector`. Selectors are just ``Graphic`` objects, so you can manage, + remove, or delete them from a plot area just like any other ``Graphic``. + + Parameters + ---------- + selection: List of positions, optional + Initial points for the polygon. If not given or None, you'll start drawing the selection (clicking adds points to the polygon). + """ + bbox = self.world_object.get_world_bounding_box() + + xdata = np.array(self.data[:, 0]) + xmin, xmax = (np.nanmin(xdata), np.nanmax(xdata)) + + ydata = np.array(self.data[:, 1]) + ymin = np.floor(ydata.min()).astype(int) + + ymax = np.ptp(bbox[:, 1]) + + limits = (xmin, xmax, ymin - (ymax * 1.5 - ymax), ymax * 1.5) + + selector = PolygonSelector( + selection, + limits, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + return selector + + def _get_linear_selector_init_args(self, axis, padding): + # use bbox to get size and center + bbox = self.world_object.get_world_bounding_box() + + if axis == "x": + xdata = np.array(self.data[:, 0]) + xmin, xmax = (np.nanmin(xdata), np.nanmax(xdata)) + value_25p = (xmax - xmin) / 4 + + bounds = (xmin, value_25p) + limits = (xmin, xmax) + # size from orthogonal axis + size = np.ptp(bbox[:, 1]) * 1.5 + # center on orthogonal axis + center = bbox[:, 1].mean() + + elif axis == "y": + ydata = np.array(self.data[:, 1]) + xmin, xmax = (np.nanmin(ydata), np.nanmax(ydata)) + value_25p = (xmax - xmin) / 4 + + bounds = (xmin, value_25p) + limits = (xmin, xmax) + + size = np.ptp(bbox[:, 0]) * 1.5 + # center on orthogonal axis + center = bbox[:, 0].mean() + + return bounds, limits, size, center + + +axes = {"x": 0, "y": 1, "z": 2} + + +class ScatterStack(ScatterCollection): + def __init__( + self, + data: List[np.ndarray], + thickness: float | Iterable[float] = 2.0, + colors: str | Iterable[str] | np.ndarray | Iterable[np.ndarray] = "w", + cmap: Iterable[str] | str = None, + cmap_transform: np.ndarray | List = None, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Sequence[Any] | np.ndarray = None, + isolated_buffer: bool = True, + separation: float = 0.0, + separation_axis: str = "y", + kwargs_lines: list[dict] = None, + **kwargs, + ): + """ + Create a stack of :class:`.LineGraphic` that are separated along the "x" or "y" axis. + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + thickness: float or Iterable of float, default 2.0 + | if ``float``, single thickness will be used for all lines + | if ``list`` of ``float``, each value will apply to the individual lines + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + metadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + separation: float, default 0.0 + space in between each line graphic in the stack + + separation_axis: str, default "y" + axis in which the line graphics in the stack should be separated + + + kwargs_lines: list[dict], optional + list of kwargs passed to the individual lines, ``len(kwargs_lines)`` must equal ``len(data)`` + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + """ + super().__init__( + data=data, + thickness=thickness, + colors=colors, + cmap=cmap, + cmap_transform=cmap_transform, + name=name, + names=names, + metadata=metadata, + metadatas=metadatas, + isolated_buffer=isolated_buffer, + kwargs_lines=kwargs_lines, + **kwargs, + ) + + self._sepration_axis = separation_axis + self._separation = separation + + self.separation = separation + + @property + def separation(self) -> float: + """distance between each line in the stack, in world space""" + return self._separation + + @separation.setter + def separation(self, value: float): + separation = float(value) + + axis_zero = 0 + for i, line in enumerate(self.graphics): + if self._sepration_axis == "x": + line.offset = (axis_zero, *line.offset[1:]) + + elif self._sepration_axis == "y": + line.offset = (line.offset[0], axis_zero, line.offset[2]) + + axis_zero = ( + axis_zero + line.data.value[:, axes[self._sepration_axis]].max() + separation + ) + + self._separation = value diff --git a/fastplotlib/graphics/selectors/_linear_region.py b/fastplotlib/graphics/selectors/_linear_region.py index 70a8dffa8..8a8583ae9 100644 --- a/fastplotlib/graphics/selectors/_linear_region.py +++ b/fastplotlib/graphics/selectors/_linear_region.py @@ -472,9 +472,9 @@ def _move_graphic(self, move_info: MoveInfo): if move_info.source == self._edges[0]: # change only left or bottom bound new_min = min(cur_min + delta, cur_max) - self._selection.set_value(self, (new_min, cur_max)) + self._selection.set_value(self, (new_min, cur_max), change="min") elif move_info.source == self._edges[1]: # change only right or top bound new_max = max(cur_max + delta, cur_min) - self._selection.set_value(self, (cur_min, new_max)) + self._selection.set_value(self, (cur_min, new_max), change="max") diff --git a/fastplotlib/graphics/utils.py b/fastplotlib/graphics/utils.py index 6be5aefc4..f32d80809 100644 --- a/fastplotlib/graphics/utils.py +++ b/fastplotlib/graphics/utils.py @@ -1,13 +1,16 @@ from contextlib import contextmanager +from typing import Callable, Iterable from ._base import Graphic @contextmanager -def pause_events(*graphics: Graphic): +def pause_events(*graphics: Graphic, event_handlers: Iterable[Callable] = None): """ Context manager for pausing Graphic events. + Optionally pass in only specific event handlers which are blocked. Other events for the graphic will not be blocked. + Examples -------- @@ -30,8 +33,14 @@ def pause_events(*graphics: Graphic): original_vals = [g.block_events for g in graphics] for g in graphics: - g.block_events = True + if event_handlers is not None: + g.block_handlers.extend([e for e in event_handlers]) + else: + g.block_events = True yield for g, value in zip(graphics, original_vals): - g.block_events = value + if event_handlers is not None: + g.block_handlers.clear() + else: + g.block_events = value diff --git a/fastplotlib/layouts/_figure.py b/fastplotlib/layouts/_figure.py index 79b5be3a8..00b915b1f 100644 --- a/fastplotlib/layouts/_figure.py +++ b/fastplotlib/layouts/_figure.py @@ -539,7 +539,7 @@ def _render(self, draw=True): # call the animation functions before render self._call_animate_functions(self._animate_funcs_pre) - for subplot in self: + for subplot in self._subplots.ravel(): subplot._render() # overlay render pass @@ -606,14 +606,14 @@ def show( sidecar_kwargs = dict() # flip y-axis if ImageGraphics are present - for subplot in self: + for subplot in self._subplots.ravel(): for g in subplot.graphics: if isinstance(g, ImageGraphic): subplot.camera.local.scale_y *= -1 break if autoscale: - for subplot in self: + for subplot in self._subplots.ravel(): if maintain_aspect is None: _maintain_aspect = subplot.camera.maintain_aspect else: @@ -622,7 +622,7 @@ def show( # set axes visibility if False if not axes_visible: - for subplot in self: + for subplot in self._subplots.ravel(): subplot.axes.visible = False # parse based on canvas type @@ -646,15 +646,15 @@ def show( elif self.canvas.__class__.__name__ == "OffscreenRenderCanvas": # for test and docs gallery screenshots self._fpl_reset_layout() - for subplot in self: + for subplot in self._subplots.ravel(): subplot.axes.update_using_camera() # render call is blocking only on github actions for some reason, # but not for rtd build, this is a workaround # for CI tests, the render call works if it's in test_examples # but it is necessary for the gallery images too so that's why this check is here - if "RTD_BUILD" in os.environ.keys(): - if os.environ["RTD_BUILD"] == "1": + if "DOCS_BUILD" in os.environ.keys(): + if os.environ["DOCS_BUILD"] == "1": self._render() else: # assume GLFW @@ -770,7 +770,7 @@ def clear_animations(self, removal: str = None): def clear(self): """Clear all Subplots""" - for subplot in self: + for subplot in self._subplots.ravel(): subplot.clear() def export_numpy(self, rgb: bool = False) -> np.ndarray: @@ -929,18 +929,20 @@ def __getitem__(self, index: str | int | tuple[int, int]) -> Subplot: return subplot raise IndexError(f"no subplot with given name: {index}") + if isinstance(index, (int, np.integer)): + return self._subplots.ravel()[index] + if isinstance(self.layout, GridLayout): return self._subplots[index[0], index[1]] - return self._subplots[index] + raise TypeError( + f"Can index figure using subplot name, numerical subplot index, or a " + f"tuple[int, int] if the layout is a grid" + ) def __iter__(self): - self._current_iter = iter(range(len(self))) - return self - - def __next__(self) -> Subplot: - pos = self._current_iter.__next__() - return self._subplots.ravel()[pos] + for subplot in self._subplots.ravel(): + yield subplot def __len__(self): """number of subplots""" @@ -955,6 +957,6 @@ def __repr__(self): return ( f"fastplotlib.{self.__class__.__name__}" f" Subplots:\n" - f"\t{newline.join(subplot.__str__() for subplot in self)}" + f"\t{newline.join(subplot.__str__() for subplot in self._subplots.ravel())}" f"\n" ) diff --git a/fastplotlib/layouts/_graphic_methods_mixin.py b/fastplotlib/layouts/_graphic_methods_mixin.py index 06a4c7517..bd01855bd 100644 --- a/fastplotlib/layouts/_graphic_methods_mixin.py +++ b/fastplotlib/layouts/_graphic_methods_mixin.py @@ -33,7 +33,6 @@ def add_image( cmap: str = "plasma", interpolation: str = "nearest", cmap_interpolation: str = "linear", - isolated_buffer: bool = True, **kwargs, ) -> ImageGraphic: """ @@ -62,12 +61,6 @@ def add_image( cmap_interpolation: str, optional, default "linear" colormap interpolation method, one of "nearest" or "linear" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. - kwargs: additional keyword arguments passed to :class:`.Graphic` @@ -81,7 +74,6 @@ def add_image( cmap, interpolation, cmap_interpolation, - isolated_buffer, **kwargs, ) @@ -100,7 +92,6 @@ def add_image_volume( substep_size: float = 0.1, emissive: str | tuple | numpy.ndarray = (0, 0, 0), shininess: int = 30, - isolated_buffer: bool = True, **kwargs, ) -> ImageVolumeGraphic: """ @@ -158,11 +149,6 @@ def add_image_volume( How shiny the specular highlight is; a higher value gives a sharper highlight. Used only if `mode` = "iso" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then set the data, useful if the - data arrays are ready-only such as memmaps. If False, the input array is itself used as the - buffer - useful if the array is large. - kwargs additional keyword arguments passed to :class:`.Graphic` @@ -183,7 +169,6 @@ def add_image_volume( substep_size, emissive, shininess, - isolated_buffer, **kwargs, ) @@ -192,14 +177,13 @@ def add_line_collection( data: Union[numpy.ndarray, List[numpy.ndarray]], thickness: Union[float, Sequence[float]] = 2.0, colors: Union[str, Sequence[str], numpy.ndarray, Sequence[numpy.ndarray]] = "w", - uniform_colors: bool = False, cmap: Union[Sequence[str], str] = None, cmap_transform: Union[numpy.ndarray, List] = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", name: str = None, names: list[str] = None, metadata: Any = None, metadatas: Union[Sequence[Any], numpy.ndarray] = None, - isolated_buffer: bool = True, kwargs_lines: list[dict] = None, **kwargs, ) -> LineCollection: @@ -235,6 +219,9 @@ def add_line_collection( cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + The color mode for each line in the collection. See `color_mode` in :class:`.LineGraphic` for details. + name: str, optional name of the line collection as a whole @@ -261,14 +248,13 @@ def add_line_collection( data, thickness, colors, - uniform_colors, cmap, cmap_transform, + color_mode, name, names, metadata, metadatas, - isolated_buffer, kwargs_lines, **kwargs, ) @@ -278,10 +264,9 @@ def add_line( data: Any, thickness: float = 2.0, colors: Union[str, numpy.ndarray, Sequence] = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: Union[numpy.ndarray, Sequence] = None, - isolated_buffer: bool = True, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", **kwargs, ) -> LineGraphic: @@ -304,15 +289,19 @@ def add_line( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default ``False`` - if True, uses a uniform buffer for the line color, - basically saves GPU VRAM when the entire line has a single color - cmap: str, optional Apply a colormap to the line instead of assigning colors manually, this overrides any argument passed to "colors". For supported colormaps see the ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap @@ -329,10 +318,9 @@ def add_line( data, thickness, colors, - uniform_color, cmap, cmap_transform, - isolated_buffer, + color_mode, size_space, **kwargs, ) @@ -348,7 +336,6 @@ def add_line_stack( names: list[str] = None, metadata: Any = None, metadatas: Union[Sequence[Any], numpy.ndarray] = None, - isolated_buffer: bool = True, separation: float = 10.0, separation_axis: str = "y", kwargs_lines: list[dict] = None, @@ -425,7 +412,6 @@ def add_line_stack( names, metadata, metadatas, - isolated_buffer, separation, separation_axis, kwargs_lines, @@ -448,7 +434,6 @@ def add_mesh( | numpy.ndarray ) = None, clim: tuple[float, float] = None, - isolated_buffer: bool = True, **kwargs, ) -> MeshGraphic: """ @@ -488,12 +473,6 @@ def add_mesh( Both 1D and 2D colormaps are supported, though the mapcoords has to match the dimensionality. An image can also be used, this is basically a 2D colormap. - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. In almost all cases this should be ``True``. - **kwargs passed to :class:`.Graphic` @@ -509,7 +488,6 @@ def add_mesh( mapcoords, cmap, clim, - isolated_buffer, **kwargs, ) @@ -570,16 +548,94 @@ def add_polygon( PolygonGraphic, data, mode, colors, mapcoords, cmap, clim, **kwargs ) + def add_scatter_collection( + self, + data: Union[numpy.ndarray, List[numpy.ndarray]], + colors: Union[str, Sequence[str], numpy.ndarray, Sequence[numpy.ndarray]] = "w", + cmap: Union[Sequence[str], str] = None, + cmap_transform: Union[numpy.ndarray, List] = None, + sizes: Union[float, Sequence[float]] = 5.0, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Union[Sequence[Any], numpy.ndarray] = None, + kwargs_lines: list[dict] = None, + **kwargs, + ) -> ScatterCollection: + """ + + Create a collection of :class:`.ScatterGraphic` + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + meatadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + kwargs_lines: list[dict], optional + list of kwargs passed to the individual lines, ``len(kwargs_lines)`` must equal ``len(data)`` + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + + """ + return self._create_graphic( + ScatterCollection, + data, + colors, + cmap, + cmap_transform, + sizes, + name, + names, + metadata, + metadatas, + kwargs_lines, + **kwargs, + ) + def add_scatter( self, data: Any, colors: Union[str, numpy.ndarray, Sequence[float], Sequence[str]] = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: numpy.ndarray = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", mode: Literal["markers", "simple", "gaussian", "image"] = "markers", markers: Union[str, numpy.ndarray, Sequence[str]] = "o", - uniform_marker: bool = False, + uniform_marker: bool = True, custom_sdf: str = None, edge_colors: Union[ str, pygfx.utils.color.Color, numpy.ndarray, Sequence[float] @@ -589,10 +645,9 @@ def add_scatter( image: numpy.ndarray = None, point_rotations: float | numpy.ndarray = 0, point_rotation_mode: Literal["uniform", "vertex", "curve"] = "uniform", - sizes: Union[float, numpy.ndarray, Sequence[float]] = 1, - uniform_size: bool = False, + sizes: Union[float, numpy.ndarray, Sequence[float]] = 5, + uniform_size: bool = True, size_space: str = "screen", - isolated_buffer: bool = True, **kwargs, ) -> ScatterGraphic: """ @@ -609,18 +664,23 @@ def add_scatter( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default False - if True, uses a uniform buffer for the scatter point colors. Useful if you need to - save GPU VRAM when all points have the same color. - cmap: str, optional apply a colormap to the scatter instead of assigning colors manually, this - overrides any argument passed to "colors". For supported colormaps see the - ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + overrides any argument passed to "colors". + For supported colormaps see the ``cmap`` library catalogue: + https://cmap-docs.readthedocs.io/en/stable/catalog/ cmap_transform: 1D array-like or list of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + mode: one of: "markers", "simple", "gaussian", "image", default "markers" The scatter points mode, cannot be changed after the graphic has been created. @@ -640,9 +700,10 @@ def add_scatter( * Emojis: "❤️♠️♣️♦️💎💍✳️📍". * A string containing the value "custom". In this case, WGSL code defined by ``custom_sdf`` will be used. - uniform_marker: bool, default False - Use the same marker for all points. Only valid when `mode` is "markers". Useful if you need to use - the same marker for all points and want to save GPU RAM. + uniform_marker: bool, default ``True`` + If ``True``, use the same marker for all points. Only valid when `mode` is "markers". + Useful if you need to use the same marker for all points and want to save GPU RAM. If ``False``, you can + set per-vertex markers. custom_sdf: str = None, The SDF code for the marker shape when the marker is set to custom. @@ -662,8 +723,9 @@ def add_scatter( edge_colors: str | np.ndarray | pygfx.Color | Sequence[float], default "black" edge color of the markers, used when `mode` is "markers" - uniform_edge_color: bool, default True - Set the same edge color for all markers. Useful for saving GPU RAM. + uniform_edge_color: bool, default ``True`` + Set the same edge color for all markers. Useful for saving GPU RAM. Set to ``False`` for per-vertex edge + colors edge_width: float = 1.0, Width of the marker edges. used when `mode` is "markers". @@ -684,17 +746,13 @@ def add_scatter( sizes: float or iterable of float, optional, default 1.0 sizes of the scatter points - uniform_size: bool, default False - if True, uses a uniform buffer for the scatter point sizes. Useful if you need to - save GPU VRAM when all points have the same size. + uniform_size: bool, default ``False`` + if ``True``, uses a uniform buffer for the scatter point sizes. Useful if you need to + save GPU VRAM when all points have the same size. Set to ``False`` if you need per-vertex sizes. size_space: str, default "screen" coordinate space in which the size is expressed, one of ("screen", "world", "model") - isolated_buffer: bool, default True - whether the buffers should be isolated from the user input array. - Generally always ``True``, ``False`` is for rare advanced use if you have large arrays. - kwargs passed to :class:`.Graphic` @@ -704,9 +762,9 @@ def add_scatter( ScatterGraphic, data, colors, - uniform_color, cmap, cmap_transform, + color_mode, mode, markers, uniform_marker, @@ -720,7 +778,6 @@ def add_scatter( sizes, uniform_size, size_space, - isolated_buffer, **kwargs, ) diff --git a/fastplotlib/layouts/_plot_area.py b/fastplotlib/layouts/_plot_area.py index f83dcfbcb..405a01546 100644 --- a/fastplotlib/layouts/_plot_area.py +++ b/fastplotlib/layouts/_plot_area.py @@ -233,7 +233,10 @@ def controller(self, new_controller: str | pygfx.Controller): # pygfx plans on refactoring viewports anyways if self.parent is not None: if self.parent.__class__.__name__.endswith("Figure"): - for subplot in self.parent: + # always use figure._subplots.ravel() in internal fastplotlib code + # otherwise if we use `for subplot in figure`, this could conflict + # with a user's iterator where they are doing `for subplot in figure` !!! + for subplot in self.parent._subplots.ravel(): if subplot.camera in cameras_list: new_controller.register_events(subplot.viewport) subplot._controller = new_controller diff --git a/fastplotlib/tools/_histogram_lut.py b/fastplotlib/tools/_histogram_lut.py index d651137da..8edfb046b 100644 --- a/fastplotlib/tools/_histogram_lut.py +++ b/fastplotlib/tools/_histogram_lut.py @@ -6,424 +6,412 @@ import pygfx -from ..utils import subsample_array +from ..utils import subsample_array, RenderQueue from ..graphics import LineGraphic, ImageGraphic, ImageVolumeGraphic, TextGraphic from ..graphics.utils import pause_events from ..graphics._base import Graphic +from ..graphics.features import GraphicFeatureEvent from ..graphics.selectors import LinearRegionSelector -def _get_image_graphic_events(image_graphic: ImageGraphic) -> list[str]: - """Small helper function to return the relevant events for an ImageGraphic""" - events = ["vmin", "vmax"] +def _format_value(value: float): + abs_val = abs(value) + if abs_val < 0.01 or abs_val > 9_999: + return f"{value:.2e}" + else: + return f"{value:.2f}" - if not image_graphic.data.value.ndim > 2: - events.append("cmap") - # if RGB(A), do not add cmap - - return events - - -# TODO: This is a widget, we can think about a BaseWidget class later if necessary class HistogramLUTTool(Graphic): _fpl_support_tooltip = False def __init__( self, - data: np.ndarray, - images: ( - ImageGraphic - | ImageVolumeGraphic - | Sequence[ImageGraphic | ImageVolumeGraphic] - ), - nbins: int = 100, - flank_divisor: float = 5.0, + histogram: tuple[np.ndarray, np.ndarray], + images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] | None = None, **kwargs, ): """ - HistogramLUT tool that can be used to control the vmin, vmax of ImageGraphics or ImageVolumeGraphics. - If used to control multiple images or image volumes it is assumed that they share a representation of - the same data, and that their histogram, vmin, and vmax are identical. For example, displaying a - ImageVolumeGraphic and several images that represent slices of the same volume data. + A histogram tool that allows adjusting the vmin, vmax of images. + Also allows changing the cmap LUT for grayscale images and displays a colorbar. Parameters ---------- - data: np.ndarray - - images: ImageGraphic | ImageVolumeGraphic | tuple[ImageGraphic | ImageVolumeGraphic] - - nbins: int, defaut 100. - Total number of bins used in the histogram + histogram: tuple[np.ndarray, np.ndarray] + [frequency, bin_edges], must be 100 bins - flank_divisor: float, default 5.0. - Fraction of empty histogram bins on the tails of the distribution set `np.inf` for no flanks + images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] + the images that are managed by the histogram tool - kwargs: passed to ``Graphic`` + kwargs: + passed to ``Graphic`` """ - super().__init__(**kwargs) - - self._nbins = nbins - self._flank_divisor = flank_divisor - - if isinstance(images, (ImageGraphic, ImageVolumeGraphic)): - images = (images,) - elif isinstance(images, Sequence): - if not all( - [isinstance(ig, (ImageGraphic, ImageVolumeGraphic)) for ig in images] - ): - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - else: - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - self._images = images + super().__init__(**kwargs) - self._data = weakref.proxy(data) + if len(histogram) != 2: + raise TypeError - self._scale_factor: float = 1.0 + self._block_reentrance = False + self._images = list() - hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) + self._bin_centers_flanked = np.zeros(120, dtype=np.float64) + self._freq_flanked = np.zeros(120, dtype=np.float32) - line_data = np.column_stack([hist_scaled, edges_flanked]) + # 100 points for the histogram, 10 points on each side for the flank + line_data = np.column_stack( + [np.zeros(120, dtype=np.float32), np.arange(0, 120)] + ) - self._histogram_line = LineGraphic( - line_data, colors=(0.8, 0.8, 0.8), alpha_mode="solid", offset=(0, 0, -1) + # line that displays the histogram + self._line = LineGraphic( + line_data, colors=(0.8, 0.8, 0.8), alpha_mode="solid", offset=(1, 0, 0) + ) + self._line.world_object.local.scale_x = -1 + + # vmin, vmax selector + self._selector = LinearRegionSelector( + selection=(10, 110), + limits=(0, 119), + size=1.5, + center=0.5, # frequency data are normalized between 0-1 + axis="y", + parent=self._line, ) - bounds = (edges[0] * self._scale_factor, edges[-1] * self._scale_factor) - limits = (edges_flanked[0], edges_flanked[-1]) - size = 120 # since it's scaled to 100 - origin = (hist_scaled.max() / 2, 0) + self._selector.add_event_handler(self._selector_event_handler, "selection") - self._linear_region_selector = LinearRegionSelector( - selection=bounds, - limits=limits, - size=size, - center=origin[0], - axis="y", - parent=self._histogram_line, + self._colorbar = ImageGraphic( + data=np.zeros([120, 2]), interpolation="linear", offset=(1.5, 0, 0) ) - self._vmin = self.images[0].vmin - self._vmax = self.images[0].vmax + # make the colorbar thin + self._colorbar.world_object.local.scale_x = 0.15 + self._colorbar.add_event_handler(self._open_cmap_picker, "click") - # there will be a small difference with the histogram edges so this makes them both line up exactly - self._linear_region_selector.selection = ( - self._vmin * self._scale_factor, - self._vmax * self._scale_factor, + # colorbar ruler + self._ruler = pygfx.Ruler( + end_pos=(0, 119, 0), + alpha_mode="solid", + render_queue=RenderQueue.axes, + tick_side="right", + tick_marker="tick_right", + tick_format=self._ruler_tick_map, + min_tick_distance=10, ) + self._ruler.local.x = 1.75 - vmin_str, vmax_str = self._get_vmin_vmax_str() + # TODO: need to auto-scale using the text so it appears nicely, will do later + self._ruler.visible = False self._text_vmin = TextGraphic( - text=vmin_str, + text="", font_size=16, - offset=(0, 0, 0), anchor="top-left", outline_color="black", outline_thickness=0.5, alpha_mode="solid", ) - + # this is to make sure clicking text doesn't conflict with the selector tool + # since the text appears near the selector tool self._text_vmin.world_object.material.pick_write = False self._text_vmax = TextGraphic( - text=vmax_str, + text="", font_size=16, - offset=(0, 0, 0), anchor="bottom-left", outline_color="black", outline_thickness=0.5, alpha_mode="solid", ) - self._text_vmax.world_object.material.pick_write = False - widget_wo = pygfx.Group() - widget_wo.add( - self._histogram_line.world_object, - self._linear_region_selector.world_object, + # add all the world objects to a pygfx.Group + wo = pygfx.Group() + wo.add( + self._line.world_object, + self._selector.world_object, + self._colorbar.world_object, + self._ruler, self._text_vmin.world_object, self._text_vmax.world_object, ) + self._set_world_object(wo) - self._set_world_object(widget_wo) + # for convenience, a list that stores all the graphics managed by the histogram LUT tool + self._children = [ + self._line, + self._selector, + self._colorbar, + self._text_vmin, + self._text_vmax, + ] - self.world_object.local.scale_x *= -1 + # set histogram + self.histogram = histogram - self._text_vmin.offset = (-120, self._linear_region_selector.selection[0], 0) + # set the images + self.images = images - self._text_vmax.offset = (-120, self._linear_region_selector.selection[1], 0) + def _fpl_add_plot_area_hook(self, plot_area): + self._plot_area = plot_area - self._linear_region_selector.add_event_handler( - self._linear_region_handler, "selection" - ) + for child in self._children: + # need all of them to call the add_plot_area_hook so that events are connected correctly + # example, the linear region selector needs all the canvas events to be connected + child._fpl_add_plot_area_hook(plot_area) - ig_events = _get_image_graphic_events(self.images[0]) + if hasattr(self._plot_area, "size"): + # if it's in a dock area + self._plot_area.size = 80 - for ig in self.images: - ig.add_event_handler(self._image_cmap_handler, *ig_events) + # disable the controller in this plot area + self._plot_area.controller.enabled = False + self._plot_area.auto_scale(maintain_aspect=False) - # colorbar for grayscale images - if self.images[0].cmap is not None: - self._colorbar: ImageGraphic = self._make_colorbar(edges_flanked) - self._colorbar.add_event_handler(self._open_cmap_picker, "click") + # tick text for colorbar ruler doesn't show without this call + self._ruler.update(plot_area.camera, plot_area.canvas.get_logical_size()) - self.world_object.add(self._colorbar.world_object) - else: - self._colorbar = None - self._cmap = None + def _ruler_tick_map(self, bin_index, *args): + return f"{self._bin_centers_flanked[int(bin_index)]:.2f}" - def _make_colorbar(self, edges_flanked) -> ImageGraphic: - # use the histogram edge values as data for an - # image with 2 columns, this will be our colorbar! - colorbar_data = np.column_stack( - [ - np.linspace( - edges_flanked[0], edges_flanked[-1], ceil(np.ptp(edges_flanked)) - ) - ] - * 2 - ).astype(np.float32) - - colorbar_data /= self._scale_factor - - cbar = ImageGraphic( - data=colorbar_data, - vmin=self.vmin, - vmax=self.vmax, - cmap=self.images[0].cmap, - interpolation="linear", - offset=(-55, edges_flanked[0], -1), - ) + @property + def histogram(self) -> tuple[np.ndarray, np.ndarray]: + """histogram [frequency, bin_centers]. Frequency is flanked by 10 zeros on both sides""" + return self._freq_flanked, self._bin_centers_flanked - cbar.world_object.world.scale_x = 20 - self._cmap = self.images[0].cmap + @histogram.setter + def histogram( + self, histogram: tuple[np.ndarray, np.ndarray], limits: tuple[int, int] = None + ): + """set histogram with pre-compuated [frequency, edges], must have exactly 100 bins""" - return cbar + freq, edges = histogram - def _get_vmin_vmax_str(self) -> tuple[str, str]: - if self.vmin < 0.001 or self.vmin > 99_999: - vmin_str = f"{self.vmin:.2e}" - else: - vmin_str = f"{self.vmin:.2f}" + if freq.max() > 0: + # if the histogram is made from an empty array, then the max freq will be 0 + # we don't want to divide by 0 because then we just get nans + freq = freq / freq.max() - if self.vmax < 0.001 or self.vmax > 99_999: - vmax_str = f"{self.vmax:.2e}" - else: - vmax_str = f"{self.vmax:.2f}" + bin_centers = 0.5 * (edges[1:] + edges[:-1]) - return vmin_str, vmax_str + step = bin_centers[1] - bin_centers[0] - def _fpl_add_plot_area_hook(self, plot_area): - self._plot_area = plot_area - self._linear_region_selector._fpl_add_plot_area_hook(plot_area) - self._histogram_line._fpl_add_plot_area_hook(plot_area) + under_flank = np.linspace(bin_centers[0] - step * 10, bin_centers[0] - step, 10) + over_flank = np.linspace( + bin_centers[-1] + step, bin_centers[-1] + step * 10, 10 + ) + self._bin_centers_flanked[:] = np.concatenate( + [under_flank, bin_centers, over_flank] + ) + + self._freq_flanked[10:110] = freq - self._plot_area.auto_scale() - self._plot_area.controller.enabled = True + self._line.data[:, 0] = self._freq_flanked + self._colorbar.data = np.column_stack( + [self._bin_centers_flanked, self._bin_centers_flanked] + ) - def _calculate_histogram(self, data): + # self.vmin, self.vmax = bin_centers[0], bin_centers[-1] - # get a subsampled view of this array - data_ss = subsample_array(data, max_size=int(1e6)) # 1e6 is default - hist, edges = np.histogram(data_ss, bins=self._nbins) + if hasattr(self, "plot_area"): + self._ruler.update( + self._plot_area.camera, self._plot_area.canvas.get_logical_size() + ) - # used if data ptp <= 10 because event things get weird - # with tiny world objects due to floating point error - # so if ptp <= 10, scale up by a factor - data_interval = edges[-1] - edges[0] - self._scale_factor: int = max(1, 100 * int(10 / data_interval)) + @property + def images(self) -> tuple[ImageGraphic | ImageVolumeGraphic, ...] | None: + """get or set the managed images""" + return tuple(self._images) - edges = edges * self._scale_factor + @images.setter + def images(self, new_images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] | None): + self._disconnect_images() + self._images.clear() - bin_width = edges[1] - edges[0] + if new_images is None: + return - flank_nbins = int(self._nbins / self._flank_divisor) - flank_size = flank_nbins * bin_width + if isinstance(new_images, (ImageGraphic, ImageVolumeGraphic)): + new_images = [new_images] - flank_left = np.arange(edges[0] - flank_size, edges[0], bin_width) - flank_right = np.arange( - edges[-1] + bin_width, edges[-1] + flank_size, bin_width - ) + if not all( + [ + isinstance(image, (ImageGraphic, ImageVolumeGraphic)) + for image in new_images + ] + ): + raise TypeError - edges_flanked = np.concatenate((flank_left, edges, flank_right)) + for image in new_images: + if image.cmap is not None: + self._colorbar.visible = True + break + else: + self._colorbar.visible = False - hist_flanked = np.concatenate( - (np.zeros(flank_nbins), hist, np.zeros(flank_nbins)) - ) + self._images = list(new_images) - # scale 0-100 to make it easier to see - # float32 data can produce unnecessarily high values - hist_scale_value = hist_flanked.max() - if np.allclose(hist_scale_value, 0): - hist_scale_value = 1 - hist_scaled = hist_flanked / (hist_scale_value / 100) + # reset vmin, vmax using first image + self.vmin = self._images[0].vmin + self.vmax = self._images[0].vmax - if edges_flanked.size > hist_scaled.size: - # we don't care about accuracy here so if it's off by 1-2 bins that's fine - edges_flanked = edges_flanked[: hist_scaled.size] + if self._images[0].cmap is not None: + self._colorbar.cmap = self._images[0].cmap - return hist, edges, hist_scaled, edges_flanked + # connect event handlers + for image in self._images: + image.add_event_handler(self._image_event_handler, "vmin", "vmax") + image.add_event_handler(self._disconnect_images, "deleted") + if image.cmap is not None: + image.add_event_handler( + self._image_event_handler, "vmin", "vmax", "cmap" + ) - def _linear_region_handler(self, ev): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - selected_ixs = self._linear_region_selector.selection - vmin, vmax = selected_ixs[0], selected_ixs[1] - vmin, vmax = vmin / self._scale_factor, vmax / self._scale_factor - self.vmin, self.vmax = vmin, vmax + def _disconnect_images(self, *args): + """disconnect event handlers of the managed images""" + for image in self._images: + for ev, handlers in image.event_handlers: + if self._image_event_handler in handlers: + image.remove_event_handler(self._image_event_handler, ev) - def _image_cmap_handler(self, ev): - setattr(self, ev.type, ev.info["value"]) + def _image_event_handler(self, ev): + """when the image vmin, vmax, or cmap changes it will update the HistogramLUTTool""" + new_value = ev.info["value"] + setattr(self, ev.type, new_value) @property def cmap(self) -> str: - return self._cmap + """get or set the colormap, only for grayscale images""" + return self._colorbar.cmap @cmap.setter def cmap(self, name: str): - if self._colorbar is None: + if self._block_reentrance: return - with pause_events(*self.images): - for ig in self.images: - ig.cmap = name + if name is None: + return - self._cmap = name + self._block_reentrance = True + try: self._colorbar.cmap = name + with pause_events( + *self._images, event_handlers=[self._image_event_handler] + ): + for image in self._images: + if image.cmap is None: + # rgb(a) images have no cmap + continue + + image.cmap = name + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False + @property def vmin(self) -> float: - return self._vmin + """get or set the vmin, the lower contrast limit""" + # no offset or rotation so we can directly use the world space selection value + index = int(self._selector.selection[0]) + return self._bin_centers_flanked[index] @vmin.setter def vmin(self, value: float): - with pause_events(self._linear_region_selector, *self.images): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - self._linear_region_selector.selection = ( - value * self._scale_factor, - self._linear_region_selector.selection[1], - ) - for ig in self.images: - ig.vmin = value + if self._block_reentrance: + return + self._block_reentrance = True + try: + index_min = np.searchsorted(self._bin_centers_flanked, value) + with pause_events( + self._selector, + *self._images, + event_handlers=[ + self._selector_event_handler, + self._image_event_handler, + ], + ): + self._selector.selection = (index_min, self._selector.selection[1]) - self._vmin = value - if self._colorbar is not None: - self._colorbar.vmin = value + self._colorbar.vmin = value - vmin_str, vmax_str = self._get_vmin_vmax_str() - self._text_vmin.offset = (-120, self._linear_region_selector.selection[0], 0) - self._text_vmin.text = vmin_str + self._text_vmin.text = _format_value(value) + self._text_vmin.offset = (-0.45, self._selector.selection[0], 0) + + for image in self._images: + image.vmin = value + + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False @property def vmax(self) -> float: - return self._vmax + """get or set the vmax, the upper contrast limit""" + # no offset or rotation so we can directly use the world space selection value + index = int(self._selector.selection[1]) + return self._bin_centers_flanked[index] @vmax.setter def vmax(self, value: float): - with pause_events(self._linear_region_selector, *self.images): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - self._linear_region_selector.selection = ( - self._linear_region_selector.selection[0], - value * self._scale_factor, - ) - - for ig in self.images: - ig.vmax = value - - self._vmax = value - if self._colorbar is not None: - self._colorbar.vmax = value - - vmin_str, vmax_str = self._get_vmin_vmax_str() - self._text_vmax.offset = (-120, self._linear_region_selector.selection[1], 0) - self._text_vmax.text = vmax_str - - def set_data(self, data, reset_vmin_vmax: bool = True): - hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) - - line_data = np.column_stack([hist_scaled, edges_flanked]) - - # set x and y vals - self._histogram_line.data[:, :2] = line_data - - bounds = (edges[0], edges[-1]) - limits = (edges_flanked[0], edges_flanked[-11]) - origin = (hist_scaled.max() / 2, 0) - - if reset_vmin_vmax: - # reset according to the new data - self._linear_region_selector.limits = limits - self._linear_region_selector.selection = bounds - else: - with pause_events(self._linear_region_selector, *self.images): - # don't change the current selection - self._linear_region_selector.limits = limits - - self._data = weakref.proxy(data) - - if self._colorbar is not None: - self._colorbar.clear_event_handlers() - self.world_object.remove(self._colorbar.world_object) - - if self.images[0].cmap is not None: - self._colorbar: ImageGraphic = self._make_colorbar(edges_flanked) - self._colorbar.add_event_handler(self._open_cmap_picker, "click") + if self._block_reentrance: + return - self.world_object.add(self._colorbar.world_object) - else: - self._colorbar = None - self._cmap = None + self._block_reentrance = True + try: + index_max = np.searchsorted(self._bin_centers_flanked, value) + with pause_events( + self._selector, + *self._images, + event_handlers=[ + self._selector_event_handler, + self._image_event_handler, + ], + ): + self._selector.selection = (self._selector.selection[0], index_max) - # reset plotarea dims - self._plot_area.auto_scale() + self._colorbar.vmax = value - @property - def images(self) -> tuple[ImageGraphic | ImageVolumeGraphic]: - return self._images + self._text_vmax.text = _format_value(value) + self._text_vmax.offset = (-0.45, self._selector.selection[1], 0) - @images.setter - def images(self, images): - if isinstance(images, (ImageGraphic, ImageVolumeGraphic)): - images = (images,) - elif isinstance(images, Sequence): - if not all( - [isinstance(ig, (ImageGraphic, ImageVolumeGraphic)) for ig in images] - ): - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - else: - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) + for image in self._images: + image.vmax = value - if self._images is not None: - for ig in self._images: - # cleanup events from current image graphics - ig_events = _get_image_graphic_events(ig) - ig.remove_event_handler(self._image_cmap_handler, *ig_events) + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False - self._images = images + def _selector_event_handler(self, ev: GraphicFeatureEvent): + """when the selector's selctor has changed, it will update the vmin, vmax, or both""" + selection = ev.info["value"] + index_min = int(selection[0]) + vmin = self._bin_centers_flanked[index_min] - ig_events = _get_image_graphic_events(self._images[0]) + index_max = int(selection[1]) + vmax = self._bin_centers_flanked[index_max] - for ig in self.images: - ig.add_event_handler(self._image_cmap_handler, *ig_events) + match ev.info["change"]: + case "min": + self.vmin = vmin + case "max": + self.vmax = vmax + case _: + self.vmin, self.vmax = vmin, vmax def _open_cmap_picker(self, ev): + """open imgui cmap picker""" # check if right click if ev.button != 2: return @@ -433,7 +421,11 @@ def _open_cmap_picker(self, ev): self._plot_area.get_figure().open_popup("colormap-picker", pos, lut_tool=self) def _fpl_prepare_del(self): - self._linear_region_selector._fpl_prepare_del() - self._histogram_line._fpl_prepare_del() - del self._histogram_line - del self._linear_region_selector + """cleanup, need to disconnect events and remove image references for proper garbage collection""" + self._disconnect_images() + self._images.clear() + + for i in range(len(self._children)): + g = self._children.pop(0) + g._fpl_prepare_del() + del g diff --git a/fastplotlib/ui/_base.py b/fastplotlib/ui/_base.py index 3e763e08c..9767cf76f 100644 --- a/fastplotlib/ui/_base.py +++ b/fastplotlib/ui/_base.py @@ -123,8 +123,9 @@ def size(self) -> int | None: @size.setter def size(self, value): if not isinstance(value, int): - raise TypeError + raise TypeError(f"{self.__class__.__name__}.size must be an ") self._size = value + self._set_rect() @property def location(self) -> str: @@ -153,6 +154,7 @@ def height(self) -> int: def _set_rect(self, *args): self._x, self._y, self._width, self._height = self.get_rect() + self._figure._fpl_reset_layout() def get_rect(self) -> tuple[int, int, int, int]: """ diff --git a/fastplotlib/ui/right_click_menus/_colormap_picker.py b/fastplotlib/ui/right_click_menus/_colormap_picker.py index a80e5b2aa..9df26dcdc 100644 --- a/fastplotlib/ui/right_click_menus/_colormap_picker.py +++ b/fastplotlib/ui/right_click_menus/_colormap_picker.py @@ -154,7 +154,8 @@ def update(self): self._texture_height = (imgui.get_font_size()) - 2 if imgui.menu_item("Reset vmin-vmax", "", False)[0]: - self._lut_tool.images[0].reset_vmin_vmax() + for image in self._lut_tool.images: + image.reset_vmin_vmax() # add all the cmap options for cmap_type in COLORMAP_NAMES.keys(): diff --git a/fastplotlib/utils/__init__.py b/fastplotlib/utils/__init__.py index dd527ca67..8001ae375 100644 --- a/fastplotlib/utils/__init__.py +++ b/fastplotlib/utils/__init__.py @@ -6,6 +6,7 @@ from .gpu import enumerate_adapters, select_adapter, print_wgpu_report from ._plot_helpers import * from .enums import * +from ._protocols import ArrayProtocol @dataclass diff --git a/fastplotlib/utils/_protocols.py b/fastplotlib/utils/_protocols.py new file mode 100644 index 000000000..7ae63ed67 --- /dev/null +++ b/fastplotlib/utils/_protocols.py @@ -0,0 +1,15 @@ +from typing import Protocol, runtime_checkable + + +ARRAY_LIKE_ATTRS = ["shape", "ndim", "__getitem__"] + + +@runtime_checkable +class ArrayProtocol(Protocol): + @property + def ndim(self) -> int: ... + + @property + def shape(self) -> tuple[int, ...]: ... + + def __getitem__(self, key): ... diff --git a/fastplotlib/widgets/image_widget/__init__.py b/fastplotlib/widgets/image_widget/__init__.py index 70a1aa8ae..dc5daea55 100644 --- a/fastplotlib/widgets/image_widget/__init__.py +++ b/fastplotlib/widgets/image_widget/__init__.py @@ -2,6 +2,7 @@ if IMGUI: from ._widget import ImageWidget + from ._processor import NDImageProcessor else: diff --git a/fastplotlib/widgets/image_widget/_nd_iw_backup.py b/fastplotlib/widgets/image_widget/_nd_iw_backup.py new file mode 100644 index 000000000..7db265c0c --- /dev/null +++ b/fastplotlib/widgets/image_widget/_nd_iw_backup.py @@ -0,0 +1,1007 @@ +from typing import Callable, Sequence, Literal +from warnings import warn + +import numpy as np + +from rendercanvas import BaseRenderCanvas + +from ...layouts import ImguiFigure as Figure +from ...graphics import ImageGraphic, ImageVolumeGraphic +from ...utils import calculate_figure_shape, quick_min_max, ArrayProtocol +from ...tools import HistogramLUTTool +from ._sliders import ImageWidgetSliders +from ._processor import NDImageProcessor, WindowFuncCallable +from ._properties import ImageWidgetProperty, Indices + + +IMGUI_SLIDER_HEIGHT = 49 + + +class ImageWidget: + def __init__( + self, + data: ArrayProtocol | Sequence[ArrayProtocol | None] | None, + processors: NDImageProcessor | Sequence[NDImageProcessor] = NDImageProcessor, + n_display_dims: Literal[2, 3] | Sequence[Literal[2, 3]] = 2, + slider_dim_names: Sequence[str] | None = None, # dim names left -> right + rgb: bool | Sequence[bool] = False, + cmap: str | Sequence[str] = "plasma", + window_funcs: ( + tuple[WindowFuncCallable | None, ...] + | WindowFuncCallable + | None + | Sequence[ + tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None + ] + ) = None, + window_sizes: ( + tuple[int | None, ...] | Sequence[tuple[int | None, ...] | None] + ) = None, + window_order: tuple[int, ...] | Sequence[tuple[int, ...] | None] = None, + spatial_func: ( + Callable[[ArrayProtocol], ArrayProtocol] + | Sequence[Callable[[ArrayProtocol], ArrayProtocol]] + | None + ) = None, + sliders_dim_order: Literal["right", "left"] = "right", + figure_shape: tuple[int, int] = None, + names: Sequence[str] = None, + figure_kwargs: dict = None, + histogram_widget: bool = True, + histogram_init_quantile: int = (0, 100), + graphic_kwargs: dict | Sequence[dict] = None, + ): + """ + This widget facilitates high-level navigation through image stacks, which are arrays containing one or more + images. It includes sliders for key dimensions such as "t" (time) and "z", enabling users to smoothly navigate + through one or multiple image stacks simultaneously. + + Allowed dimensions orders for each image stack: Note that each has a an optional (c) channel which refers to + RGB(A) a channel. So this channel should be either 3 or 4. + + Parameters + ---------- + data: ArrayProtocol | Sequence[ArrayProtocol | None] | None + array-like or a list of array-like, each array must have a minimum of 2 dimensions + + processors: NDImageProcessor | Sequence[NDImageProcessor], default NDImageProcessor + The image processors used for each n-dimensional data array + + n_display_dims: Literal[2, 3] | Sequence[Literal[2, 3]], default 2 + number of display dimensions + + slider_dim_names: Sequence[str], optional + optional list/tuple of names for each slider dim + + rgb: bool | Sequence[bool], default + whether or not each data array represents RGB(A) images + + figure_shape: Optional[Tuple[int, int]] + manually provide the shape for the Figure, otherwise the number of rows and columns is estimated + + figure_kwargs: dict, optional + passed to ``Figure`` + + names: Optional[str] + gives names to the subplots + + histogram_widget: bool, default False + make histogram LUT widget for each subplot + + rgb: bool | list[bool], default None + bool or list of bool for each input data array in the ImageWidget, indicating whether the corresponding + data arrays are grayscale or RGB(A). + + graphic_kwargs: Any + passed to each ImageGraphic in the ImageWidget figure subplots + + """ + + if figure_kwargs is None: + figure_kwargs = dict() + + if isinstance(data, ArrayProtocol) or (data is None): + data = [data] + + elif isinstance(data, (list, tuple)): + # verify that it's a list of np.ndarray + if not all([isinstance(d, ArrayProtocol) or d is None for d in data]): + raise TypeError( + f"`data` must be an array-like type or a list/tuple of array-like or None. " + f"You have passed the following type {type(data)}" + ) + + else: + raise TypeError( + f"`data` must be an array-like type or a list/tuple of array-like or None. " + f"You have passed the following type {type(data)}" + ) + + if issubclass(processors, NDImageProcessor): + processors = [processors] * len(data) + + elif isinstance(processors, (tuple, list)): + if not all([issubclass(p, NDImageProcessor) for p in processors]): + raise TypeError( + f"`processors` must be a `NDImageProcess` class, a subclass of `NDImageProcessor`, or a " + f"list/tuple of `NDImageProcess` subclasses. You have passed: {processors}" + ) + + else: + raise TypeError( + f"`processors` must be a `NDImageProcess` class, a subclass of `NDImageProcessor`, or a " + f"list/tuple of `NDImageProcess` subclasses. You have passed: {processors}" + ) + + # subplot layout + if figure_shape is None: + if "shape" in figure_kwargs: + figure_shape = figure_kwargs["shape"] + else: + figure_shape = calculate_figure_shape(len(data)) + + # Regardless of how figure_shape is computed, below code + # verifies that figure shape is large enough for the number of image arrays passed + if figure_shape[0] * figure_shape[1] < len(data): + original_shape = (figure_shape[0], figure_shape[1]) + figure_shape = calculate_figure_shape(len(data)) + warn( + f"Original `figure_shape` was: {original_shape} " + f" but data length is {len(data)}" + f" Resetting figure shape to: {figure_shape}" + ) + + elif isinstance(rgb, bool): + rgb = [rgb] * len(data) + + if not all([isinstance(v, bool) for v in rgb]): + raise TypeError( + f"`rgb` parameter must be a bool or a Sequence of bool, you have passed: {rgb}" + ) + + if not len(rgb) == len(data): + raise ValueError( + f"len(rgb) != len(data), {len(rgb)} != {len(data)}. These must be equal" + ) + + if names is not None: + if not all([isinstance(n, str) for n in names]): + raise TypeError("optional argument `names` must be a Sequence of str") + + if len(names) != len(data): + raise ValueError( + "number of `names` for subplots must be same as the number of data arrays" + ) + + # verify window funcs + if window_funcs is None: + win_funcs = [None] * len(data) + + elif callable(window_funcs) or all( + [callable(f) or f is None for f in window_funcs] + ): + # across all data arrays + # one window function defined for all dims, or window functions defined per-dim + win_funcs = [window_funcs] * len(data) + + # if the above two clauses didn't trigger, then window_funcs defined per-dim, per data array + elif len(window_funcs) != len(data): + raise IndexError + else: + win_funcs = window_funcs + + # verify window sizes + if window_sizes is None: + win_sizes = [window_sizes] * len(data) + + elif isinstance(window_sizes, int): + win_sizes = [window_sizes] * len(data) + + elif all([isinstance(size, int) or size is None for size in window_sizes]): + # window sizes defined per-dim across all data arrays + win_sizes = [window_sizes] * len(data) + + elif len(window_sizes) != len(data): + # window sizes defined per-dim, per data array + raise IndexError + else: + win_sizes = window_sizes + + # verify window orders + if window_order is None: + win_order = [None] * len(data) + + elif all([isinstance(o, int) for o in order]): + # window order defined per-dim across all data arrays + win_order = [window_order] * len(data) + + elif len(window_order) != len(data): + raise IndexError + + else: + win_order = window_order + + # verify spatial_func + if spatial_func is None: + spatial_func = [None] * len(data) + + elif callable(spatial_func): + # same spatial_func for all data arrays + spatial_func = [spatial_func] * len(data) + + elif len(spatial_func) != len(data): + raise IndexError + + else: + spatial_func = spatial_func + + # verify number of display dims + if isinstance(n_display_dims, (int, np.integer)): + n_display_dims = [n_display_dims] * len(data) + + elif isinstance(n_display_dims, (tuple, list)): + if not all([isinstance(n, (int, np.integer)) for n in n_display_dims]): + raise TypeError + + if len(n_display_dims) != len(data): + raise IndexError + else: + raise TypeError + + n_display_dims = tuple(n_display_dims) + + if sliders_dim_order not in ("right",): + raise ValueError( + f"Only 'right' slider dims order is currently supported, you passed: {sliders_dim_order}" + ) + self._sliders_dim_order = sliders_dim_order + + self._slider_dim_names = None + self.slider_dim_names = slider_dim_names + + self._histogram_widget = histogram_widget + + # make NDImageArrays + self._image_processors: list[NDImageProcessor] = list() + for i in range(len(data)): + cls = processors[i] + image_processor = cls( + data=data[i], + rgb=rgb[i], + n_display_dims=n_display_dims[i], + window_funcs=win_funcs[i], + window_sizes=win_sizes[i], + window_order=win_order[i], + spatial_func=spatial_func[i], + compute_histogram=self._histogram_widget, + ) + + self._image_processors.append(image_processor) + + self._data = ImageWidgetProperty(self, "data") + self._rgb = ImageWidgetProperty(self, "rgb") + self._n_display_dims = ImageWidgetProperty(self, "n_display_dims") + self._window_funcs = ImageWidgetProperty(self, "window_funcs") + self._window_sizes = ImageWidgetProperty(self, "window_sizes") + self._window_order = ImageWidgetProperty(self, "window_order") + self._spatial_func = ImageWidgetProperty(self, "spatial_func") + + if len(set(n_display_dims)) > 1: + # assume user wants one controller for 2D images and another for 3D image volumes + n_subplots = np.prod(figure_shape) + controller_ids = [0] * n_subplots + controller_types = ["panzoom"] * n_subplots + + for i in range(len(data)): + if n_display_dims[i] == 2: + controller_ids[i] = 1 + else: + controller_ids[i] = 2 + controller_types[i] = "orbit" + + # needs to be a list of list + controller_ids = [controller_ids] + + else: + controller_ids = "sync" + controller_types = None + + figure_kwargs_default = { + "controller_ids": controller_ids, + "controller_types": controller_types, + "names": names, + } + + # update the default kwargs with any user-specified kwargs + # user specified kwargs will overwrite the defaults + figure_kwargs_default.update(figure_kwargs) + figure_kwargs_default["shape"] = figure_shape + + if graphic_kwargs is None: + graphic_kwargs = [dict()] * len(data) + + elif isinstance(graphic_kwargs, dict): + graphic_kwargs = [graphic_kwargs] * len(data) + + elif len(graphic_kwargs) != len(data): + raise IndexError + + if cmap is None: + cmap = [None] * len(data) + + elif isinstance(cmap, str): + cmap = [cmap] * len(data) + + elif not all([isinstance(c, str) for c in cmap]): + raise TypeError(f"`cmap` must be a or a list/tuple of ") + + self._figure: Figure = Figure(**figure_kwargs_default) + + self._indices = Indices(list(0 for i in range(self.n_sliders)), self) + + for i, subplot in zip(range(len(self._image_processors)), self.figure): + image_data = self._get_image( + self._image_processors[i], tuple(self._indices) + ) + + if image_data is None: + # this subplot/data array is blank, skip + continue + + # next 20 lines are just vmin, vmax parsing + vmin_specified, vmax_specified = None, None + if "vmin" in graphic_kwargs[i].keys(): + vmin_specified = graphic_kwargs[i].pop("vmin") + if "vmax" in graphic_kwargs[i].keys(): + vmax_specified = graphic_kwargs[i].pop("vmax") + + if (vmin_specified is None) or (vmax_specified is None): + # if either vmin or vmax are not specified, calculate an estimate by subsampling + vmin_estimate, vmax_estimate = quick_min_max( + self._image_processors[i].data + ) + + # decide vmin, vmax passed to ImageGraphic constructor based on whether it's user specified or now + if vmin_specified is None: + # user hasn't specified vmin, use estimated value + vmin = vmin_estimate + else: + # user has provided a specific value, use that + vmin = vmin_specified + + if vmax_specified is None: + vmax = vmax_estimate + else: + vmax = vmax_specified + else: + # both vmin and vmax are specified + vmin, vmax = vmin_specified, vmax_specified + + graphic_kwargs[i]["cmap"] = cmap[i] + + if self._image_processors[i].n_display_dims == 2: + # create an Image + graphic = ImageGraphic( + data=image_data, + name="image_widget_managed", + vmin=vmin, + vmax=vmax, + **graphic_kwargs[i], + ) + elif self._image_processors[i].n_display_dims == 3: + # create an ImageVolume + graphic = ImageVolumeGraphic( + data=image_data, + name="image_widget_managed", + vmin=vmin, + vmax=vmax, + **graphic_kwargs[i], + ) + subplot.camera.fov = 50 + + subplot.add_graphic(graphic) + + self._reset_histogram(subplot, self._image_processors[i]) + + self._sliders_ui = ImageWidgetSliders( + figure=self.figure, + size=57 + (IMGUI_SLIDER_HEIGHT * self.n_sliders), + location="bottom", + title="ImageWidget Controls", + image_widget=self, + ) + + self.figure.add_gui(self._sliders_ui) + + self._indices_changed_handlers = set() + + self._reentrant_block = False + + @property + def data(self) -> ImageWidgetProperty[ArrayProtocol | None]: + """get or set the nd-image data arrays""" + return self._data + + @data.setter + def data(self, new_data: Sequence[ArrayProtocol | None]): + if isinstance(new_data, ArrayProtocol) or new_data is None: + new_data = [new_data] * len(self._image_processors) + + if len(new_data) != len(self._image_processors): + raise IndexError + + # if the data array hasn't been changed + # graphics will not be reset for this data index + skip_indices = list() + + for i, (new_data, image_processor) in enumerate( + zip(new_data, self._image_processors) + ): + if new_data is image_processor.data: + skip_indices.append(i) + continue + + image_processor.data = new_data + + self._reset(skip_indices) + + @property + def rgb(self) -> ImageWidgetProperty[bool]: + """get or set the rgb toggle for each data array""" + return self._rgb + + @rgb.setter + def rgb(self, rgb: Sequence[bool]): + if isinstance(rgb, bool): + rgb = [rgb] * len(self._image_processors) + + if len(rgb) != len(self._image_processors): + raise IndexError + + # if the rgb option hasn't been changed + # graphics will not be reset for this data index + skip_indices = list() + + for i, (new, image_processor) in enumerate(zip(rgb, self._image_processors)): + if image_processor.rgb == new: + skip_indices.append(i) + continue + + image_processor.rgb = new + + self._reset(skip_indices) + + @property + def n_display_dims(self) -> ImageWidgetProperty[Literal[2, 3]]: + """Get or set the number of display dimensions for each data array, 2 is a 2D image, 3 is a 3D volume image""" + return self._n_display_dims + + @n_display_dims.setter + def n_display_dims(self, new_ndd: Sequence[Literal[2, 3]] | Literal[2, 3]): + if isinstance(new_ndd, (int, np.integer)): + if new_ndd == 2 or new_ndd == 3: + new_ndd = [new_ndd] * len(self._image_processors) + else: + raise ValueError + + if len(new_ndd) != len(self._image_processors): + raise IndexError + + if not all([(n == 2) or (n == 3) for n in new_ndd]): + raise ValueError + + # if the n_display_dims hasn't been changed for this data array + # graphics will not be reset for this data array index + skip_indices = list() + + # first update image arrays + for i, (image_processor, new) in enumerate( + zip(self._image_processors, new_ndd) + ): + if new > image_processor.max_n_display_dims: + raise IndexError( + f"number of display dims exceeds maximum number of possible " + f"display dimensions: {image_processor.max_n_display_dims}, for array at index: " + f"{i} with shape: {image_processor.shape}, and rgb set to: {image_processor.rgb}" + ) + + if image_processor.n_display_dims == new: + skip_indices.append(i) + else: + image_processor.n_display_dims = new + + self._reset(skip_indices) + + @property + def window_funcs(self) -> ImageWidgetProperty[tuple[WindowFuncCallable | None] | None]: + """get or set the window functions""" + return self._window_funcs + + @window_funcs.setter + def window_funcs(self, new_funcs: Sequence[WindowFuncCallable | None] | None): + if callable(new_funcs) or new_funcs is None: + new_funcs = [new_funcs] * len(self._image_processors) + + if len(new_funcs) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("window_funcs", new_funcs) + + @property + def window_sizes(self) -> ImageWidgetProperty[tuple[int | None, ...] | None]: + """get or set the window sizes""" + return self._window_sizes + + @window_sizes.setter + def window_sizes( + self, new_sizes: Sequence[tuple[int | None, ...] | int | None] | int | None + ): + if isinstance(new_sizes, int) or new_sizes is None: + # same window for all data arrays + new_sizes = [new_sizes] * len(self._image_processors) + + if len(new_sizes) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("window_sizes", new_sizes) + + @property + def window_order(self) -> ImageWidgetProperty[tuple[int, ...] | None]: + """get or set order in which window functions are applied over dimensions""" + return self._window_order + + @window_order.setter + def window_order(self, new_order: Sequence[tuple[int, ...]]): + if new_order is None: + new_order = [new_order] * len(self._image_processors) + + if all([isinstance(order, (int, np.integer))] for order in new_order): + # same order specified across all data arrays + new_order = [new_order] * len(self._image_processors) + + if len(new_order) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("window_order", new_order) + + @property + def spatial_func(self) -> ImageWidgetProperty[Callable | None]: + """Get or set a spatial_func that operates on the spatial dimensions of the 2D or 3D image""" + return self._spatial_func + + @spatial_func.setter + def spatial_func(self, funcs: Callable | Sequence[Callable] | None): + if callable(funcs) or funcs is None: + funcs = [funcs] * len(self._image_processors) + + if len(funcs) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("spatial_func", funcs) + + def _set_image_processor_funcs(self, attr, new_values): + """sets window_funcs, window_sizes, window_order, or spatial_func and updates displayed data and histograms""" + for new, image_processor, subplot in zip( + new_values, self._image_processors, self.figure + ): + if getattr(image_processor, attr) == new: + continue + + setattr(image_processor, attr, new) + + # window functions and spatial functions will only change the histogram + # they do not change the collections of dimensions, so we don't need to call _reset_dimensions + # they also do not change the image graphic, so we do not need to call _reset_image_graphics + self._reset_histogram(subplot, image_processor) + + # update the displayed image data in the graphics + self.indices = self.indices + + @property + def indices(self) -> ImageWidgetProperty[int]: + """ + Get or set the current indices. + + Returns + ------- + indices: ImageWidgetProperty[int] + integer index for each slider dimension + + """ + return self._indices + + @indices.setter + def indices(self, new_indices: Sequence[int]): + if self._reentrant_block: + return + + try: + self._reentrant_block = True # block re-execution until new_indices has *fully* completed execution + + if len(new_indices) != self.n_sliders: + raise IndexError( + f"len(new_indices) != ImageWidget.n_sliders, {len(new_indices)} != {self.n_sliders}. " + f"The length of the new_indices must be the same as the number of sliders" + ) + + if any([i < 0 for i in new_indices]): + raise IndexError( + f"only positive index values are supported, you have passed: {new_indices}" + ) + + for image_processor, graphic in zip(self._image_processors, self.graphics): + new_data = self._get_image(image_processor, indices=new_indices) + if new_data is None: + continue + + graphic.data = new_data + + self._indices._fpl_set(new_indices) + + # call any event handlers + for handler in self._indices_changed_handlers: + handler(tuple(self.indices)) + + except Exception as exc: + # raise original exception + raise exc # indices setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._reentrant_block = False + + @property + def histogram_widget(self) -> bool: + """show or hide the histograms""" + return self._histogram_widget + + @histogram_widget.setter + def histogram_widget(self, show_histogram: bool): + if not isinstance(show_histogram, bool): + raise TypeError( + f"`histogram_widget` can be set with a bool, you have passed: {show_histogram}" + ) + + for subplot, image_processor in zip(self.figure, self._image_processors): + image_processor.compute_histogram = show_histogram + self._reset_histogram(subplot, image_processor) + + @property + def n_sliders(self) -> int: + """number of sliders""" + return max([a.n_slider_dims for a in self._image_processors]) + + @property + def bounds(self) -> tuple[int, ...]: + """The max bound across all dimensions across all data arrays""" + # initialize with 0 + bounds = [0] * self.n_sliders + + # TODO: implement left -> right slider dims ordering, right now it's only right -> left + # in reverse because dims go left <- right + for i, dim in enumerate(range(-1, -self.n_sliders - 1, -1)): + # across each dim + for array in self._image_processors: + if i > array.n_slider_dims - 1: + continue + # across each data array + # dims go left <- right + bounds[dim] = max(array.slider_dims_shape[dim], bounds[dim]) + + return bounds + + @property + def slider_dim_names(self) -> tuple[str, ...]: + return self._slider_dim_names + + @slider_dim_names.setter + def slider_dim_names(self, names: Sequence[str]): + if names is None: + self._slider_dim_names = None + return + + if not all([isinstance(n, str) for n in names]): + raise TypeError(f"`slider_dim_names` must be set with a list/tuple of , you passed: {names}") + + if len(set(names)) != len(names): + raise ValueError( + f"`slider_dim_names` must be unique, you passed: {names}" + ) + + self._slider_dim_names = tuple(names) + + def _get_image( + self, image_processor: NDImageProcessor, indices: Sequence[int] + ) -> ArrayProtocol: + """Get a processed 2d or 3d image from the NDImage at the given indices""" + n = image_processor.n_slider_dims + + if self._sliders_dim_order == "right": + return image_processor.get(indices[-n:]) + + elif self._sliders_dim_order == "left": + # TODO: left -> right is not fully implemented yet in ImageWidget + return image_processor.get(indices[:n]) + + def _reset_dimensions(self): + """reset the dimensions w.r.t. current collection of NDImageProcessors""" + # TODO: implement left -> right slider dims ordering, right now it's only right -> left + # add or remove dims from indices + # trim any excess dimensions + while len(self._indices) > self.n_sliders: + # remove outer most dims first + self._indices.pop_dim() + self._sliders_ui.pop_dim() + + # add any new dimensions that aren't present + while len(self.indices) < self.n_sliders: + # insert right -> left + self._indices.push_dim() + self._sliders_ui.push_dim() + + self._sliders_ui.size = 57 + (IMGUI_SLIDER_HEIGHT * self.n_sliders) + + def _reset_image_graphics(self, subplot, image_processor): + """delete and create a new image graphic if necessary""" + new_image = self._get_image(image_processor, indices=tuple(self.indices)) + if new_image is None: + if "image_widget_managed" in subplot: + # delete graphic from this subplot if present + subplot.delete_graphic(subplot["image_widget_managed"]) + # skip this subplot + return + + # check if a graphic exists + if "image_widget_managed" in subplot: + # create a new graphic only if the Texture buffer shape doesn't match + if subplot["image_widget_managed"].data.value.shape == new_image.shape: + return + + # keep cmap + cmap = subplot["image_widget_managed"].cmap + if cmap is None: + # ex: going from rgb -> grayscale + cmap = "plasma" + # delete graphic since it will be replaced + subplot.delete_graphic(subplot["image_widget_managed"]) + else: + # default cmap + cmap = "plasma" + + if image_processor.n_display_dims == 2: + g = subplot.add_image( + data=new_image, cmap=cmap, name="image_widget_managed" + ) + + # set camera orthogonal to the xy plane, flip y axis + subplot.camera.set_state( + { + "position": [0, 0, -1], + "rotation": [0, 0, 0, 1], + "scale": [1, -1, 1], + "reference_up": [0, 1, 0], + "fov": 0, + "depth_range": None, + } + ) + + subplot.controller = "panzoom" + subplot.axes.intersection = None + subplot.auto_scale() + + elif image_processor.n_display_dims == 3: + g = subplot.add_image_volume( + data=new_image, cmap=cmap, name="image_widget_managed" + ) + subplot.camera.fov = 50 + subplot.controller = "orbit" + + # make sure all 3D dimension camera scales are positive + # MIP rendering doesn't work with negative camera scales + for dim in ["x", "y", "z"]: + if getattr(subplot.camera.local, f"scale_{dim}") < 0: + setattr(subplot.camera.local, f"scale_{dim}", 1) + + subplot.auto_scale() + + def _reset_histogram(self, subplot, image_processor): + """reset the histogram""" + if not self._histogram_widget: + subplot.docks["right"].size = 0 + return + + if image_processor.histogram is None: + # no histogram available for this processor + # either there is no data array in this subplot, + # or a histogram routine does not exist for this processor + subplot.docks["right"].size = 0 + return + + if "image_widget_managed" not in subplot: + # no image in this subplot + subplot.docks["right"].size = 0 + return + + image = subplot["image_widget_managed"] + + if "histogram_lut" in subplot.docks["right"]: + hlut: HistogramLUTTool = subplot.docks["right"]["histogram_lut"] + hlut.histogram = image_processor.histogram + hlut.images = image + if subplot.docks["right"].size < 1: + subplot.docks["right"].size = 80 + + else: + # need to make one + hlut = HistogramLUTTool( + histogram=image_processor.histogram, + images=image, + name="histogram_lut", + ) + + subplot.docks["right"].add_graphic(hlut) + subplot.docks["right"].size = 80 + + self.reset_vmin_vmax() + + def _reset(self, skip_data_indices: tuple[int, ...] = None): + if skip_data_indices is None: + skip_data_indices = tuple() + + # reset the slider indices according to the new collection of dimensions + self._reset_dimensions() + # update graphics where display dims have changed accordings to indices + for i, (subplot, image_processor) in enumerate( + zip(self.figure, self._image_processors) + ): + if i in skip_data_indices: + continue + + self._reset_image_graphics(subplot, image_processor) + self._reset_histogram(subplot, image_processor) + + # force an update + self.indices = self.indices + + @property + def figure(self) -> Figure: + """ + ``Figure`` used by `ImageWidget`. + """ + return self._figure + + @property + def graphics(self) -> list[ImageGraphic]: + """List of ``ImageWidget`` managed graphics.""" + iw_managed = list() + for subplot in self.figure: + if "image_widget_managed" in subplot: + iw_managed.append(subplot["image_widget_managed"]) + else: + iw_managed.append(None) + return tuple(iw_managed) + + @property + def cmap(self) -> tuple[str | None, ...]: + """get the cmaps, or set the cmap across all images""" + return tuple(g.cmap for g in self.graphics) + + @cmap.setter + def cmap(self, name: str): + for g in self.graphics: + if g is None: + # no data at this index + continue + + if g.cmap is None: + # if rgb + continue + + g.cmap = name + + def add_event_handler(self, handler: callable, event: str = "indices"): + """ + Register an event handler. + + Currently the only event that ImageWidget supports is "indices". This event is + emitted whenever the indices of the ImageWidget changes. + + Parameters + ---------- + handler: callable + callback function, must take a tuple of int as the only argument. This tuple will be the `indices` + + event: str, "indices" + the only supported event is "indices" + + Example + ------- + + .. code-block:: py + + def my_handler(indices): + print(indices) + # example prints: (100, 15) if the data has 2 slider dimensions with sliders at positions 100, 15 + + # create an image widget + iw = ImageWidget(...) + + # add event handler + iw.add_event_handler(my_handler) + + """ + if event != "indices": + raise ValueError("`indices` is the only event supported by `ImageWidget`") + + self._indices_changed_handlers.add(handler) + + def remove_event_handler(self, handler: callable): + """Remove a registered event handler""" + self._indices_changed_handlers.remove(handler) + + def clear_event_handlers(self): + """Clear all registered event handlers""" + self._indices_changed_handlers.clear() + + def reset_vmin_vmax(self): + """ + Reset the vmin and vmax w.r.t. the full data + """ + for image_processor, subplot in zip(self._image_processors, self.figure): + if "histogram_lut" not in subplot.docks["right"]: + continue + + if image_processor.histogram is None: + continue + + hlut = subplot.docks["right"]["histogram_lut"] + hlut.histogram = image_processor.histogram + + edges = image_processor.histogram[1] + + hlut.vmin, hlut.vmax = edges[0], edges[-1] + + def reset_vmin_vmax_frame(self): + """ + Resets the vmin vmax and HistogramLUT widgets w.r.t. the current data shown in the + ImageGraphic instead of the data in the full data array. For example, if a post-processing + function is used, the range of values in the ImageGraphic can be very different from the + range of values in the full data array. + """ + + for subplot, image_processor in zip(self.figure, self._image_processors): + if "histogram_lut" not in subplot.docks["right"]: + continue + + if image_processor.histogram is None: + continue + + hlut = subplot.docks["right"]["histogram_lut"] + # set the data using the current image graphic data + image = subplot["image_widget_managed"] + freqs, edges = np.histogram(image.data.value, bins=100) + hlut.histogram = (freqs, edges) + hlut.vmin, hlut.vmax = edges[0], edges[-1] + + def show(self, **kwargs): + """ + Show the widget. + + Parameters + ---------- + + kwargs: Any + passed to `Figure.show()`t + + Returns + ------- + BaseRenderCanvas + In Qt or GLFW, the canvas window containing the Figure will be shown. + In jupyter, it will display the plot in the output cell or sidecar. + + """ + + return self.figure.show(**kwargs) + + def close(self): + """Close Widget""" + self.figure.close() diff --git a/fastplotlib/widgets/image_widget/_processor.py b/fastplotlib/widgets/image_widget/_processor.py new file mode 100644 index 000000000..0dce84a5e --- /dev/null +++ b/fastplotlib/widgets/image_widget/_processor.py @@ -0,0 +1,519 @@ +import inspect +from typing import Literal, Callable +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS + + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +class NDImageProcessor: + def __init__( + self, + data: ArrayLike | None, + n_display_dims: Literal[2, 3] = 2, + rgb: bool = False, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_sizes: tuple[int | None, ...] | int = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayLike], ArrayLike] = None, + compute_histogram: bool = True, + ): + """ + An ND image that supports computing window functions, and functions over spatial dimensions. + + Parameters + ---------- + data: ArrayLike + array-like data, must have 2 or more dimensions + + n_display_dims: int, 2 or 3, default 2 + number of display dimensions + + rgb: bool, default False + whether the image data is RGB(A) or not + + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable, optional + A function or a ``tuple`` of functions that are applied to a rolling window of the data. + + You can provide unique window functions for each dimension. If you want to apply a window function + only to a subset of the dimensions, put ``None`` to indicate no window function for a given dimension. + + A "window function" must take ``axis`` argument, which is an ``int`` that specifies the axis along which + the window function is applied. It must also take a ``keepdims`` argument which is a ``bool``. The window + function **must** return an array that has the same number of dimensions as the original ``data`` array, + therefore the size of the dimension along which the window was applied will reduce to ``1``. + + The output array-like type from a window function **must** support a ``.squeeze()`` method, but the + function itself should NOT squeeze the output array. + + window_sizes: tuple[int | None, ...], optional + ``tuple`` of ``int`` that specifies the window size for each dimension. + + window_order: tuple[int, ...] | None, optional + order in which to apply the window functions, by default just applies it from the left-most dim to the + right-most slider dim. + + spatial_func: Callable[[ArrayLike], ArrayLike] | None, optional + A function that is applied on the _spatial_ dimensions of the data array, i.e. the last 2 or 3 dimensions. + This function is applied after the window functions (if present). + + compute_histogram: bool, default True + Compute a histogram of the data, auto re-computes if window function propties or spatial_func changes. + Disable if slow. + + """ + # set as False until data, window funcs stuff and spatial func is all set + self._compute_histogram = False + + self.data = data + self.n_display_dims = n_display_dims + self.rgb = rgb + + self.window_funcs = window_funcs + self.window_sizes = window_sizes + self.window_order = window_order + + self._spatial_func = spatial_func + + self._compute_histogram = compute_histogram + self._recompute_histogram() + + @property + def data(self) -> ArrayLike | None: + """get or set the data array""" + return self._data + + @data.setter + def data(self, data: ArrayLike): + # check that all array-like attributes are present + if data is None: + self._data = None + return + + if not isinstance(data, ArrayProtocol): + raise TypeError( + f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" + f"{ARRAY_LIKE_ATTRS}, or they must be `None`" + ) + + if data.ndim < 2: + raise IndexError( + f"Image data must have a minimum of 2 dimensions, you have passed an array of shape: {data.shape}" + ) + + self._data = data + self._recompute_histogram() + + @property + def ndim(self) -> int: + if self.data is None: + return 0 + + return self.data.ndim + + @property + def shape(self) -> tuple[int, ...]: + if self._data is None: + return tuple() + + return self.data.shape + + @property + def rgb(self) -> bool: + """whether or not the data is rgb(a)""" + return self._rgb + + @rgb.setter + def rgb(self, rgb: bool): + if not isinstance(rgb, bool): + raise TypeError + + if rgb and self.ndim < 3: + raise IndexError( + f"require 3 or more dims for RGB, you have: {self.ndim} dims" + ) + + self._rgb = rgb + + @property + def n_slider_dims(self) -> int: + """number of slider dimensions""" + if self._data is None: + return 0 + + return self.ndim - self.n_display_dims - int(self.rgb) + + @property + def slider_dims(self) -> tuple[int, ...] | None: + """tuple indicating the slider dimension indices""" + if self.n_slider_dims == 0: + return None + + return tuple(range(self.n_slider_dims)) + + @property + def slider_dims_shape(self) -> tuple[int, ...] | None: + if self.n_slider_dims == 0: + return None + + return tuple(self.shape[i] for i in self.slider_dims) + + @property + def n_display_dims(self) -> Literal[2, 3]: + """get or set the number of display dimensions, `2` for 2D image and `3` for volume images""" + return self._n_display_dims + + # TODO: make n_display_dims settable, requires thinking about inserting and poping indices in ImageWidget + @n_display_dims.setter + def n_display_dims(self, n: Literal[2, 3]): + if not (n == 2 or n == 3): + raise ValueError( + f"`n_display_dims` must be an with a value of 2 or 3, you have passed: {n}" + ) + self._n_display_dims = n + self._recompute_histogram() + + @property + def max_n_display_dims(self) -> int: + """maximum number of possible display dims""" + # min 2, max 3, accounts for if data is None and ndim is 0 + return max(2, min(3, self.ndim - int(self.rgb))) + + @property + def display_dims(self) -> tuple[int, int] | tuple[int, int, int]: + """tuple indicating the display dimension indices""" + return tuple(range(self.data.ndim))[self.n_slider_dims :] + + @property + def window_funcs( + self, + ) -> tuple[WindowFuncCallable | None, ...] | None: + """get or set window functions, see docstring for details""" + return self._window_funcs + + @window_funcs.setter + def window_funcs( + self, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None, + ): + if window_funcs is None: + self._window_funcs = None + return + + if callable(window_funcs): + window_funcs = (window_funcs,) + + # if all are None + if all([f is None for f in window_funcs]): + self._window_funcs = None + return + + self._validate_window_func(window_funcs) + + self._window_funcs = tuple(window_funcs) + self._recompute_histogram() + + def _validate_window_func(self, funcs): + if isinstance(funcs, (tuple, list)): + for f in funcs: + if f is None: + pass + elif callable(f): + sig = inspect.signature(f) + + if "axis" not in sig.parameters or "keepdims" not in sig.parameters: + raise TypeError( + f"Each window function must take an `axis` and `keepdims` argument, " + f"you passed: {f} with the following function signature: {sig}" + ) + else: + raise TypeError( + f"`window_funcs` must be of type: tuple[Callable | None, ...], you have passed: {funcs}" + ) + + if not (len(funcs) == self.n_slider_dims or self.n_slider_dims == 0): + raise IndexError( + f"number of `window_funcs` must be the same as the number of slider dims: {self.n_slider_dims}, " + f"and you passed {len(funcs)} `window_funcs`: {funcs}" + ) + + @property + def window_sizes(self) -> tuple[int | None, ...] | None: + """get or set window sizes used for the corresponding window functions, see docstring for details""" + return self._window_sizes + + @window_sizes.setter + def window_sizes(self, window_sizes: tuple[int | None, ...] | int | None): + if window_sizes is None: + self._window_sizes = None + return + + if isinstance(window_sizes, int): + window_sizes = (window_sizes,) + + # if all are None + if all([w is None for w in window_sizes]): + self._window_sizes = None + return + + if not all([isinstance(w, (int)) or w is None for w in window_sizes]): + raise TypeError( + f"`window_sizes` must be of type: tuple[int | None, ...] | int | None, you have passed: {window_sizes}" + ) + + if not (len(window_sizes) == self.n_slider_dims or self.n_slider_dims == 0): + raise IndexError( + f"number of `window_sizes` must be the same as the number of slider dims, " + f"i.e. `data.ndim` - n_display_dims, your data array has {self.ndim} dimensions " + f"and you passed {len(window_sizes)} `window_sizes`: {window_sizes}" + ) + + # make all window sizes are valid numbers + _window_sizes = list() + for i, w in enumerate(window_sizes): + if w is None: + _window_sizes.append(None) + continue + + if w < 0: + raise ValueError( + f"negative window size passed, all `window_sizes` must be positive " + f"integers or `None`, you passed: {_window_sizes}" + ) + + if w == 0 or w == 1: + # this is not a real window, set as None + w = None + + elif w % 2 == 0: + # odd window sizes makes most sense + warn( + f"provided even window size: {w} in dim: {i}, adding `1` to make it odd" + ) + w += 1 + + _window_sizes.append(w) + + self._window_sizes = tuple(_window_sizes) + self._recompute_histogram() + + @property + def window_order(self) -> tuple[int, ...] | None: + """get or set dimension order in which window functions are applied""" + return self._window_order + + @window_order.setter + def window_order(self, order: tuple[int] | None): + if order is None: + self._window_order = None + return + + if order is not None: + if not all([d <= self.n_slider_dims for d in order]): + raise IndexError( + f"all `window_order` entries must be <= n_slider_dims\n" + f"`n_slider_dims` is: {self.n_slider_dims}, you have passed `window_order`: {order}" + ) + + if not all([d >= 0 for d in order]): + raise IndexError( + f"all `window_order` entires must be >= 0, you have passed: {order}" + ) + + self._window_order = tuple(order) + self._recompute_histogram() + + @property + def spatial_func(self) -> Callable[[ArrayLike], ArrayLike] | None: + """get or set a spatial_func function, see docstring for details""" + return self._spatial_func + + @spatial_func.setter + def spatial_func(self, func: Callable[[ArrayLike], ArrayLike] | None): + if not (callable(func) or func is not None): + raise TypeError( + f"`spatial_func` must be a callable or `None`, you have passed: {func}" + ) + + self._spatial_func = func + self._recompute_histogram() + + @property + def compute_histogram(self) -> bool: + return self._compute_histogram + + @compute_histogram.setter + def compute_histogram(self, compute: bool): + if compute: + if self._compute_histogram is False: + # compute a histogram + self._recompute_histogram() + self._compute_histogram = True + else: + self._compute_histogram = False + self._histogram = None + + @property + def histogram(self) -> tuple[np.ndarray, np.ndarray] | None: + """ + an estimate of the histogram of the data, (histogram_values, bin_edges). + + returns `None` if `compute_histogram` is `False` + """ + return self._histogram + + def _apply_window_function(self, indices: tuple[int, ...]) -> ArrayLike: + """applies the window functions for each dimension specified""" + # window size for each dim + winds = self._window_sizes + # window function for each dim + funcs = self._window_funcs + + if winds is None or funcs is None: + # no window funcs or window sizes, just slice data and return + # clamp to max bounds + indexer = list() + for dim, i in enumerate(indices): + i = min(self.shape[dim] - 1, i) + indexer.append(i) + + return self.data[tuple(indexer)] + + # order in which window funcs are applied + order = self._window_order + + if order is not None: + # remove any entries in `window_order` where the specified dim + # has a window function or window size specified as `None` + # example: + # window_sizes = (3, 2) + # window_funcs = (np.mean, None) + # order = (0, 1) + # `1` is removed from the order since that window_func is `None` + order = tuple( + d for d in order if winds[d] is not None and funcs[d] is not None + ) + else: + # sequential order + order = list() + for d in range(self.n_slider_dims): + if winds[d] is not None and funcs[d] is not None: + order.append(d) + + # the final indexer which will be used on the data array + indexer = list() + + for dim_index, (i, w, f) in enumerate(zip(indices, winds, funcs)): + # clamp i within the max bounds + i = min(self.shape[dim_index] - 1, i) + + if (w is not None) and (f is not None): + # specify slice window if both window size and function for this dim are not None + hw = int((w - 1) / 2) # half window + + # start index cannot be less than 0 + start = max(0, i - hw) + + # stop index cannot exceed the bounds of this dimension + stop = min(self.shape[dim_index] - 1, i + hw) + + s = slice(start, stop, 1) + else: + s = slice(i, i + 1, 1) + + indexer.append(s) + + # apply indexer to slice data with the specified windows + data_sliced = self.data[tuple(indexer)] + + # finally apply the window functions in the specified order + for dim in order: + f = funcs[dim] + + data_sliced = f(data_sliced, axis=dim, keepdims=True) + + return data_sliced + + def get(self, indices: tuple[int, ...]) -> ArrayLike | None: + """ + Get the data at the given index, process data through the window functions. + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + + Parameters + ---------- + indices: tuple[int, ...] + Get the processed data at this index. Must provide a value for each dimension. + Example: get((100, 5)) + + """ + if self.data is None: + return None + + if self.n_slider_dims != 0: + if len(indices) != self.n_slider_dims: + raise IndexError( + f"Must specify index for every slider dim, you have specified an index: {indices}\n" + f"But there are: {self.n_slider_dims} slider dims." + ) + # get output after processing through all window funcs + # squeeze to remove all dims of size 1 + window_output = self._apply_window_function(indices).squeeze() + else: + # data is a static image or volume + window_output = self.data + + # apply spatial_func + if self.spatial_func is not None: + final_output = self.spatial_func(window_output) + if final_output.ndim != (self.n_display_dims + int(self.rgb)): + raise IndexError( + f"Final output after of the `spatial_func` must match the number of display dims." + f"Output after `spatial_func` returned an array with {final_output.ndim} dims and " + f"of shape: {final_output.shape}, expected {self.n_display_dims} dims" + ) + else: + # check that output ndim after window functions matches display dims + final_output = window_output + if final_output.ndim != (self.n_display_dims + int(self.rgb)): + raise IndexError( + f"Final output after of the `window_funcs` must match the number of display dims." + f"Output after `window_funcs` returned an array with {window_output.ndim} dims and " + f"of shape: {window_output.shape}{' with rgb(a) channels' if self.rgb else ''}, " + f"expected {self.n_display_dims} dims" + ) + + return final_output + + def _recompute_histogram(self): + """ + + Returns + ------- + (histogram_values, bin_edges) + + """ + if not self._compute_histogram or self.data is None: + self._histogram = None + return + + if self.spatial_func is not None: + # don't subsample spatial dims if a spatial function is used + # spatial functions often operate on the spatial dims, ex: a gaussian kernel + # so their results require the full spatial resolution, the histogram of a + # spatially subsampled image will be very different + ignore_dims = self.display_dims + else: + ignore_dims = None + + sub = subsample_array(self.data, ignore_dims=ignore_dims) + sub_real = sub[~(np.isnan(sub) | np.isinf(sub))] + + self._histogram = np.histogram(sub_real, bins=100) diff --git a/fastplotlib/widgets/image_widget/_properties.py b/fastplotlib/widgets/image_widget/_properties.py new file mode 100644 index 000000000..060314439 --- /dev/null +++ b/fastplotlib/widgets/image_widget/_properties.py @@ -0,0 +1,139 @@ +from pprint import pformat +from typing import Iterable + +import numpy as np + +from ._processor import NDImageProcessor + + +class ImageWidgetProperty: + __class_getitem__ = classmethod(type(list[int])) + + def __init__( + self, + image_widget, + attribute: str, + ): + self._image_widget = image_widget + self._image_processors: list[NDImageProcessor] = image_widget._image_processors + self._attribute = attribute + + def _get_key(self, key: slice | int | np.integer | str) -> int | slice: + if not isinstance(key, (slice | int, np.integer, str)): + raise TypeError( + f"can index `{self._attribute}` only with a , , or a indicating the subplot name." + f"You tried to index with: {key}" + ) + + if isinstance(key, str): + for i, subplot in enumerate(self._image_widget.figure): + if subplot.name == key: + key = i + break + else: + raise IndexError(f"No subplot with given name: {key}") + + return key + + def __getitem__(self, key): + key = self._get_key(key) + # return image processor attribute at this index + if isinstance(key, (int, np.integer)): + return getattr(self._image_processors[key], self._attribute) + + # if it's a slice + processors = self._image_processors[key] + + return tuple(getattr(p, self._attribute) for p in processors) + + def __setitem__(self, key, value): + key = self._get_key(key) + + # get the values from the ImageWidget property + new_values = list(getattr(p, self._attribute) for p in self._image_processors) + + # set the new value at this slice + new_values[key] = value + + # call the setter + setattr(self._image_widget, self._attribute, new_values) + + def __iter__(self): + for image_processor in self._image_processors: + yield getattr(image_processor, self._attribute) + + def __repr__(self): + return f"{self._attribute}: {pformat(self[:])}" + + def __eq__(self, other): + return self[:] == other + + +class Indices: + def __init__( + self, + indices: list[int], + image_widget, + ): + self._data = indices + + self._image_widget = image_widget + + def __iter__(self): + for i in self._data: + yield i + + def _parse_key(self, key: int | np.integer | str) -> int: + if not isinstance(key, (int, np.integer, str)): + raise TypeError( + f"indices can only be indexed with or types, you have used: {key}" + ) + + if isinstance(key, str): + # get integer index from user's names + names = self._image_widget._slider_dim_names + if key not in names: + raise KeyError( + f"dim with name: {key} not found in slider_dim_names, current names are: {names}" + ) + + key = names.index(key) + + return key + + def __getitem__(self, key: int | np.integer | str) -> int | tuple[int]: + if isinstance(key, str): + key = self._parse_key(key) + + return self._data[key] + + def __setitem__(self, key, value): + key = self._parse_key(key) + + if not isinstance(value, (int, np.integer)): + raise TypeError( + f"indices values can only be set with integers, you have tried to set the value: {value}" + ) + + new_indices = list(self._data) + new_indices[key] = value + + self._image_widget.indices = new_indices + + def _fpl_set(self, values): + self._data[:] = values + + def pop_dim(self): + self._data.pop(0) + + def push_dim(self): + self._data.insert(0, 0) + + def __len__(self): + return len(self._data) + + def __eq__(self, other): + return self._data == other + + def __repr__(self): + return f"indices: {self._data}" diff --git a/fastplotlib/widgets/image_widget/_sliders.py b/fastplotlib/widgets/image_widget/_sliders.py index 393b13273..1945b8cfb 100644 --- a/fastplotlib/widgets/image_widget/_sliders.py +++ b/fastplotlib/widgets/image_widget/_sliders.py @@ -11,50 +11,66 @@ def __init__(self, figure, size, location, title, image_widget): super().__init__(figure=figure, size=size, location=location, title=title) self._image_widget = image_widget + n_sliders = self._image_widget.n_sliders + # whether or not a dimension is in play mode - self._playing: dict[str, bool] = {"t": False, "z": False} + self._playing: list[bool] = [False] * n_sliders # approximate framerate for playing - self._fps: dict[str, int] = {"t": 20, "z": 20} + self._fps: list[int] = [20] * n_sliders + # framerate converted to frame time - self._frame_time: dict[str, float] = {"t": 1 / 20, "z": 1 / 20} + self._frame_time: list[float] = [1 / 20] * n_sliders # last timepoint that a frame was displayed from a given dimension - self._last_frame_time: dict[str, float] = {"t": 0, "z": 0} + self._last_frame_time: list[float] = [perf_counter()] * n_sliders + # loop playback self._loop = False - if "RTD_BUILD" in os.environ.keys(): - if os.environ["RTD_BUILD"] == "1": - self._playing["t"] = True + # auto-plays the ImageWidget's left-most dimension in docs galleries + if "DOCS_BUILD" in os.environ.keys(): + if os.environ["DOCS_BUILD"] == "1": + self._playing[0] = True self._loop = True - def set_index(self, dim: str, index: int): - """set the current_index of the ImageWidget""" + self.pause = False + + def pop_dim(self): + """pop right most dim""" + i = 0 # len(self._image_widget.indices) - 1 + for l in [self._playing, self._fps, self._frame_time, self._last_frame_time]: + l.pop(i) + + def push_dim(self): + """push a new dim""" + self._playing.insert(0, False) + self._fps.insert(0, 20) + self._frame_time.insert(0, 1 / 20) + self._last_frame_time.insert(0, perf_counter()) + + def set_index(self, dim: int, new_index: int): + """set the index of the ImageWidget""" # make sure the max index for this dim is not exceeded - max_index = self._image_widget._dims_max_bounds[dim] - 1 - if index > max_index: + max_index = self._image_widget.bounds[dim] - 1 + if new_index > max_index: if self._loop: # loop back to index zero if looping is enabled - index = 0 + new_index = 0 else: # if looping not enabled, stop playing this dimension self._playing[dim] = False return - # set current_index - self._image_widget.current_index = {dim: min(index, max_index)} + # set new index + new_indices = list(self._image_widget.indices) + new_indices[dim] = new_index + self._image_widget.indices = new_indices def update(self): """called on every render cycle to update the GUI elements""" - # store the new index of the image widget ("t" and "z") - new_index = dict() - - # flag if the index changed - flag_index_changed = False - # reset vmin-vmax using full orig data if imgui.button(label=fa.ICON_FA_CIRCLE_HALF_STROKE + fa.ICON_FA_FILM): self._image_widget.reset_vmin_vmax() @@ -72,7 +88,7 @@ def update(self): now = perf_counter() # buttons and slider UI elements for each dim - for dim in self._image_widget.slider_dims: + for dim in range(self._image_widget.n_sliders): imgui.push_id(f"{self._id_counter}_{dim}") if self._playing[dim]: @@ -83,7 +99,7 @@ def update(self): # if in play mode and enough time has elapsed w.r.t. the desired framerate, increment the index if now - self._last_frame_time[dim] >= self._frame_time[dim]: - self.set_index(dim, self._image_widget.current_index[dim] + 1) + self.set_index(dim, self._image_widget.indices[dim] + 1) self._last_frame_time[dim] = now else: @@ -97,12 +113,12 @@ def update(self): imgui.same_line() # step back one frame button if imgui.button(label=fa.ICON_FA_BACKWARD_STEP) and not self._playing[dim]: - self.set_index(dim, self._image_widget.current_index[dim] - 1) + self.set_index(dim, self._image_widget.indices[dim] - 1) imgui.same_line() # step forward one frame button if imgui.button(label=fa.ICON_FA_FORWARD_STEP) and not self._playing[dim]: - self.set_index(dim, self._image_widget.current_index[dim] + 1) + self.set_index(dim, self._image_widget.indices[dim] + 1) imgui.same_line() # stop button @@ -137,10 +153,15 @@ def update(self): self._fps[dim] = value self._frame_time[dim] = 1 / value - val = self._image_widget.current_index[dim] - vmax = self._image_widget._dims_max_bounds[dim] - 1 + val = self._image_widget.indices[dim] + vmax = self._image_widget.bounds[dim] - 1 + + dim_name = dim + if self._image_widget._slider_dim_names is not None: + if dim < len(self._image_widget._slider_dim_names): + dim_name = self._image_widget._slider_dim_names[dim] - imgui.text(f"{dim}: ") + imgui.text(f"dim '{dim_name}:' ") imgui.same_line() # so that slider occupies full width imgui.set_next_item_width(self.width * 0.85) @@ -154,18 +175,12 @@ def update(self): # slider for this dimension changed, index = imgui.slider_int( - f"{dim}", v=val, v_min=0, v_max=vmax, flags=flags + f"d: {dim}", v=val, v_min=0, v_max=vmax, flags=flags ) - new_index[dim] = index - - # if the slider value changed for this dimension - flag_index_changed |= changed + if changed: + new_indices = list(self._image_widget.indices) + new_indices[dim] = index + self._image_widget.indices = new_indices imgui.pop_id() - - if flag_index_changed: - # if any slider dim changed set the new index of the image widget - self._image_widget.current_index = new_index - - self.size = int(imgui.get_window_height()) diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py new file mode 100644 index 000000000..70c2e7621 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/__init__.py @@ -0,0 +1,2 @@ +from .processor_base import NDProcessor +from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py b/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py new file mode 100644 index 000000000..03bb0e8f7 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py @@ -0,0 +1,23 @@ +import importlib + +from .core import NDPositions, NDPositionsProcessor + +class Extras: + pass + +ndp_extras = Extras() + + +for optional in ["pandas", "zarr"]: + try: + importlib.import_module(optional) + except ImportError: + pass + else: + module = importlib.import_module(f"._{optional}", "fastplotlib.widgets.nd_widget._nd_positions") + cls = getattr(module, f"NDPP_{optional.capitalize()}") + setattr( + ndp_extras, + f"NDPP_{optional.capitalize()}", + cls + ) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py new file mode 100644 index 000000000..de26c8a9d --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py @@ -0,0 +1,94 @@ +import numpy as np +import pandas as pd + +from .core import NDPositionsProcessor + + +class NDPP_Pandas(NDPositionsProcessor): + def __init__( + self, + data: pd.DataFrame, + columns: list[tuple[str, str] | tuple[str, str, str]], + tooltip_columns: list[str] = None, + max_display_datapoints: int = 1_000, + **kwargs, + ): + data = data + + self._columns = columns + + if tooltip_columns is not None: + if len(tooltip_columns) != len(self.columns): + raise ValueError + self._tooltip_columns = tooltip_columns + self._tooltip = True + else: + self._tooltip_columns = None + self._tooltip = False + + super().__init__( + data=data, + max_display_datapoints=max_display_datapoints, + **kwargs, + ) + + @property + def data(self) -> pd.DataFrame: + return self._data + + def _validate_data(self, data: pd.DataFrame): + if not isinstance(data, pd.DataFrame): + raise TypeError + + return data + + @property + def columns(self) -> list[tuple[str, str] | tuple[str, str, str]]: + return self._columns + + @property + def multi(self) -> bool: + return True + + @multi.setter + def multi(self, v): + pass + + @property + def shape(self) -> tuple[int, ...]: + # n_graphical_elements, n_timepoints, 2 + return len(self.columns), self.data.index.size, 2 + + @property + def ndim(self) -> int: + return len(self.shape) + + @property + def n_slider_dims(self) -> int: + return 1 + + @property + def tooltip(self) -> bool: + return self._tooltip + + def tooltip_format(self, n: int, p: int): + # datapoint index w.r.t. full data + p += self._slices[-1].start + return str(self.data[self._tooltip_columns[n]][p]) + + def get(self, indices: tuple[float | int, ...]) -> np.ndarray: + if not isinstance(indices, tuple): + raise TypeError(".get() must receive a tuple of float | int indices") + # assume no additional slider dims, only time slider dim + self._slices = self._get_dw_slices(indices) + + + gdata_shape = len(self.columns), self._slices[-1].stop - self._slices[-1].start, 3 + gdata = np.zeros(shape=gdata_shape, dtype=np.float32) + + for i, col in enumerate(self.columns): + gdata[i, :, :len(col)] = np.column_stack( + [self.data[c][self._slices[-1]] for c in col] + ) + + return gdata diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py b/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py new file mode 100644 index 000000000..fb3bb7015 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py @@ -0,0 +1,4 @@ +# placeholder + +class NDPP_Zarr: + pass diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py new file mode 100644 index 000000000..b95916ce8 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -0,0 +1,522 @@ +from functools import partial +from typing import Literal, Callable, Any, Type +from warnings import warn + +import numpy as np +from numpy.lib.stride_tricks import sliding_window_view + +from ....utils import subsample_array, ArrayProtocol + +from ....graphics import ( + Graphic, + ImageGraphic, + LineGraphic, + LineStack, + LineCollection, + ScatterGraphic, + ScatterCollection, +) +from ..processor_base import NDProcessor, WindowFuncCallable + + +# TODO: Maybe get rid of n_display_dims in NDProcessor, +# we will know the display dims automatically here from the last dim +# so maybe we only need it for images? +class NDPositionsProcessor(NDProcessor): + def __init__( + self, + data: Any, + multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points + display_window: int | float | None = 100, # window for n_datapoints dim only + max_display_datapoints: int = 1_000, + datapoints_window_func: Callable | None = None, + datapoints_window_size: int | None = None, + **kwargs, + ): + self._display_window = display_window + self._max_display_datapoints = max_display_datapoints + + # TOOD: this does data validation twice and is a bit messy, cleanup + self._data = self._validate_data(data) + self.multi = multi + + super().__init__(data=data, **kwargs) + + self._datapoints_window_func = datapoints_window_func + self._datapoints_window_size = datapoints_window_size + + def _validate_data(self, data: ArrayProtocol): + # TODO: determine right validation shape etc. + return data + + @property + def display_window(self) -> int | float | None: + """display window in the reference units for the n_datapoints dim""" + return self._display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + if dw is None: + self._display_window = None + + elif not isinstance(dw, (int, float)): + raise TypeError + + self._display_window = dw + + @property + def max_display_datapoints(self) -> int: + return self._max_display_datapoints + + @max_display_datapoints.setter + def max_display_datapoints(self, n: int): + if not isinstance(n, (int, np.integer)): + raise TypeError + if n < 2: + raise ValueError + + self._max_display_datapoints = n + + @property + def multi(self) -> bool: + return self._multi + + @multi.setter + def multi(self, m: bool): + if m and self.data.ndim < 3: + # p is p-datapoints, n is how many lines to show simultaneously (for line collection/stack) + raise ValueError( + "ndim must be >= 3 for multi, shape must be [s1..., sn, n, p, 2 | 3]" + ) + + self._multi = m + + @property + def slider_dims(self) -> tuple[int, ...]: + """slider dimensions""" + return tuple(range(self.ndim - 2 - int(self.multi))) + (self.ndim - 2,) + + @property + def n_slider_dims(self) -> int: + return self.ndim - 1 - int(self.multi) + + # TODO: validation for datapoints_window_func and size + @property + def datapoints_window_func(self) -> tuple[Callable, str] | None: + """ + Callable and str indicating which dims to apply window function along: + 'all', 'x', 'y', 'z', 'xyz', 'xy', 'xz', 'yz' + '""" + return self._datapoints_window_func + + @property + def datapoints_window_size(self) -> Callable | None: + return self._datapoints_window_size + + def _apply_window_functions(self, indices: tuple[int, ...]): + """applies the window functions for each dimension specified""" + # window size for each dim + winds = self._window_sizes + # window function for each dim + funcs = self._window_funcs + + # TODO: use tuple of None for window funcs and sizes to indicate all None, instead of just None + # print(winds) + # print(funcs) + # + # if winds is None or funcs is None: + # # no window funcs or window sizes, just slice data and return + # # clamp to max bounds + # indexer = list() + # print(indices) + # print(self.shape) + # for dim, i in enumerate(indices): + # i = min(self.shape[dim] - 1, i) + # indexer.append(i) + # + # return self.data[tuple(indexer)] + + # order in which window funcs are applied + order = self._window_order + + if order is not None: + # remove any entries in `window_order` where the specified dim + # has a window function or window size specified as `None` + # example: + # window_sizes = (3, 2) + # window_funcs = (np.mean, None) + # order = (0, 1) + # `1` is removed from the order since that window_func is `None` + order = tuple( + d for d in order if winds[d] is not None and funcs[d] is not None + ) + else: + # sequential order + order = list() + for d in range(self.n_slider_dims): + if winds[d] is not None and funcs[d] is not None: + order.append(d) + + # the final indexer which will be used on the data array + indexer = list() + + for dim_index, (i, w, f) in enumerate(zip(indices, winds, funcs)): + # clamp i within the max bounds + i = min(self.shape[dim_index] - 1, i) + + if (w is not None) and (f is not None): + # specify slice window if both window size and function for this dim are not None + hw = int((w - 1) / 2) # half window + + # start index cannot be less than 0 + start = max(0, i - hw) + + # stop index cannot exceed the bounds of this dimension + stop = min(self.shape[dim_index], i + hw) + + s = slice(start, stop, 1) + else: + s = slice(i, i + 1, 1) + + indexer.append(s) + + # apply indexer to slice data with the specified windows + data_sliced = self.data[tuple(indexer)] + + # finally apply the window functions in the specified order + for dim in order: + f = funcs[dim] + + data_sliced = f(data_sliced, axis=dim, keepdims=True) + + return data_sliced + + def _get_dw_slices(self, indices) -> tuple[slice] | tuple[slice, slice]: + # given indices, return slice using display window + + # display window is interpreted using the index mapping for the `p` dim + dw = self.display_window + + if dw is None: + # just map p dimension at this index and return + index_p = self.index_mappings[-1](indices[-1]) + return (slice(index_p, index_p + 1),) + + # display window is in reference units, apply display window and then map to array indices + # clamp w.r.t. 0 and processor shape `p` dim + hw = dw / 2 + index_p_start = max(self.index_mappings[-1](indices[-1] - hw), 0) + index_p_stop = min(self.index_mappings[-1](indices[-1] + hw), self.shape[-2]) + if index_p_start >= index_p_stop: + index_p_stop = index_p_start + 1 + + slices = [slice(index_p_start, index_p_stop)] + + if self.multi: + slices.insert(0, slice(None)) + + return tuple(slices) + + def get(self, indices: tuple[Any, ...]): + """ + slices through all slider dims and outputs an array that can be used to set graphic data + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + """ + # apply any slider index mappings + indices = tuple([m(i) for m, i in zip(self.index_mappings, indices)]) + + if len(indices) > 1: + # there are dims in addition to the n_datapoints dim + # apply window funcs + # window_output array should be of shape [n_datapoints, 2 | 3] + window_output = self._apply_window_functions(indices[:-1]).squeeze() + else: + window_output = self.data + + # TODO: window function on the `p` n_datapoints dimension + + if self.display_window is not None: + slices = self._get_dw_slices(indices) + + # if self.display_window is not None: + # # display window is interpreted using the index mapping for the `p` dim + # dw = self.index_mappings[-1](self.display_window) + # + # if dw == 1: + # slices = [slice(indices[-1], indices[-1] + 1)] + # + # else: + # # half window size + # hw = dw // 2 + # + # # for now assume just a single index provided that indicates x axis value + # start = max(indices[-1] - hw, 0) + # stop = start + dw + # # also add window size of `p` dim so window_func output has the same number of datapoints + # if ( + # self.datapoints_window_func is not None + # and self.datapoints_window_size is not None + # ): + # stop += self.datapoints_window_size - 1 + # # TODO: pad with constant if we're using a window func and the index is near the end + # + # # TODO: uncomment this once we have resizeable buffers!! + # # stop = min(indices[-1] + hw, self.shape[-2]) + # + # slices = [slice(start, stop)] + # + # if self.multi: + # # n - 2 dim is n_lines or n_scatters + # slices.insert(0, slice(None)) + + # data that will be used for the graphical representation + # a copy is made, if there were no window functions then this is a view of the original data + graphic_data = window_output[tuple(slices)] + + dw = self.index_mappings[-1](self.display_window) + + # apply window function on the `p` n_datapoints dim + if ( + self.datapoints_window_func is not None + and self.datapoints_window_size is not None + # if there are too many points to efficiently compute the window func + # applying a window func also requires making a copy so that's a further performance hit + and (dw < self.max_display_datapoints * 2) + ): + # get windows + + # graphic_data will be of shape: [n, p + (ws - 1), 2 | 3] + # where: + # n - number of lines, scatters, heatmap rows + # p - number of datapoints/samples + + wf = self.datapoints_window_func[0] + apply_dims = self.datapoints_window_func[1] + ws = self.datapoints_window_size + + # apply user's window func + # result will be of shape [n, p, 2 | 3] + if apply_dims == "all": + # windows will be of shape [n, p, 1 | 2 | 3, ws] + windows = sliding_window_view(graphic_data, ws, axis=-2) + return wf(windows, axis=-1) + + # map user dims str to tuple of numerical dims + dims = tuple(map({"x": 0, "y": 1, "z": 2}.get, apply_dims)) + + # windows will be of shape [n, p, 1 | 2 | 3, ws] + windows = sliding_window_view( + graphic_data[..., dims], ws, axis=-2 + ).squeeze() + + # make a copy because we need to modify it + graphic_data = graphic_data.copy() + + # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary + # we need to slice upto dw since we add the `datapoints_window_size` above + graphic_data[..., :dw, dims] = wf(windows, axis=-1).reshape( + graphic_data.shape[0], dw, len(dims) + ) + + return graphic_data[ + ..., : dw : max(1, dw // self.max_display_datapoints), : + ] + + return graphic_data[ + ..., + : graphic_data.shape[-2] : max( + 1, graphic_data.shape[-2] // self.max_display_datapoints + ), + :, + ] + + +class NDPositions: + def __init__( + self, + data: Any, + *args, + graphic: Type[ + LineGraphic + | LineCollection + | LineStack + | ScatterGraphic + | ScatterCollection + | ImageGraphic + ], + processor: type[NDPositionsProcessor] = NDPositionsProcessor, + multi: bool = False, + display_window: int = 10, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + index_mappings: tuple[Callable[[Any], int] | None] | None = None, + max_display_datapoints: int = 1_000, + graphic_kwargs: dict = None, + processor_kwargs: dict = None, + ): + if issubclass(graphic, LineCollection): + multi = True + + if processor_kwargs is None: + processor_kwargs = dict() + + self._processor = processor( + data, + *args, + multi=multi, + display_window=display_window, + max_display_datapoints=max_display_datapoints, + window_funcs=window_funcs, + window_sizes=window_sizes, + index_mappings=index_mappings, + **processor_kwargs, + ) + + self._processor.p_max = 1_000 + + self._indices = tuple([0] * self._processor.n_slider_dims) + + self._create_graphic(graphic) + + @property + def processor(self) -> NDPositionsProcessor: + return self._processor + + @property + def graphic( + self, + ) -> ( + LineGraphic + | LineCollection + | LineStack + | ScatterGraphic + | ScatterCollection + | ImageGraphic + ): + """LineStack or ImageGraphic for heatmaps""" + return self._graphic + + @graphic.setter + def graphic(self, graphic_type): + if isinstance(self.graphic, graphic_type): + return + + plot_area = self._graphic._plot_area + plot_area.delete_graphic(self._graphic) + + self._create_graphic(graphic_type) + plot_area.add_graphic(self._graphic) + + @property + def indices(self) -> tuple: + return self._indices + + @indices.setter + def indices(self, indices): + data_slice = self.processor.get(indices) + + if isinstance(self.graphic, (LineGraphic, ScatterGraphic)): + self.graphic.data[:, : data_slice.shape[-1]] = data_slice + + elif isinstance(self.graphic, (LineCollection, ScatterCollection)): + for g, new_data in zip(self.graphic.graphics, data_slice): + if g.data.value.shape[0] != new_data.shape[0]: + # will replace buffer internally + g.data = new_data + else: + # if data are only xy, set only xy + g.data[:, : new_data.shape[1]] = new_data + + elif isinstance(self.graphic, ImageGraphic): + image_data, x0, x_scale = self._create_heatmap_data(data_slice) + self.graphic.data = image_data + self.graphic.offset = (x0, *self.graphic.offset[1:]) + + self._indices = indices + + def _tooltip_handler(self, graphic, pick_info): + if isinstance(self.graphic, (LineCollection, ScatterCollection)): + # get graphic within the collection + n_index = np.argwhere(self.graphic.graphics == graphic).item() + p_index = pick_info["vertex_index"] + return self.processor.tooltip_format(n_index, p_index) + + def _create_graphic( + self, + graphic_cls: Type[ + LineGraphic + | LineCollection + | LineStack + | ScatterGraphic + | ScatterCollection + | ImageGraphic + ], + ): + if not issubclass(graphic_cls, Graphic): + raise TypeError + + data_slice = self.processor.get(self.indices) + + if issubclass(graphic_cls, ImageGraphic): + if not self.processor.multi: + raise ValueError + + if self.processor.shape[-1] != 2: + raise ValueError + + image_data, x0, x_scale = self._create_heatmap_data(data_slice) + self._graphic = graphic_cls( + image_data, offset=(x0, 0, -1), scale=(x_scale, 1, 1) + ) + + else: + if issubclass(graphic_cls, LineStack): + kwargs = {"separation": 0.0} + else: + kwargs = dict() + self._graphic = graphic_cls(data_slice, **kwargs) + + if self.processor.tooltip: + if isinstance(self._graphic, (LineCollection, ScatterCollection)): + for g in self._graphic.graphics: + g.tooltip_format = partial(self._tooltip_handler, g) + + def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: + """return [n_rows, n_cols] shape data""" + # assumes x vals in every row is the same, otherwise a heatmap representation makes no sense + x = data_slice[0, :, 0] # get x from just the first row + + # check if we need to interpolate + norm = np.linalg.norm(np.diff(np.diff(x))) / x.size + + if norm > 1e-6: + # x is not uniform upto float32 precision, must interpolate + x_uniform = np.linspace(x[0], x[-1], num=x.size) + y_interp = np.zeros(shape=data_slice[..., 1].shape, dtype=np.float32) + + # this for loop is actually slightly faster than numpy.apply_along_axis() + for i in range(data_slice.shape[0]): + y_interp[i] = np.interp(x_uniform, x, data_slice[i, :, 1]) + + else: + # x is sufficiently uniform + y_interp = data_slice[..., 1] + + # assume all x values are the same + x_scale = data_slice[:, -1, 0][0] / data_slice.shape[1] + + x0 = data_slice[0, 0, 0] + + return y_interp, x0, x_scale + + @property + def display_window(self) -> int | float | None: + """display window in the reference units for the n_datapoints dim""" + return self.processor.display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + self.processor.display_window = dw + self.indices = self.indices diff --git a/fastplotlib/widgets/nd_widget/nd_image.py b/fastplotlib/widgets/nd_widget/nd_image.py new file mode 100644 index 000000000..4972db9d5 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/nd_image.py @@ -0,0 +1,13 @@ +from typing import Literal + +from .processor_base import NDProcessor + + +class NDImageProcessor(NDProcessor): + @property + def n_display_dims(self) -> Literal[2, 3]: + pass + + def _validate_n_display_dims(self, n_display_dims): + if n_display_dims not in (2, 3): + raise ValueError("`n_display_dims` must be") diff --git a/fastplotlib/widgets/nd_widget/processor_base.py b/fastplotlib/widgets/nd_widget/processor_base.py new file mode 100644 index 000000000..a1cd5311c --- /dev/null +++ b/fastplotlib/widgets/nd_widget/processor_base.py @@ -0,0 +1,251 @@ +import inspect +from typing import Literal, Callable, Any +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +def identity(index: int) -> int: + return index + + +class NDProcessor: + def __init__( + self, + data, + n_display_dims: Literal[2, 3] = 2, + index_mappings: tuple[Callable[[Any], int] | None, ...] | None = None, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + ): + self._data = self._validate_data(data) + self._index_mappings = tuple(self._validate_index_mappings(index_mappings)) + + self.window_funcs = window_funcs + self.window_sizes = window_sizes + self.window_order = window_order + + @property + def data(self) -> ArrayProtocol: + return self._data + + @data.setter + def data(self, data: ArrayProtocol): + self._data = self._validate_data(data) + + @property + def shape(self) -> tuple[int, ...]: + return self.data.shape + + @property + def ndim(self) -> int: + return len(self.shape) + + def _validate_data(self, data: ArrayProtocol): + if not isinstance(data, ArrayProtocol): + raise TypeError("`data` must implement the ArrayProtocol") + + return data + + @property + def tooltip(self) -> bool: + """ + whether or not a custom tooltip formatter method exists + """ + return False + + def tooltip_format(self, *args) -> str | None: + """ + Override in subclass to format custom tooltips + """ + return None + + @property + def slider_dims(self): + raise NotImplementedError + + @property + def n_slider_dims(self): + raise NotImplementedError + + @property + def window_funcs( + self, + ) -> tuple[WindowFuncCallable | None, ...] | None: + """get or set window functions, see docstring for details""" + return self._window_funcs + + @window_funcs.setter + def window_funcs( + self, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None, + ): + if window_funcs is None: + self._window_funcs = tuple([None] * self.n_slider_dims) + return + + if callable(window_funcs): + window_funcs = (window_funcs,) + + # if all are None + # if all([f is None for f in window_funcs]): + # self._window_funcs = tuple(window_funcs) + # return + + self._validate_window_func(window_funcs) + + self._window_funcs = tuple(window_funcs) + # self._recompute_histogram() + + def _validate_window_func(self, funcs): + if isinstance(funcs, (tuple, list)): + for f in funcs: + if f is None: + pass + elif callable(f): + sig = inspect.signature(f) + + if "axis" not in sig.parameters or "keepdims" not in sig.parameters: + raise TypeError( + f"Each window function must take an `axis` and `keepdims` argument, " + f"you passed: {f} with the following function signature: {sig}" + ) + else: + raise TypeError( + f"`window_funcs` must be of type: tuple[Callable | None, ...], you have passed: {funcs}" + ) + + if not (len(funcs) == self.n_slider_dims or self.n_slider_dims == 0): + raise IndexError( + f"number of `window_funcs` must be the same as the number of slider dims: {self.n_slider_dims}, " + f"and you passed {len(funcs)} `window_funcs`: {funcs}" + ) + + @property + def window_sizes(self) -> tuple[int | None, ...] | None: + """get or set window sizes used for the corresponding window functions, see docstring for details""" + return self._window_sizes + + @window_sizes.setter + def window_sizes(self, window_sizes: tuple[int | None, ...] | int | None): + if window_sizes is None: + self._window_sizes = tuple([None] * self.n_slider_dims) + return + + if isinstance(window_sizes, int): + window_sizes = (window_sizes,) + + # if all are None + if all([w is None for w in window_sizes]): + self._window_sizes = None + return + + if not all([isinstance(w, (int)) or w is None for w in window_sizes]): + raise TypeError( + f"`window_sizes` must be of type: tuple[int | None, ...] | int | None, you have passed: {window_sizes}" + ) + + # if not (len(window_sizes) == self.n_slider_dims or self.n_slider_dims == 0): + # raise IndexError( + # f"number of `window_sizes` must be the same as the number of slider dims, " + # f"i.e. `data.ndim` - n_display_dims, your data array has {self.ndim} dimensions " + # f"and you passed {len(window_sizes)} `window_sizes`: {window_sizes}" + # ) + + # make all window sizes are valid numbers + _window_sizes = list() + for i, w in enumerate(window_sizes): + if w is None: + _window_sizes.append(None) + continue + + if w < 0: + raise ValueError( + f"negative window size passed, all `window_sizes` must be positive " + f"integers or `None`, you passed: {_window_sizes}" + ) + + if w == 0 or w == 1: + # this is not a real window, set as None + w = None + + elif w % 2 == 0: + # odd window sizes makes most sense + warn( + f"provided even window size: {w} in dim: {i}, adding `1` to make it odd" + ) + w += 1 + + _window_sizes.append(w) + + self._window_sizes = tuple(_window_sizes) + + @property + def window_order(self) -> tuple[int, ...] | None: + """get or set dimension order in which window functions are applied""" + return self._window_order + + @window_order.setter + def window_order(self, order: tuple[int] | None): + if order is None: + self._window_order = None + return + + if order is not None: + if not all([d <= self.n_slider_dims for d in order]): + raise IndexError( + f"all `window_order` entries must be <= n_slider_dims\n" + f"`n_slider_dims` is: {self.n_slider_dims}, you have passed `window_order`: {order}" + ) + + if not all([d >= 0 for d in order]): + raise IndexError( + f"all `window_order` entires must be >= 0, you have passed: {order}" + ) + + self._window_order = tuple(order) + + @property + def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: + pass + + # @property + # def slider_dims(self) -> tuple[int, ...] | None: + # pass + + @property + def index_mappings(self) -> tuple[Callable[[Any], int]]: + return self._index_mappings + + @index_mappings.setter + def index_mappings(self, maps: tuple[Callable[[Any], int] | None] | None): + self._index_mappings = tuple(self._validate_index_mappings(maps)) + + def _validate_index_mappings(self, maps): + if maps is None: + return tuple([identity] * self.n_slider_dims) + + if len(maps) != self.n_slider_dims: + raise IndexError + + _maps = list() + for m in maps: + if m is None: + _maps.append(identity) + elif callable(m): + _maps.append(identity) + else: + raise TypeError + + return tuple(maps) + + def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: + pass diff --git a/tests/test_colors_buffer_manager.py b/tests/test_colors_buffer_manager.py index 7b1aef16a..f9d56189e 100644 --- a/tests/test_colors_buffer_manager.py +++ b/tests/test_colors_buffer_manager.py @@ -48,10 +48,10 @@ def test_int(test_graphic): data = generate_positions_spiral_data("xyz") if test_graphic == "line": - graphic = fig[0, 0].add_line(data=data) + graphic = fig[0, 0].add_line(data=data, color_mode="vertex") elif test_graphic == "scatter": - graphic = fig[0, 0].add_scatter(data=data) + graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex") colors = graphic.colors global EVENT_RETURN_VALUE @@ -98,10 +98,10 @@ def test_tuple(test_graphic, slice_method): data = generate_positions_spiral_data("xyz") if test_graphic == "line": - graphic = fig[0, 0].add_line(data=data) + graphic = fig[0, 0].add_line(data=data, color_mode="vertex") elif test_graphic == "scatter": - graphic = fig[0, 0].add_scatter(data=data) + graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex") colors = graphic.colors global EVENT_RETURN_VALUE @@ -190,10 +190,10 @@ def test_slice(color_input, slice_method: dict, test_graphic: bool): data = generate_positions_spiral_data("xyz") if test_graphic == "line": - graphic = fig[0, 0].add_line(data=data) + graphic = fig[0, 0].add_line(data=data, color_mode="vertex") elif test_graphic == "scatter": - graphic = fig[0, 0].add_scatter(data=data) + graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex") colors = graphic.colors diff --git a/tests/test_markers_buffer_manager.py b/tests/test_markers_buffer_manager.py index 65ead392e..488bed194 100644 --- a/tests/test_markers_buffer_manager.py +++ b/tests/test_markers_buffer_manager.py @@ -46,10 +46,10 @@ def test_create_buffer(test_graphic): if test_graphic: fig = fpl.Figure() - scatter = fig[0, 0].add_scatter(data, markers=MARKERS1) + scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False) vertex_markers = scatter.markers assert isinstance(vertex_markers, VertexMarkers) - assert vertex_markers.buffer is scatter.world_object.geometry.markers + assert vertex_markers._fpl_buffer is scatter.world_object.geometry.markers else: vertex_markers = VertexMarkers(MARKERS1, len(data)) @@ -68,7 +68,7 @@ def test_int(test_graphic, index: int): if test_graphic: fig = fpl.Figure() - scatter = fig[0, 0].add_scatter(data, markers=MARKERS1) + scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False) scatter.add_event_handler(event_handler, "markers") vertex_markers = scatter.markers else: @@ -108,7 +108,7 @@ def test_slice(test_graphic, slice_method): if test_graphic: fig = fpl.Figure() - scatter = fig[0, 0].add_scatter(data, markers=MARKERS1) + scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False) scatter.add_event_handler(event_handler, "markers") vertex_markers = scatter.markers diff --git a/tests/test_point_rotations_buffer_manager.py b/tests/test_point_rotations_buffer_manager.py index ec5fdbe0f..50ee88984 100644 --- a/tests/test_point_rotations_buffer_manager.py +++ b/tests/test_point_rotations_buffer_manager.py @@ -35,7 +35,7 @@ def test_create_buffer(test_graphic): scatter = fig[0, 0].add_scatter(data, point_rotation_mode="vertex", point_rotations=ROTATIONS1) vertex_rotations = scatter.point_rotations assert isinstance(vertex_rotations, VertexRotations) - assert vertex_rotations.buffer is scatter.world_object.geometry.rotations + assert vertex_rotations._fpl_buffer is scatter.world_object.geometry.rotations else: vertex_rotations = VertexRotations(ROTATIONS1, len(data)) diff --git a/tests/test_positions_data_buffer_manager.py b/tests/test_positions_data_buffer_manager.py index e2582d4ba..cc550abf0 100644 --- a/tests/test_positions_data_buffer_manager.py +++ b/tests/test_positions_data_buffer_manager.py @@ -57,7 +57,7 @@ def test_int(test_graphic): graphic = fig[0, 0].add_scatter(data=data) points = graphic.data - assert graphic.data.buffer is graphic.world_object.geometry.positions + assert graphic.data._fpl_buffer is graphic.world_object.geometry.positions global EVENT_RETURN_VALUE graphic.add_event_handler(event_handler, "data") else: diff --git a/tests/test_positions_graphics.py b/tests/test_positions_graphics.py index 31c001888..4bc93b626 100644 --- a/tests/test_positions_graphics.py +++ b/tests/test_positions_graphics.py @@ -37,12 +37,12 @@ def test_sizes_slice(): @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [None, *generate_color_inputs("b")]) -@pytest.mark.parametrize("uniform_color", [True, False]) -def test_uniform_color(graphic_type, colors, uniform_color): +@pytest.mark.parametrize("color_mode", ["uniform", "vertex"]) +def test_color_mode(graphic_type, colors, color_mode): fig = fpl.Figure() kwargs = dict() - for kwarg in ["colors", "uniform_color"]: + for kwarg in ["colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -54,7 +54,7 @@ def test_uniform_color(graphic_type, colors, uniform_color): elif graphic_type == "scatter": graphic = fig[0, 0].add_scatter(data=data, **kwargs) - if uniform_color: + if color_mode == "uniform": assert isinstance(graphic._colors, UniformColor) assert isinstance(graphic.colors, pygfx.Color) if colors is None: @@ -130,17 +130,17 @@ def test_positions_graphics_data( @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [None, *generate_color_inputs("r")]) -@pytest.mark.parametrize("uniform_color", [None, False]) +@pytest.mark.parametrize("color_mode", ["vertex"]) def test_positions_graphic_vertex_colors( graphic_type, colors, - uniform_color, + color_mode, ): # test different ways of passing vertex colors fig = fpl.Figure() kwargs = dict() - for kwarg in ["colors", "uniform_color"]: + for kwarg in ["colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -153,10 +153,9 @@ def test_positions_graphic_vertex_colors( graphic = fig[0, 0].add_scatter(data=data, **kwargs) # color per vertex - # uniform colors is default False, or set to False - assert isinstance(graphic._colors, VertexColors) - assert isinstance(graphic.colors, VertexColors) - assert len(graphic.colors) == len(graphic.data) + assert isinstance(graphic._colors, VertexColors) + assert isinstance(graphic.colors, VertexColors) + assert len(graphic.colors) == len(graphic.data) if colors is None: # default @@ -179,7 +178,7 @@ def test_positions_graphic_vertex_colors( @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [None, *generate_color_inputs("r")]) -@pytest.mark.parametrize("uniform_color", [None, False]) +@pytest.mark.parametrize("color_mode", ["auto", "vertex"]) @pytest.mark.parametrize("cmap", ["jet"]) @pytest.mark.parametrize( "cmap_transform", [None, [3, 5, 2, 1, 0, 6, 9, 7, 4, 8], np.arange(9, -1, -1)] @@ -187,7 +186,7 @@ def test_positions_graphic_vertex_colors( def test_cmap( graphic_type, colors, - uniform_color, + color_mode, cmap, cmap_transform, ): @@ -195,7 +194,7 @@ def test_cmap( fig = fpl.Figure() kwargs = dict() - for kwarg in ["cmap", "cmap_transform", "colors", "uniform_color"]: + for kwarg in ["cmap", "cmap_transform", "colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -220,7 +219,8 @@ def test_cmap( # make sure buffer is identical # cmap overrides colors argument - assert graphic.colors.buffer is graphic.cmap.buffer + # use __repr__.__self__ to get the real reference from the cmap feature instead of the weakref proxy + assert graphic.colors._fpl_buffer is graphic.cmap.buffer.__repr__.__self__ npt.assert_almost_equal(graphic.cmap.value, truth) npt.assert_almost_equal(graphic.colors.value, truth) @@ -261,14 +261,14 @@ def test_cmap( "colors", [None, *generate_color_inputs("multi")] ) # cmap arg overrides colors @pytest.mark.parametrize( - "uniform_color", [True] # none of these will work with a uniform buffer + "color_mode", ["uniform"] # none of these will work with a uniform buffer ) -def test_incompatible_cmap_color_args(graphic_type, cmap, colors, uniform_color): +def test_incompatible_cmap_color_args(graphic_type, cmap, colors, color_mode): # test incompatible cmap args fig = fpl.Figure() kwargs = dict() - for kwarg in ["cmap", "colors", "uniform_color"]: + for kwarg in ["cmap", "colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -276,24 +276,24 @@ def test_incompatible_cmap_color_args(graphic_type, cmap, colors, uniform_color) data = generate_positions_spiral_data("xy") if graphic_type == "line": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_line(data=data, **kwargs) elif graphic_type == "scatter": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_scatter(data=data, **kwargs) @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [*generate_color_inputs("multi")]) @pytest.mark.parametrize( - "uniform_color", [True] # none of these will work with a uniform buffer + "color_mode", ["uniform"] # none of these will work with a uniform buffer ) -def test_incompatible_color_args(graphic_type, colors, uniform_color): +def test_incompatible_color_args(graphic_type, colors, color_mode): # test incompatible color args fig = fpl.Figure() kwargs = dict() - for kwarg in ["colors", "uniform_color"]: + for kwarg in ["colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -301,16 +301,15 @@ def test_incompatible_color_args(graphic_type, colors, uniform_color): data = generate_positions_spiral_data("xy") if graphic_type == "line": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_line(data=data, **kwargs) elif graphic_type == "scatter": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_scatter(data=data, **kwargs) @pytest.mark.parametrize("sizes", [None, 5.0, np.linspace(3, 8, 10, dtype=np.float32)]) -@pytest.mark.parametrize("uniform_size", [None, False]) -def test_sizes(sizes, uniform_size): +def test_sizes(sizes): # test scatter sizes fig = fpl.Figure() @@ -322,7 +321,7 @@ def test_sizes(sizes, uniform_size): data = generate_positions_spiral_data("xy") - graphic = fig[0, 0].add_scatter(data=data, **kwargs) + graphic = fig[0, 0].add_scatter(data=data, uniform_size=False, **kwargs) assert isinstance(graphic.sizes, VertexPointSizes) assert isinstance(graphic._sizes, VertexPointSizes) diff --git a/tests/test_replace_buffer.py b/tests/test_replace_buffer.py new file mode 100644 index 000000000..a9d0ffe41 --- /dev/null +++ b/tests/test_replace_buffer.py @@ -0,0 +1,155 @@ +import gc +import weakref + +import pytest +import numpy as np +from itertools import product + +import fastplotlib as fpl +from .utils_textures import MAX_TEXTURE_SIZE, check_texture_array, check_image_graphic + +# These are only de-referencing tests for positions graphics, and ImageGraphic +# they do not test that VRAM gets free, for now this can only be checked manually +# with the tests in examples/misc/buffer_replace_gc.py + + +@pytest.mark.parametrize("graphic_type", ["line", "scatter"]) +@pytest.mark.parametrize("new_buffer_size", [50, 150]) +def test_replace_positions_buffer(graphic_type, new_buffer_size): + fig = fpl.Figure() + + # create some data with an initial shape + orig_datapoints = 100 + + xs = np.linspace(0, 2 * np.pi, orig_datapoints) + ys = np.sin(xs) + zs = np.cos(xs) + + data = np.column_stack([xs, ys, zs]) + + # add add_line or add_scatter method + adder = getattr(fig[0, 0], f"add_{graphic_type}") + + if graphic_type == "scatter": + kwargs = { + "markers": np.random.choice(list("osD+x^v<>*"), size=orig_datapoints), + "uniform_marker": False, + "sizes": np.abs(ys), + "uniform_size": False, + # TODO: skipping edge_colors for now since that causes a WGPU bind group error that we will figure out later + # anyways I think changing buffer sizes in combination with per-vertex edge colors is a literal edge-case + "point_rotations": zs * 180, + "point_rotation_mode": "vertex", + } + else: + kwargs = dict() + + # add a line or scatter graphic + graphic = adder(data=data, colors=np.random.rand(orig_datapoints, 4), **kwargs) + + fig.show() + + # weakrefs to the original buffers + # these should raise a ReferenceError when the corresponding feature is replaced with data of a different shape + orig_data_buffer = weakref.proxy(graphic.data._fpl_buffer) + orig_colors_buffer = weakref.proxy(graphic.colors._fpl_buffer) + + buffers = [orig_data_buffer, orig_colors_buffer] + + # extra buffers for the scatters + if graphic_type == "scatter": + for attr in ["markers", "sizes", "point_rotations"]: + buffers.append(weakref.proxy(getattr(graphic, attr)._fpl_buffer)) + + # create some new data that requires a different buffer shape + xs = np.linspace(0, 15 * np.pi, new_buffer_size) + ys = np.sin(xs) + zs = np.cos(xs) + + new_data = np.column_stack([xs, ys, zs]) + + # set data that requires a larger buffer and check that old buffer is no longer referenced + graphic.data = new_data + graphic.colors = np.random.rand(new_buffer_size, 4) + + if graphic_type == "scatter": + # changes values so that new larger buffers must be allocated + graphic.markers = np.random.choice(list("osD+x^v<>*"), size=new_buffer_size) + graphic.sizes = np.abs(zs) + graphic.point_rotations = ys * 180 + + # make sure old original buffers are de-referenced + for i in range(len(buffers)): + with pytest.raises(ReferenceError) as fail: + buffers[i] + pytest.fail( + f"GC failed for buffer: {buffers[i]}, " + f"with referrers: {gc.get_referrers(buffers[i].__repr__.__self__)}" + ) + + +# test all combination of dims that require TextureArrays of shapes 1x1, 1x2, 1x3, 2x3, 3x3 etc. +@pytest.mark.parametrize( + "new_buffer_size", list(product(*[[(500, 1), (1200, 2), (2200, 3)]] * 2)) +) +def test_replace_image_buffer(new_buffer_size): + # make an image with some starting shape + orig_size = (1_500, 1_500) + + data = np.random.rand(*orig_size) + + fig = fpl.Figure() + image = fig[0, 0].add_image(data) + + # the original Texture buffers that represent the individual image tiles + orig_buffers = [ + weakref.proxy(image.data.buffer.ravel()[i]) + for i in range(image.data.buffer.size) + ] + orig_shape = image.data.buffer.shape + + fig.show() + + # dimensions for a new image + new_dims = [v[0] for v in new_buffer_size] + + # the number of tiles required in each dim/shape of the TextureArray + new_shape = tuple(v[1] for v in new_buffer_size) + + # make the new data and set the image + new_data = np.random.rand(*new_dims) + image.data = new_data + + # test that old Texture buffers are de-referenced + for i in range(len(orig_buffers)): + with pytest.raises(ReferenceError) as fail: + orig_buffers[i] + pytest.fail( + f"GC failed for buffer: {orig_buffers[i]}, of shape: {orig_shape}" + f"with referrers: {gc.get_referrers(orig_buffers[i].__repr__.__self__)}" + ) + + # check new texture array + check_texture_array( + data=new_data, + ta=image.data, + buffer_size=np.prod(new_shape), + buffer_shape=new_shape, + row_indices_size=new_shape[0], + col_indices_size=new_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (new_data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (new_data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), + ) + + # check that new image tiles are arranged correctly + check_image_graphic(image.data, image) diff --git a/tests/test_scatter_graphic.py b/tests/test_scatter_graphic.py index a61681f24..930d8c495 100644 --- a/tests/test_scatter_graphic.py +++ b/tests/test_scatter_graphic.py @@ -133,7 +133,7 @@ def test_edge_colors(edge_colors): npt.assert_almost_equal(scatter.edge_colors.value, MULTI_COLORS_TRUTH) assert ( - scatter.edge_colors.buffer is scatter.world_object.geometry.edge_colors + scatter.edge_colors._fpl_buffer is scatter.world_object.geometry.edge_colors ) # test changes, don't need to test extensively here since it's tested in the main VertexColors test diff --git a/tests/test_texture_array.py b/tests/test_texture_array.py index 6220f2fe5..01abb9a97 100644 --- a/tests/test_texture_array.py +++ b/tests/test_texture_array.py @@ -2,14 +2,9 @@ from numpy import testing as npt import pytest -import pygfx - import fastplotlib as fpl from fastplotlib.graphics.features import TextureArray -from fastplotlib.graphics.image import _ImageTile - - -MAX_TEXTURE_SIZE = 1024 +from .utils_textures import MAX_TEXTURE_SIZE, check_texture_array, check_image_graphic def make_data(n_rows: int, n_cols: int) -> np.ndarray: @@ -25,50 +20,6 @@ def make_data(n_rows: int, n_cols: int) -> np.ndarray: return np.vstack([sine * i for i in range(n_rows)]).astype(np.float32) -def check_texture_array( - data: np.ndarray, - ta: TextureArray, - buffer_size: int, - buffer_shape: tuple[int, int], - row_indices_size: int, - col_indices_size: int, - row_indices_values: np.ndarray, - col_indices_values: np.ndarray, -): - - npt.assert_almost_equal(ta.value, data) - - assert ta.buffer.size == buffer_size - assert ta.buffer.shape == buffer_shape - - assert all([isinstance(texture, pygfx.Texture) for texture in ta.buffer.ravel()]) - - assert ta.row_indices.size == row_indices_size - assert ta.col_indices.size == col_indices_size - npt.assert_array_equal(ta.row_indices, row_indices_values) - npt.assert_array_equal(ta.col_indices, col_indices_values) - - # make sure chunking is correct - for texture, chunk_index, data_slice in ta: - assert ta.buffer[chunk_index] is texture - chunk_row, chunk_col = chunk_index - - data_row_start_index = chunk_row * MAX_TEXTURE_SIZE - data_col_start_index = chunk_col * MAX_TEXTURE_SIZE - - data_row_stop_index = min( - data.shape[0], data_row_start_index + MAX_TEXTURE_SIZE - ) - data_col_stop_index = min( - data.shape[1], data_col_start_index + MAX_TEXTURE_SIZE - ) - - row_slice = slice(data_row_start_index, data_row_stop_index) - col_slice = slice(data_col_start_index, data_col_stop_index) - - assert data_slice == (row_slice, col_slice) - - def check_set_slice(data, ta, row_slice, col_slice): ta[row_slice, col_slice] = 1 npt.assert_almost_equal(ta[row_slice, col_slice], 1) @@ -85,17 +36,6 @@ def make_image_graphic(data) -> fpl.ImageGraphic: return fig[0, 0].add_image(data) -def check_image_graphic(texture_array, graphic): - # make sure each ImageTile has the right texture - for (texture, chunk_index, data_slice), img in zip( - texture_array, graphic.world_object.children - ): - assert isinstance(img, _ImageTile) - assert img.geometry.grid is texture - assert img.world.x == data_slice[1].start - assert img.world.y == data_slice[0].start - - @pytest.mark.parametrize("test_graphic", [False, True]) def test_small_texture(test_graphic): # tests TextureArray with dims that requires only 1 texture @@ -162,15 +102,27 @@ def test_wide(test_graphic): else: ta = TextureArray(data) + ta_shape = (2, 3) + check_texture_array( data, ta=ta, - buffer_size=6, - buffer_shape=(2, 3), - row_indices_size=2, - col_indices_size=3, - row_indices_values=np.array([0, MAX_TEXTURE_SIZE]), - col_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), + buffer_size=np.prod(ta_shape), + buffer_shape=ta_shape, + row_indices_size=ta_shape[0], + col_indices_size=ta_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), ) if test_graphic: @@ -189,15 +141,27 @@ def test_tall(test_graphic): else: ta = TextureArray(data) + ta_shape = (3, 2) + check_texture_array( data, ta=ta, - buffer_size=6, - buffer_shape=(3, 2), - row_indices_size=3, - col_indices_size=2, - row_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), - col_indices_values=np.array([0, MAX_TEXTURE_SIZE]), + buffer_size=np.prod(ta_shape), + buffer_shape=ta_shape, + row_indices_size=ta_shape[0], + col_indices_size=ta_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), ) if test_graphic: @@ -216,15 +180,27 @@ def test_square(test_graphic): else: ta = TextureArray(data) + ta_shape = (3, 3) + check_texture_array( data, ta=ta, - buffer_size=9, - buffer_shape=(3, 3), - row_indices_size=3, - col_indices_size=3, - row_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), - col_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), + buffer_size=np.prod(ta_shape), + buffer_shape=ta_shape, + row_indices_size=ta_shape[0], + col_indices_size=ta_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), ) if test_graphic: diff --git a/tests/utils_textures.py b/tests/utils_textures.py new file mode 100644 index 000000000..f40a7371c --- /dev/null +++ b/tests/utils_textures.py @@ -0,0 +1,64 @@ +import numpy as np +import pygfx +from numpy import testing as npt + +from fastplotlib.graphics.features import TextureArray +from fastplotlib.graphics.image import _ImageTile + + +MAX_TEXTURE_SIZE = 1024 + + +def check_texture_array( + data: np.ndarray, + ta: TextureArray, + buffer_size: int, + buffer_shape: tuple[int, int], + row_indices_size: int, + col_indices_size: int, + row_indices_values: np.ndarray, + col_indices_values: np.ndarray, +): + + npt.assert_almost_equal(ta.value, data) + + assert ta.buffer.size == buffer_size + assert ta.buffer.shape == buffer_shape + + assert all([isinstance(texture, pygfx.Texture) for texture in ta.buffer.ravel()]) + + assert ta.row_indices.size == row_indices_size + assert ta.col_indices.size == col_indices_size + npt.assert_array_equal(ta.row_indices, row_indices_values) + npt.assert_array_equal(ta.col_indices, col_indices_values) + + # make sure chunking is correct + for texture, chunk_index, data_slice in ta: + assert ta.buffer[chunk_index] is texture + chunk_row, chunk_col = chunk_index + + data_row_start_index = chunk_row * MAX_TEXTURE_SIZE + data_col_start_index = chunk_col * MAX_TEXTURE_SIZE + + data_row_stop_index = min( + data.shape[0], data_row_start_index + MAX_TEXTURE_SIZE + ) + data_col_stop_index = min( + data.shape[1], data_col_start_index + MAX_TEXTURE_SIZE + ) + + row_slice = slice(data_row_start_index, data_row_stop_index) + col_slice = slice(data_col_start_index, data_col_stop_index) + + assert data_slice == (row_slice, col_slice) + + +def check_image_graphic(texture_array, graphic): + # make sure each ImageTile has the right texture + for (texture, chunk_index, data_slice), img in zip( + texture_array, graphic.world_object.children + ): + assert isinstance(img, _ImageTile) + assert img.geometry.grid is texture + assert img.world.x == data_slice[1].start + assert img.world.y == data_slice[0].start