diff --git a/404.html b/404.html deleted file mode 100644 index 086a5c9..0000000 --- a/404.html +++ /dev/null @@ -1,25 +0,0 @@ ---- -permalink: /404.html -layout: default ---- - - - -
Page not found :(
-The requested page could not be found.
-I stumbled upon an interesting Stackoverflow question that was linked via an issue on Polars github repo. The OP asked for a pure Polars solution. At the time of answering the question Polars did not have support for non-equi joins, and any solution using it would be pretty cumbersome.
+I’m more of a right-tool-for-the-job person, so I tried to find a better solution.
+Suppose we have a dataset that captures the arrival and departure times of trucks at a station, along with the truck’s ID.
+ + + + + + 1import polars as pl # if you don't have polars, run
+ 2 # pip install 'polars[all]'
+ 3data = pl.from_repr("""
+ 4┌─────────────────────┬─────────────────────┬─────┐
+ 5│ arrival_time ┆ departure_time ┆ ID │
+ 6│ --- ┆ --- ┆ --- │
+ 7│ datetime[μs] ┆ datetime[μs] ┆ str │
+ 8╞═════════════════════╪═════════════════════╪═════╡
+ 9│ 2023-01-01 06:23:47 ┆ 2023-01-01 06:25:08 ┆ A1 │
+10│ 2023-01-01 06:26:42 ┆ 2023-01-01 06:28:02 ┆ A1 │
+11│ 2023-01-01 06:30:20 ┆ 2023-01-01 06:35:01 ┆ A5 │
+12│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+13│ 2023-01-01 06:33:09 ┆ 2023-01-01 06:36:01 ┆ B3 │
+14│ 2023-01-01 06:34:08 ┆ 2023-01-01 06:39:49 ┆ C3 │
+15│ 2023-01-01 06:36:40 ┆ 2023-01-01 06:38:34 ┆ A6 │
+16│ 2023-01-01 06:37:43 ┆ 2023-01-01 06:40:48 ┆ A5 │
+17│ 2023-01-01 06:39:48 ┆ 2023-01-01 06:46:10 ┆ A6 │
+18└─────────────────────┴─────────────────────┴─────┘
+19""")We want to identify the number of trucks docked at any given time within a threshold of 1 minute prior to the arrival time of a truck, and 1 minute after the departure of a truck. Equivalently, this means that we need to calculate the number of trucks within a specific window for each row of the data.
+Before we find a general solution to this problem, let’s consider a specific row to understand the problem better:
+ + + + + +1"""
+2┌─────────────────────┬─────────────────────┬─────┐
+3│ arrival_time ┆ departure_time ┆ ID │
+4│ --- ┆ --- ┆ --- │
+5│ datetime[μs] ┆ datetime[μs] ┆ str │
+6╞═════════════════════╪═════════════════════╪═════╡
+7│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+8└─────────────────────┴─────────────────────┴─────┘
+9"""For this row, we need to find the number of trucks that are there between 2023-01-01 06:31:06 (1 minute prior to the arrival_time and 2023-01-01 06:34:48 (1 minute post the departure_time). Manually going through the original dataset, we see that B3, C3, A6 and A5 are the truck IDs that qualify - they all are at the station in a duration that is between 2023-01-01 06:31:06 and 2023-01-01 06:34:48.
There are many cases that will qualify a truck to be present in the overlap window defined by a particular row. Specifically for the example above, we have (this visualization is generalizable, because for each row we can calculate without much difficulty the overlap window relative to the arrival and departure times):
+
Take some time to absorb these cases - it’s important for the part where we write the code for the solution. Note that we need to actually tell our algorithm to filter only for Cases 2, 3 and 4, since Cases 1 and 5 will not satisfy our requirements.
+In theory, we can use any language that has the capability to define rules that meet our algorithmic requirements outlined in the above section to find the solution. Why choose SQL? It’s often able to convey elegantly the logic that was used to execute the algorithm; and while it does come with excessive verbosity at times, it doesn’t quite in this case.
+Note here that we run SQL in Python with almost no setup or boilerplate code - so this is a Python based solution as well (although not quite Pythonic!).
+Once again, in theory, any SQL package or language can be used. Far too few however meet the ease-of-use that DuckDB provides:
+pip install duckdb),all with mind-blowing speed that stands shoulder-to-shoulder with Polars. We’ll also use a few advanced SQL concepts noted below.
+This should be a familiar, albeit not often used concept - a join of a table with itself is a self join. There are few cases where such an operation would make sense, and this happens to be one of them.
+A key concept that we’ll use is the idea of joining on a range of values rather than a specific value. That is, instead of the usual LEFT JOIN ON A.column = B.column, we can do LEFT JOIN ON A.column <= B.column for one row in table A to match to multiple rows in B. DuckDB has a blog post that outlines this join in detail, including fast implementation.
LIST columnsDuckDB has first class support for LIST columns - that is, each row in a LIST column can have a varying length (much like a Python list), but must have the exact same datatype (like R’s vector). Using list columns allow us to eschew the use of an additional GROUP BY operation on top of a WHERE filter or SELECT DISTINCT operation, since we can directly perform those on the LIST column itself.
Dates can be rather difficult to handle well in most tools and languages, with several packages purpose built to make handling them easier - lubridate from the tidyverse is a stellar example. Thankfully, DuckDB provides a similar swiss-knife set of tools to deal with it, including specifying INTERVALs (a special data type that represent a period of time independent of specific time values) to modify TIMESTAMP values using addition or subtraction.
Okay - had a lot of background. Let’s have at it! The query by itself in SQL is (see immediately below for runnable code in Python):
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN (
+16 SELECT *
+17 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+18 FROM data) B
+19
+20ON ((B.arrival_time <= A.window_open AND
+21 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+22 (B.arrival_time >= A.window_open AND
+23 B.departure_time <= A.window_close) OR
+24 (B.arrival_time >= A.window_open AND
+25 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close))
+26GROUP BY 1, 2, 3, 4A small, succinct query such as this will need a bit of explanation to take it all in. Here’s one below, reproducible in Python (make sure to install duckdb first!). Expand it to view.
1import duckdb as db
+ 2db.query("""
+ 3 SELECT
+ 4 A.arrival_time
+ 5 ,A.departure_time
+ 6 ,A.window_open
+ 7 ,A.window_close
+ 8 -- LIST aggregates the values into a LIST column
+ 9 -- and LIST_DISTINCT finds the unique values in it
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 -- finally, LIST_UNIQUE calculates the unique number of values in it
+12 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+13
+14 FROM (
+15 SELECT
+16 *
+17 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+18 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+19 FROM data -- remember we defined data as the Polars DataFrame with our truck station data
+20 ) A
+21
+22 LEFT JOIN (
+23 SELECT
+24 *
+25 -- This is the time, in seconds between the arrival and departure of
+26 -- each truck PER ROW in the original data-frame
+27 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+28 FROM data -- this is where we perform a self-join
+29 ) B
+30
+31 ON (
+32 -- Case 2 in the diagram;
+33 (B.arrival_time <= A.window_open AND
+34 -- Adding the duration here makes sure that the second interval
+35 -- is at least ENDING AFTER the start of the overlap window
+36 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+37
+38 -- Case 3 in the diagram - the simplest of all five cases
+39 (B.arrival_time >= A.window_open AND
+40 B.departure_time <= A.window_close) OR
+41
+42 -- Case 4 in the digram;
+43 (B.arrival_time >= A.window_open AND
+44 -- Subtracting the duration here makes sure that the second interval
+45 -- STARTS BEFORE the end of the overlap window.
+46 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)
+47 )
+48 GROUP BY 1, 2, 3, 4
+49""")The output of this query is:
+ + + + + +"""
+┌─────────────────────┬─────────────────────┬─────────────────────┬───┬──────────────────┬────────────────────┐
+│ arrival_time │ departure_time │ window_open │ … │ docked_trucks │ docked_truck_count │
+│ timestamp │ timestamp │ timestamp │ │ varchar[] │ uint64 │
+├─────────────────────┼─────────────────────┼─────────────────────┼───┼──────────────────┼────────────────────┤
+│ 2023-01-01 06:23:47 │ 2023-01-01 06:25:08 │ 2023-01-01 06:22:47 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:26:42 │ 2023-01-01 06:28:02 │ 2023-01-01 06:25:42 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:30:20 │ 2023-01-01 06:35:01 │ 2023-01-01 06:29:20 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:32:06 │ 2023-01-01 06:33:48 │ 2023-01-01 06:31:06 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:33:09 │ 2023-01-01 06:36:01 │ 2023-01-01 06:32:09 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:34:08 │ 2023-01-01 06:39:49 │ 2023-01-01 06:33:08 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:36:40 │ 2023-01-01 06:38:34 │ 2023-01-01 06:35:40 │ … │ [A5, A6, C3, B3] │ 4 │
+│ 2023-01-01 06:37:43 │ 2023-01-01 06:40:48 │ 2023-01-01 06:36:43 │ … │ [A5, A6, C3] │ 3 │
+│ 2023-01-01 06:39:48 │ 2023-01-01 06:46:10 │ 2023-01-01 06:38:48 │ … │ [A6, A5, C3] │ 3 │
+├─────────────────────┴─────────────────────┴─────────────────────┴───┴──────────────────┴────────────────────┤
+│ 9 rows 6 columns (5 shown) │
+└─────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
+"""We clearly see the strengths of DuckDB in how succintly we were able to express this operation. We also find how DuckDB is able to seamlessly integrate with an existing Pandas or Polars pipeline with zero-conversion costs. In fact, we can convert this back to a Polars or Pandas dataframe by appending the ending bracket with db.query(...).pl() and db.query(...).pd() respectively.
Now that we’ve understood the logic that goes into the query, let’s try to optimize the algorithm. We have the three conditions:
+ + + + + +1-- Case 2 in the diagram
+2(B.arrival_time <= A.window_open AND
+3 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+4-- Case 3 in the diagram
+5(B.arrival_time >= A.window_open AND
+6 B.departure_time <= A.window_close) OR
+7-- Case 4 in the diagram
+8(B.arrival_time >= A.window_open AND
+9 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)What is common between these three conditions? It takes a while to see it; but it becomes clear that all these cases require the start of the overlap to be before the window ends, and the end of the overlap to be after the window starts. This can be simplified to just:
+ + + + + +1B.arrival_time <= A.window_close AND
+2B.departure_time >= A.window_openmaking our query much simpler!
+We’ve removed the need for the duration calculation algother now. Therefore, we can write:
1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN data B
+16
+17ON (
+18 B.arrival_time <= A.window_close AND
+19 B.departure_time >= A.window_open
+20)
+21GROUP BY 1, 2, 3, 4Can we simplify this even further?
+I think the SQL query in the above section is very easy to ready already. However, it is a little clunky overall, and there is a way that we can leverage DuckDB’s extensive optimizations to simplify our legibility by rewriting the query as a cross join:
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 5 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8FROM data A, data B
+ 9WHERE B.arrival_time <= window_close
+10AND B.departure_time >= window_open
+11GROUP BY 1, 2, 3, 4Why does this work? Before optimization on DuckDB, this is what the query plan looks like:
+ 1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ FILTER │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ (arrival_time <= │
+25│(departure_time + to_m... │
+26│ AS BIGINT)))) │
+27│ (departure_time >= │
+28│(arrival_time - to_min... │
+29│ AS BIGINT)))) │
+30└─────────────┬─────────────┘
+31┌─────────────┴─────────────┐
+32│ CROSS_PRODUCT ├──────────────┐
+33└─────────────┬─────────────┘ │
+34┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+35│ ARROW_SCAN ││ ARROW_SCAN │
+36└───────────────────────────┘└───────────────────────────┘
+37""" After optimization, the CROSS_PRODUCT is automatically optimized to an interval join!
1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ COMPARISON_JOIN │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ INNER │
+25│ ((departure_time + '00:01 │
+26│ :00'::INTERVAL) >= ├──────────────┐
+27│ arrival_time) │ │
+28│((arrival_time - '00:01:00'│ │
+29│ ::INTERVAL) <= │ │
+30│ departure_time) │ │
+31└─────────────┬─────────────┘ │
+32┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+33│ ARROW_SCAN ││ ARROW_SCAN │
+34└───────────────────────────┘└───────────────────────────┘
+35""" So in effect, we’re actually exploiting a feature of DuckDB to allow us to write our queries in a suboptimal manner for greater readability, and allowing the optmizer to do a good chunk of our work for us. I wouldn’t recommend using this generally, because not all SQL engine optmizers will be able to find an efficient route to these calculations for large datasets.
+I’m glad you asked. Here’s the DuckDB page explaining EXPLAIN (heh). Here’s the code I used:
1import duckdb as db
+ 2db.sql("SET EXPLAIN_OUTPUT='all';")
+ 3print(db.query("""
+ 4EXPLAIN
+ 5SELECT
+ 6 A.arrival_time
+ 7 ,A.departure_time
+ 8 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 9 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+12FROM data A, data B
+13WHERE B.arrival_time <= window_close
+14AND B.departure_time >= window_open
+15GROUP BY 1, 2, 3, 4
+16""").pl()[1, 1])data.table waydata.table is a package that has historically been ahead of its time - in both speed and features that it has had. Developement has taken a hit recently, but will likely pick back up. It’s my favourite package on all fronts for data manipulation, but suffers simply from the lack of broader R support across the ML and DL space.
foverlaps functionIf this kind of overlapping join is common, shouldn’t someone have developed a package for it? Turns out, data.table has, and with very specific constraints that make it the perfect solution to our problem (if you don’t mind switching over to R, that is).
The foverlaps function has these requirements:
data.table objects have to be keyed for automatic recognition of columns.within overlap, matching start and end windows,by must specify the start and end points of the overlapping window. This isn’t a problem for us now, but does restrict for future uses where we may want non-equi joins on other cases.Without further ado:
+ + + + + + 1library(data.table)
+ 2library(lubridate)
+ 3
+ 4######### BOILERPLATE CODE, NO LOGIC HERE ####################
+ 5arrival_time = as_datetime(c(
+ 6 '2023-01-01 06:23:47.000000', '2023-01-01 06:26:42.000000',
+ 7 '2023-01-01 06:30:20.000000', '2023-01-01 06:32:06.000000',
+ 8 '2023-01-01 06:33:09.000000', '2023-01-01 06:34:08.000000',
+ 9 '2023-01-01 06:36:40.000000', '2023-01-01 06:37:43.000000',
+10 '2023-01-01 06:39:48.000000'))
+11departure_time = as_datetime(c(
+12 '2023-01-01 06:25:08.000000', '2023-01-01 06:28:02.000000',
+13 '2023-01-01 06:35:01.000000', '2023-01-01 06:33:48.000000',
+14 '2023-01-01 06:36:01.000000', '2023-01-01 06:39:49.000000',
+15 '2023-01-01 06:38:34.000000', '2023-01-01 06:40:48.000000',
+16 '2023-01-01 06:46:10.000000'))
+17ID = c('A1', 'A1', 'A5', 'A6', 'B3', 'C3', 'A6', 'A5', 'A6')
+18
+19DT = data.table(
+20 arrival_time = arrival_time,
+21 departure_time = departure_time,
+22 ID = ID)
+23######### BOILERPLATE CODE, NO LOGIC HERE ####################
+24
+25# A copy(DT) creates a copy of a data.table that isn't linked
+26# to the original one, so that changes in it don't reflect in
+27# the original DT object.
+28# The `:=` allow assignment by reference (i.e. "in place").
+29DT_with_windows = copy(DT)[, `:=`(
+30 window_start = arrival_time - minutes(1),
+31 window_end = departure_time + minutes(1))]
+32
+33# This step is necessary for the second table, but not the first, but we
+34# key both data.tables to make the foverlap code very succinct.
+35setkeyv(DT, c("arrival_time", "departure_time"))
+36setkeyv(DT_with_windows, c("window_start", "window_end"))
+37
+38# The foverlap function returns a data.table, so we can simply apply
+39# the usual data.table syntax on it!
+40# Since we have the same name of some columns in both data.tables,
+41# the latter table's columns are prefixed with "i." to avoid conflicts.
+42foverlaps(DT, DT_with_windows)[
+43 , .(docked_trucks = list(unique(i.ID)),
+44 docked_truck_count = uniqueN(i.ID))
+45 , .(arrival_time, departure_time)]provides us the output:
+ + + + + + 1 arrival_time departure_time docked_trucks docked_truck_count
+ 2 <POSc> <POSc> <list> <int>
+ 31: 2023-01-01 06:23:47 2023-01-01 06:25:08 A1 1
+ 42: 2023-01-01 06:26:42 2023-01-01 06:28:02 A1 1
+ 53: 2023-01-01 06:30:20 2023-01-01 06:35:01 A5,A6,B3,C3 4
+ 64: 2023-01-01 06:32:06 2023-01-01 06:33:48 A5,A6,B3,C3 4
+ 75: 2023-01-01 06:33:09 2023-01-01 06:36:01 A5,A6,B3,C3 4
+ 86: 2023-01-01 06:34:08 2023-01-01 06:39:49 A5,A6,B3,C3 4
+ 97: 2023-01-01 06:36:40 2023-01-01 06:38:34 B3,C3,A6,A5 4
+108: 2023-01-01 06:37:43 2023-01-01 06:40:48 C3,A6,A5 3
+119: 2023-01-01 06:39:48 2023-01-01 06:46:10 C3,A5,A6 3data.tableThe package offers a wonderful, nearly one-stop solution that doesn’t require you to write the logic out for the query or command yourself, but has a major problem for a lot of users - it requires you to switch your codebase to R, and a lot of your tasks may be on Python or in an SQL pipeline. So, what do you do?
+Consider the effort in maintaining an additional dependency for your analytics pipeline (i.e. R), and the effort that you’ll need to invest to run R from Python, or run an R script in your pipeline and pull the output from it back into the pipeline, and make your call.
+ ++ +
+ + + + + + +I stumbled upon an interesting Stackoverflow question that was linked via an issue on Polars github repo. The OP asked for a pure Polars solution. At the time of answering the question Polars did not have support for non-equi joins, and any solution using it would be pretty cumbersome.
I’m more of a right-tool-for-the-job person, so I tried to find a better solution.
Suppose we have a dataset that captures the arrival and departure times of trucks at a station, along with the truck’s ID.
1import polars as pl # if you don't have polars, run
+ 2 # pip install 'polars[all]'
+ 3data = pl.from_repr("""
+ 4┌─────────────────────┬─────────────────────┬─────┐
+ 5│ arrival_time ┆ departure_time ┆ ID │
+ 6│ --- ┆ --- ┆ --- │
+ 7│ datetime[μs] ┆ datetime[μs] ┆ str │
+ 8╞═════════════════════╪═════════════════════╪═════╡
+ 9│ 2023-01-01 06:23:47 ┆ 2023-01-01 06:25:08 ┆ A1 │
+10│ 2023-01-01 06:26:42 ┆ 2023-01-01 06:28:02 ┆ A1 │
+11│ 2023-01-01 06:30:20 ┆ 2023-01-01 06:35:01 ┆ A5 │
+12│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+13│ 2023-01-01 06:33:09 ┆ 2023-01-01 06:36:01 ┆ B3 │
+14│ 2023-01-01 06:34:08 ┆ 2023-01-01 06:39:49 ┆ C3 │
+15│ 2023-01-01 06:36:40 ┆ 2023-01-01 06:38:34 ┆ A6 │
+16│ 2023-01-01 06:37:43 ┆ 2023-01-01 06:40:48 ┆ A5 │
+17│ 2023-01-01 06:39:48 ┆ 2023-01-01 06:46:10 ┆ A6 │
+18└─────────────────────┴─────────────────────┴─────┘
+19""")We want to identify the number of trucks docked at any given time within a threshold of 1 minute prior to the arrival time of a truck, and 1 minute after the departure of a truck. Equivalently, this means that we need to calculate the number of trucks within a specific window for each row of the data.
Before we find a general solution to this problem, let’s consider a specific row to understand the problem better:
1"""
+2┌─────────────────────┬─────────────────────┬─────┐
+3│ arrival_time ┆ departure_time ┆ ID │
+4│ --- ┆ --- ┆ --- │
+5│ datetime[μs] ┆ datetime[μs] ┆ str │
+6╞═════════════════════╪═════════════════════╪═════╡
+7│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+8└─────────────────────┴─────────────────────┴─────┘
+9"""For this row, we need to find the number of trucks that are there between 2023-01-01 06:31:06 (1 minute prior to the arrival_time and 2023-01-01 06:34:48 (1 minute post the departure_time). Manually going through the original dataset, we see that B3, C3, A6 and A5 are the truck IDs that qualify - they all are at the station in a duration that is between 2023-01-01 06:31:06 and 2023-01-01 06:34:48.
There are many cases that will qualify a truck to be present in the overlap window defined by a particular row. Specifically for the example above, we have (this visualization is generalizable, because for each row we can calculate without much difficulty the overlap window relative to the arrival and departure times):

Take some time to absorb these cases - it’s important for the part where we write the code for the solution. Note that we need to actually tell our algorithm to filter only for Cases 2, 3 and 4, since Cases 1 and 5 will not satisfy our requirements.
In theory, we can use any language that has the capability to define rules that meet our algorithmic requirements outlined in the above section to find the solution. Why choose SQL? It’s often able to convey elegantly the logic that was used to execute the algorithm; and while it does come with excessive verbosity at times, it doesn’t quite in this case.
Note here that we run SQL in Python with almost no setup or boilerplate code - so this is a Python based solution as well (although not quite Pythonic!).
Once again, in theory, any SQL package or language can be used. Far too few however meet the ease-of-use that DuckDB provides:
pip install duckdb),all with mind-blowing speed that stands shoulder-to-shoulder with Polars. We’ll also use a few advanced SQL concepts noted below.
This should be a familiar, albeit not often used concept - a join of a table with itself is a self join. There are few cases where such an operation would make sense, and this happens to be one of them.
A key concept that we’ll use is the idea of joining on a range of values rather than a specific value. That is, instead of the usual LEFT JOIN ON A.column = B.column, we can do LEFT JOIN ON A.column <= B.column for one row in table A to match to multiple rows in B. DuckDB has a blog post that outlines this join in detail, including fast implementation.
LIST columnsDuckDB has first class support for LIST columns - that is, each row in a LIST column can have a varying length (much like a Python list), but must have the exact same datatype (like R’s vector). Using list columns allow us to eschew the use of an additional GROUP BY operation on top of a WHERE filter or SELECT DISTINCT operation, since we can directly perform those on the LIST column itself.
Dates can be rather difficult to handle well in most tools and languages, with several packages purpose built to make handling them easier - lubridate from the tidyverse is a stellar example. Thankfully, DuckDB provides a similar swiss-knife set of tools to deal with it, including specifying INTERVALs (a special data type that represent a period of time independent of specific time values) to modify TIMESTAMP values using addition or subtraction.
Okay - had a lot of background. Let’s have at it! The query by itself in SQL is (see immediately below for runnable code in Python):
1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN (
+16 SELECT *
+17 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+18 FROM data) B
+19
+20ON ((B.arrival_time <= A.window_open AND
+21 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+22 (B.arrival_time >= A.window_open AND
+23 B.departure_time <= A.window_close) OR
+24 (B.arrival_time >= A.window_open AND
+25 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close))
+26GROUP BY 1, 2, 3, 4A small, succinct query such as this will need a bit of explanation to take it all in. Here’s one below, reproducible in Python (make sure to install duckdb first!). Expand it to view.
1import duckdb as db
+ 2db.query("""
+ 3 SELECT
+ 4 A.arrival_time
+ 5 ,A.departure_time
+ 6 ,A.window_open
+ 7 ,A.window_close
+ 8 -- LIST aggregates the values into a LIST column
+ 9 -- and LIST_DISTINCT finds the unique values in it
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 -- finally, LIST_UNIQUE calculates the unique number of values in it
+12 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+13
+14 FROM (
+15 SELECT
+16 *
+17 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+18 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+19 FROM data -- remember we defined data as the Polars DataFrame with our truck station data
+20 ) A
+21
+22 LEFT JOIN (
+23 SELECT
+24 *
+25 -- This is the time, in seconds between the arrival and departure of
+26 -- each truck PER ROW in the original data-frame
+27 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+28 FROM data -- this is where we perform a self-join
+29 ) B
+30
+31 ON (
+32 -- Case 2 in the diagram;
+33 (B.arrival_time <= A.window_open AND
+34 -- Adding the duration here makes sure that the second interval
+35 -- is at least ENDING AFTER the start of the overlap window
+36 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+37
+38 -- Case 3 in the diagram - the simplest of all five cases
+39 (B.arrival_time >= A.window_open AND
+40 B.departure_time <= A.window_close) OR
+41
+42 -- Case 4 in the digram;
+43 (B.arrival_time >= A.window_open AND
+44 -- Subtracting the duration here makes sure that the second interval
+45 -- STARTS BEFORE the end of the overlap window.
+46 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)
+47 )
+48 GROUP BY 1, 2, 3, 4
+49""")The output of this query is:
"""
+┌─────────────────────┬─────────────────────┬─────────────────────┬───┬──────────────────┬────────────────────┐
+│ arrival_time │ departure_time │ window_open │ … │ docked_trucks │ docked_truck_count │
+│ timestamp │ timestamp │ timestamp │ │ varchar[] │ uint64 │
+├─────────────────────┼─────────────────────┼─────────────────────┼───┼──────────────────┼────────────────────┤
+│ 2023-01-01 06:23:47 │ 2023-01-01 06:25:08 │ 2023-01-01 06:22:47 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:26:42 │ 2023-01-01 06:28:02 │ 2023-01-01 06:25:42 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:30:20 │ 2023-01-01 06:35:01 │ 2023-01-01 06:29:20 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:32:06 │ 2023-01-01 06:33:48 │ 2023-01-01 06:31:06 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:33:09 │ 2023-01-01 06:36:01 │ 2023-01-01 06:32:09 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:34:08 │ 2023-01-01 06:39:49 │ 2023-01-01 06:33:08 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:36:40 │ 2023-01-01 06:38:34 │ 2023-01-01 06:35:40 │ … │ [A5, A6, C3, B3] │ 4 │
+│ 2023-01-01 06:37:43 │ 2023-01-01 06:40:48 │ 2023-01-01 06:36:43 │ … │ [A5, A6, C3] │ 3 │
+│ 2023-01-01 06:39:48 │ 2023-01-01 06:46:10 │ 2023-01-01 06:38:48 │ … │ [A6, A5, C3] │ 3 │
+├─────────────────────┴─────────────────────┴─────────────────────┴───┴──────────────────┴────────────────────┤
+│ 9 rows 6 columns (5 shown) │
+└─────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
+"""We clearly see the strengths of DuckDB in how succintly we were able to express this operation. We also find how DuckDB is able to seamlessly integrate with an existing Pandas or Polars pipeline with zero-conversion costs. In fact, we can convert this back to a Polars or Pandas dataframe by appending the ending bracket with db.query(...).pl() and db.query(...).pd() respectively.
Now that we’ve understood the logic that goes into the query, let’s try to optimize the algorithm. We have the three conditions:
1-- Case 2 in the diagram
+2(B.arrival_time <= A.window_open AND
+3 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+4-- Case 3 in the diagram
+5(B.arrival_time >= A.window_open AND
+6 B.departure_time <= A.window_close) OR
+7-- Case 4 in the diagram
+8(B.arrival_time >= A.window_open AND
+9 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)What is common between these three conditions? It takes a while to see it; but it becomes clear that all these cases require the start of the overlap to be before the window ends, and the end of the overlap to be after the window starts. This can be simplified to just:
1B.arrival_time <= A.window_close AND
+2B.departure_time >= A.window_openmaking our query much simpler!
We’ve removed the need for the duration calculation algother now. Therefore, we can write:
1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN data B
+16
+17ON (
+18 B.arrival_time <= A.window_close AND
+19 B.departure_time >= A.window_open
+20)
+21GROUP BY 1, 2, 3, 4Can we simplify this even further?
I think the SQL query in the above section is very easy to ready already. However, it is a little clunky overall, and there is a way that we can leverage DuckDB’s extensive optimizations to simplify our legibility by rewriting the query as a cross join:
1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 5 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8FROM data A, data B
+ 9WHERE B.arrival_time <= window_close
+10AND B.departure_time >= window_open
+11GROUP BY 1, 2, 3, 4Why does this work? Before optimization on DuckDB, this is what the query plan looks like:
1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ FILTER │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ (arrival_time <= │
+25│(departure_time + to_m... │
+26│ AS BIGINT)))) │
+27│ (departure_time >= │
+28│(arrival_time - to_min... │
+29│ AS BIGINT)))) │
+30└─────────────┬─────────────┘
+31┌─────────────┴─────────────┐
+32│ CROSS_PRODUCT ├──────────────┐
+33└─────────────┬─────────────┘ │
+34┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+35│ ARROW_SCAN ││ ARROW_SCAN │
+36└───────────────────────────┘└───────────────────────────┘
+37""" After optimization, the CROSS_PRODUCT is automatically optimized to an interval join!
1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ COMPARISON_JOIN │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ INNER │
+25│ ((departure_time + '00:01 │
+26│ :00'::INTERVAL) >= ├──────────────┐
+27│ arrival_time) │ │
+28│((arrival_time - '00:01:00'│ │
+29│ ::INTERVAL) <= │ │
+30│ departure_time) │ │
+31└─────────────┬─────────────┘ │
+32┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+33│ ARROW_SCAN ││ ARROW_SCAN │
+34└───────────────────────────┘└───────────────────────────┘
+35""" So in effect, we’re actually exploiting a feature of DuckDB to allow us to write our queries in a suboptimal manner for greater readability, and allowing the optmizer to do a good chunk of our work for us. I wouldn’t recommend using this generally, because not all SQL engine optmizers will be able to find an efficient route to these calculations for large datasets.
I’m glad you asked. Here’s the DuckDB page explaining EXPLAIN (heh). Here’s the code I used:
1import duckdb as db
+ 2db.sql("SET EXPLAIN_OUTPUT='all';")
+ 3print(db.query("""
+ 4EXPLAIN
+ 5SELECT
+ 6 A.arrival_time
+ 7 ,A.departure_time
+ 8 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 9 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+12FROM data A, data B
+13WHERE B.arrival_time <= window_close
+14AND B.departure_time >= window_open
+15GROUP BY 1, 2, 3, 4
+16""").pl()[1, 1])data.table waydata.table is a package that has historically been ahead of its time - in both speed and features that it has had. Developement has taken a hit recently, but will likely pick back up. It’s my favourite package on all fronts for data manipulation, but suffers simply from the lack of broader R support across the ML and DL space.
foverlaps functionIf this kind of overlapping join is common, shouldn’t someone have developed a package for it? Turns out, data.table has, and with very specific constraints that make it the perfect solution to our problem (if you don’t mind switching over to R, that is).
The foverlaps function has these requirements:
data.table objects have to be keyed for automatic recognition of columns.within overlap, matching start and end windows,by must specify the start and end points of the overlapping window. This isn’t a problem for us now, but does restrict for future uses where we may want non-equi joins on other cases.Without further ado:
1library(data.table)
+ 2library(lubridate)
+ 3
+ 4######### BOILERPLATE CODE, NO LOGIC HERE ####################
+ 5arrival_time = as_datetime(c(
+ 6 '2023-01-01 06:23:47.000000', '2023-01-01 06:26:42.000000',
+ 7 '2023-01-01 06:30:20.000000', '2023-01-01 06:32:06.000000',
+ 8 '2023-01-01 06:33:09.000000', '2023-01-01 06:34:08.000000',
+ 9 '2023-01-01 06:36:40.000000', '2023-01-01 06:37:43.000000',
+10 '2023-01-01 06:39:48.000000'))
+11departure_time = as_datetime(c(
+12 '2023-01-01 06:25:08.000000', '2023-01-01 06:28:02.000000',
+13 '2023-01-01 06:35:01.000000', '2023-01-01 06:33:48.000000',
+14 '2023-01-01 06:36:01.000000', '2023-01-01 06:39:49.000000',
+15 '2023-01-01 06:38:34.000000', '2023-01-01 06:40:48.000000',
+16 '2023-01-01 06:46:10.000000'))
+17ID = c('A1', 'A1', 'A5', 'A6', 'B3', 'C3', 'A6', 'A5', 'A6')
+18
+19DT = data.table(
+20 arrival_time = arrival_time,
+21 departure_time = departure_time,
+22 ID = ID)
+23######### BOILERPLATE CODE, NO LOGIC HERE ####################
+24
+25# A copy(DT) creates a copy of a data.table that isn't linked
+26# to the original one, so that changes in it don't reflect in
+27# the original DT object.
+28# The `:=` allow assignment by reference (i.e. "in place").
+29DT_with_windows = copy(DT)[, `:=`(
+30 window_start = arrival_time - minutes(1),
+31 window_end = departure_time + minutes(1))]
+32
+33# This step is necessary for the second table, but not the first, but we
+34# key both data.tables to make the foverlap code very succinct.
+35setkeyv(DT, c("arrival_time", "departure_time"))
+36setkeyv(DT_with_windows, c("window_start", "window_end"))
+37
+38# The foverlap function returns a data.table, so we can simply apply
+39# the usual data.table syntax on it!
+40# Since we have the same name of some columns in both data.tables,
+41# the latter table's columns are prefixed with "i." to avoid conflicts.
+42foverlaps(DT, DT_with_windows)[
+43 , .(docked_trucks = list(unique(i.ID)),
+44 docked_truck_count = uniqueN(i.ID))
+45 , .(arrival_time, departure_time)]provides us the output:
1 arrival_time departure_time docked_trucks docked_truck_count
+ 2 <POSc> <POSc> <list> <int>
+ 31: 2023-01-01 06:23:47 2023-01-01 06:25:08 A1 1
+ 42: 2023-01-01 06:26:42 2023-01-01 06:28:02 A1 1
+ 53: 2023-01-01 06:30:20 2023-01-01 06:35:01 A5,A6,B3,C3 4
+ 64: 2023-01-01 06:32:06 2023-01-01 06:33:48 A5,A6,B3,C3 4
+ 75: 2023-01-01 06:33:09 2023-01-01 06:36:01 A5,A6,B3,C3 4
+ 86: 2023-01-01 06:34:08 2023-01-01 06:39:49 A5,A6,B3,C3 4
+ 97: 2023-01-01 06:36:40 2023-01-01 06:38:34 B3,C3,A6,A5 4
+108: 2023-01-01 06:37:43 2023-01-01 06:40:48 C3,A6,A5 3
+119: 2023-01-01 06:39:48 2023-01-01 06:46:10 C3,A5,A6 3data.tableThe package offers a wonderful, nearly one-stop solution that doesn’t require you to write the logic out for the query or command yourself, but has a major problem for a lot of users - it requires you to switch your codebase to R, and a lot of your tasks may be on Python or in an SQL pipeline. So, what do you do?
Consider the effort in maintaining an additional dependency for your analytics pipeline (i.e. R), and the effort that you’ll need to invest to run R from Python, or run an R script in your pipeline and pull the output from it back into the pipeline, and make your call.
In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
Here’s what you need to do:
Generally, three things come to mind:
This data can be practically anything that can be represented as a 2D matrix.
There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
Craft a small representative sample for each category.
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.

What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:

Very starry night-eque vibes here. Let’s get to the code.
1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.#representative +#samples +#faiss +#approximate +#nearest +#neighbor +#network +#graph +#networkx +#polars +#category
When I worked in healthcare consulting, I often spent a LOT of my time creating PowerPoint presentations (decks in consulting lingo - not even slide decks). However, it was rather repetitive. Thus, was born PowerPointSnap.
I’ll write this down as pointers.
The project is available on this Github repo. The instructions to install it are available there, but here’s the down-low:
Frankly, a LOT. The base concept of this tool is:
Here’s a non-exhaustive list of all the options available.
This is the part of the interface that can be used for shapes (which include charts and tables).

To use, first select a shape object, click on “Set”. Then, choose the object you want to Snap its properties to (see how I got the inspiration for the name?). You should be able to copy all compatible properties - if something is not copy-able, the tool will show an error, and then let you exit.
Note that it’s probably not to apply a property of a shape to a table - if you want to make the entire table orange, there are probably better built-in ways to do it than to use Snap.
Charts are also supported, with dedicated features for it.

What do these features do? You should be able to hover over the option and get a tooltip that shows what it’s capable of, but here’s another summary just in case:
The next two options deserve their own section.
Your immediate senior in a consulting environment would frown at your chart, and then exclaim, “I think that’s too many labels for the data points. Can you show them every two/three/four labels? I know this is manual work, but it’s a one time thing!”
It’s never a one time affair. But don’t worry, we have this nice feature to help us. If you click on the Customize Label option, you will get this (without the “Set” option):
Never mind the rather unfriendly legend entries. They’re just here to demonstrate that you can do the following kinds of whacky abilities with your own chart!
Of course, visuals will do it more justice. For example, look at this image:

Here’s what you can do:
This is what your results should look like:

Of course, getting those calculations right is a whole different thing that will need some work.
Oftentimes, you have two tables that show similar values… you know the drill. Here’s what you can do in a scenario such as this:

This is what the Tables section of the tool looks like:

To align these tables together,
Here’s what you’ll end up with:

Pretty neat, eh?
When I worked in healthcare consulting, I often spent a LOT of my time creating PowerPoint presentations (decks in consulting lingo - not even slide decks). However, it was rather repetitive. Thus, was born PowerPointSnap.
+I’ll write this down as pointers.
+The project is available on this Github repo. The instructions to install it are available there, but here’s the down-low:
+Frankly, a LOT. The base concept of this tool is:
+Here’s a non-exhaustive list of all the options available.
+This is the part of the interface that can be used for shapes (which include charts and tables).
+
To use, first select a shape object, click on “Set”. Then, choose the object you want to Snap its properties to (see how I got the inspiration for the name?). You should be able to copy all compatible properties - if something is not copy-able, the tool will show an error, and then let you exit.
+Note that it’s probably not to apply a property of a shape to a table - if you want to make the entire table orange, there are probably better built-in ways to do it than to use Snap.
+Charts are also supported, with dedicated features for it.
+
What do these features do? You should be able to hover over the option and get a tooltip that shows what it’s capable of, but here’s another summary just in case:
+The next two options deserve their own section.
+Your immediate senior in a consulting environment would frown at your chart, and then exclaim, “I think that’s too many labels for the data points. Can you show them every two/three/four labels? I know this is manual work, but it’s a one time thing!”
+It’s never a one time affair. But don’t worry, we have this nice feature to help us. If you click on the Customize Label option, you will get this (without the “Set” option):
+Never mind the rather unfriendly legend entries. They’re just here to demonstrate that you can do the following kinds of whacky abilities with your own chart!
+Of course, visuals will do it more justice. For example, look at this image:
+
Here’s what you can do:
+This is what your results should look like:
+
Of course, getting those calculations right is a whole different thing that will need some work.
+Oftentimes, you have two tables that show similar values… you know the drill. Here’s what you can do in a scenario such as this:
+
This is what the Tables section of the tool looks like:
+
To align these tables together,
+Here’s what you’ll end up with:
+
Pretty neat, eh?
+]]>In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.I stumbled upon an interesting Stackoverflow question that was linked via an issue on Polars github repo. The OP asked for a pure Polars solution. At the time of answering the question Polars did not have support for non-equi joins, and any solution using it would be pretty cumbersome.
+I’m more of a right-tool-for-the-job person, so I tried to find a better solution.
+Suppose we have a dataset that captures the arrival and departure times of trucks at a station, along with the truck’s ID.
+ + + + + + 1import polars as pl # if you don't have polars, run
+ 2 # pip install 'polars[all]'
+ 3data = pl.from_repr("""
+ 4┌─────────────────────┬─────────────────────┬─────┐
+ 5│ arrival_time ┆ departure_time ┆ ID │
+ 6│ --- ┆ --- ┆ --- │
+ 7│ datetime[μs] ┆ datetime[μs] ┆ str │
+ 8╞═════════════════════╪═════════════════════╪═════╡
+ 9│ 2023-01-01 06:23:47 ┆ 2023-01-01 06:25:08 ┆ A1 │
+10│ 2023-01-01 06:26:42 ┆ 2023-01-01 06:28:02 ┆ A1 │
+11│ 2023-01-01 06:30:20 ┆ 2023-01-01 06:35:01 ┆ A5 │
+12│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+13│ 2023-01-01 06:33:09 ┆ 2023-01-01 06:36:01 ┆ B3 │
+14│ 2023-01-01 06:34:08 ┆ 2023-01-01 06:39:49 ┆ C3 │
+15│ 2023-01-01 06:36:40 ┆ 2023-01-01 06:38:34 ┆ A6 │
+16│ 2023-01-01 06:37:43 ┆ 2023-01-01 06:40:48 ┆ A5 │
+17│ 2023-01-01 06:39:48 ┆ 2023-01-01 06:46:10 ┆ A6 │
+18└─────────────────────┴─────────────────────┴─────┘
+19""")We want to identify the number of trucks docked at any given time within a threshold of 1 minute prior to the arrival time of a truck, and 1 minute after the departure of a truck. Equivalently, this means that we need to calculate the number of trucks within a specific window for each row of the data.
+Before we find a general solution to this problem, let’s consider a specific row to understand the problem better:
+ + + + + +1"""
+2┌─────────────────────┬─────────────────────┬─────┐
+3│ arrival_time ┆ departure_time ┆ ID │
+4│ --- ┆ --- ┆ --- │
+5│ datetime[μs] ┆ datetime[μs] ┆ str │
+6╞═════════════════════╪═════════════════════╪═════╡
+7│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+8└─────────────────────┴─────────────────────┴─────┘
+9"""For this row, we need to find the number of trucks that are there between 2023-01-01 06:31:06 (1 minute prior to the arrival_time and 2023-01-01 06:34:48 (1 minute post the departure_time). Manually going through the original dataset, we see that B3, C3, A6 and A5 are the truck IDs that qualify - they all are at the station in a duration that is between 2023-01-01 06:31:06 and 2023-01-01 06:34:48.
There are many cases that will qualify a truck to be present in the overlap window defined by a particular row. Specifically for the example above, we have (this visualization is generalizable, because for each row we can calculate without much difficulty the overlap window relative to the arrival and departure times):
+
Take some time to absorb these cases - it’s important for the part where we write the code for the solution. Note that we need to actually tell our algorithm to filter only for Cases 2, 3 and 4, since Cases 1 and 5 will not satisfy our requirements.
+In theory, we can use any language that has the capability to define rules that meet our algorithmic requirements outlined in the above section to find the solution. Why choose SQL? It’s often able to convey elegantly the logic that was used to execute the algorithm; and while it does come with excessive verbosity at times, it doesn’t quite in this case.
+Note here that we run SQL in Python with almost no setup or boilerplate code - so this is a Python based solution as well (although not quite Pythonic!).
+Once again, in theory, any SQL package or language can be used. Far too few however meet the ease-of-use that DuckDB provides:
+pip install duckdb),all with mind-blowing speed that stands shoulder-to-shoulder with Polars. We’ll also use a few advanced SQL concepts noted below.
+This should be a familiar, albeit not often used concept - a join of a table with itself is a self join. There are few cases where such an operation would make sense, and this happens to be one of them.
+A key concept that we’ll use is the idea of joining on a range of values rather than a specific value. That is, instead of the usual LEFT JOIN ON A.column = B.column, we can do LEFT JOIN ON A.column <= B.column for one row in table A to match to multiple rows in B. DuckDB has a blog post that outlines this join in detail, including fast implementation.
LIST columnsDuckDB has first class support for LIST columns - that is, each row in a LIST column can have a varying length (much like a Python list), but must have the exact same datatype (like R’s vector). Using list columns allow us to eschew the use of an additional GROUP BY operation on top of a WHERE filter or SELECT DISTINCT operation, since we can directly perform those on the LIST column itself.
Dates can be rather difficult to handle well in most tools and languages, with several packages purpose built to make handling them easier - lubridate from the tidyverse is a stellar example. Thankfully, DuckDB provides a similar swiss-knife set of tools to deal with it, including specifying INTERVALs (a special data type that represent a period of time independent of specific time values) to modify TIMESTAMP values using addition or subtraction.
Okay - had a lot of background. Let’s have at it! The query by itself in SQL is (see immediately below for runnable code in Python):
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN (
+16 SELECT *
+17 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+18 FROM data) B
+19
+20ON ((B.arrival_time <= A.window_open AND
+21 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+22 (B.arrival_time >= A.window_open AND
+23 B.departure_time <= A.window_close) OR
+24 (B.arrival_time >= A.window_open AND
+25 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close))
+26GROUP BY 1, 2, 3, 4A small, succinct query such as this will need a bit of explanation to take it all in. Here’s one below, reproducible in Python (make sure to install duckdb first!). Expand it to view.
1import duckdb as db
+ 2db.query("""
+ 3 SELECT
+ 4 A.arrival_time
+ 5 ,A.departure_time
+ 6 ,A.window_open
+ 7 ,A.window_close
+ 8 -- LIST aggregates the values into a LIST column
+ 9 -- and LIST_DISTINCT finds the unique values in it
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 -- finally, LIST_UNIQUE calculates the unique number of values in it
+12 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+13
+14 FROM (
+15 SELECT
+16 *
+17 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+18 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+19 FROM data -- remember we defined data as the Polars DataFrame with our truck station data
+20 ) A
+21
+22 LEFT JOIN (
+23 SELECT
+24 *
+25 -- This is the time, in seconds between the arrival and departure of
+26 -- each truck PER ROW in the original data-frame
+27 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+28 FROM data -- this is where we perform a self-join
+29 ) B
+30
+31 ON (
+32 -- Case 2 in the diagram;
+33 (B.arrival_time <= A.window_open AND
+34 -- Adding the duration here makes sure that the second interval
+35 -- is at least ENDING AFTER the start of the overlap window
+36 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+37
+38 -- Case 3 in the diagram - the simplest of all five cases
+39 (B.arrival_time >= A.window_open AND
+40 B.departure_time <= A.window_close) OR
+41
+42 -- Case 4 in the digram;
+43 (B.arrival_time >= A.window_open AND
+44 -- Subtracting the duration here makes sure that the second interval
+45 -- STARTS BEFORE the end of the overlap window.
+46 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)
+47 )
+48 GROUP BY 1, 2, 3, 4
+49""")The output of this query is:
+ + + + + +"""
+┌─────────────────────┬─────────────────────┬─────────────────────┬───┬──────────────────┬────────────────────┐
+│ arrival_time │ departure_time │ window_open │ … │ docked_trucks │ docked_truck_count │
+│ timestamp │ timestamp │ timestamp │ │ varchar[] │ uint64 │
+├─────────────────────┼─────────────────────┼─────────────────────┼───┼──────────────────┼────────────────────┤
+│ 2023-01-01 06:23:47 │ 2023-01-01 06:25:08 │ 2023-01-01 06:22:47 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:26:42 │ 2023-01-01 06:28:02 │ 2023-01-01 06:25:42 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:30:20 │ 2023-01-01 06:35:01 │ 2023-01-01 06:29:20 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:32:06 │ 2023-01-01 06:33:48 │ 2023-01-01 06:31:06 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:33:09 │ 2023-01-01 06:36:01 │ 2023-01-01 06:32:09 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:34:08 │ 2023-01-01 06:39:49 │ 2023-01-01 06:33:08 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:36:40 │ 2023-01-01 06:38:34 │ 2023-01-01 06:35:40 │ … │ [A5, A6, C3, B3] │ 4 │
+│ 2023-01-01 06:37:43 │ 2023-01-01 06:40:48 │ 2023-01-01 06:36:43 │ … │ [A5, A6, C3] │ 3 │
+│ 2023-01-01 06:39:48 │ 2023-01-01 06:46:10 │ 2023-01-01 06:38:48 │ … │ [A6, A5, C3] │ 3 │
+├─────────────────────┴─────────────────────┴─────────────────────┴───┴──────────────────┴────────────────────┤
+│ 9 rows 6 columns (5 shown) │
+└─────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
+"""We clearly see the strengths of DuckDB in how succintly we were able to express this operation. We also find how DuckDB is able to seamlessly integrate with an existing Pandas or Polars pipeline with zero-conversion costs. In fact, we can convert this back to a Polars or Pandas dataframe by appending the ending bracket with db.query(...).pl() and db.query(...).pd() respectively.
Now that we’ve understood the logic that goes into the query, let’s try to optimize the algorithm. We have the three conditions:
+ + + + + +1-- Case 2 in the diagram
+2(B.arrival_time <= A.window_open AND
+3 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+4-- Case 3 in the diagram
+5(B.arrival_time >= A.window_open AND
+6 B.departure_time <= A.window_close) OR
+7-- Case 4 in the diagram
+8(B.arrival_time >= A.window_open AND
+9 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)What is common between these three conditions? It takes a while to see it; but it becomes clear that all these cases require the start of the overlap to be before the window ends, and the end of the overlap to be after the window starts. This can be simplified to just:
+ + + + + +1B.arrival_time <= A.window_close AND
+2B.departure_time >= A.window_openmaking our query much simpler!
+We’ve removed the need for the duration calculation algother now. Therefore, we can write:
1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN data B
+16
+17ON (
+18 B.arrival_time <= A.window_close AND
+19 B.departure_time >= A.window_open
+20)
+21GROUP BY 1, 2, 3, 4Can we simplify this even further?
+I think the SQL query in the above section is very easy to ready already. However, it is a little clunky overall, and there is a way that we can leverage DuckDB’s extensive optimizations to simplify our legibility by rewriting the query as a cross join:
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 5 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8FROM data A, data B
+ 9WHERE B.arrival_time <= window_close
+10AND B.departure_time >= window_open
+11GROUP BY 1, 2, 3, 4Why does this work? Before optimization on DuckDB, this is what the query plan looks like:
+ 1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ FILTER │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ (arrival_time <= │
+25│(departure_time + to_m... │
+26│ AS BIGINT)))) │
+27│ (departure_time >= │
+28│(arrival_time - to_min... │
+29│ AS BIGINT)))) │
+30└─────────────┬─────────────┘
+31┌─────────────┴─────────────┐
+32│ CROSS_PRODUCT ├──────────────┐
+33└─────────────┬─────────────┘ │
+34┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+35│ ARROW_SCAN ││ ARROW_SCAN │
+36└───────────────────────────┘└───────────────────────────┘
+37""" After optimization, the CROSS_PRODUCT is automatically optimized to an interval join!
1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ COMPARISON_JOIN │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ INNER │
+25│ ((departure_time + '00:01 │
+26│ :00'::INTERVAL) >= ├──────────────┐
+27│ arrival_time) │ │
+28│((arrival_time - '00:01:00'│ │
+29│ ::INTERVAL) <= │ │
+30│ departure_time) │ │
+31└─────────────┬─────────────┘ │
+32┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+33│ ARROW_SCAN ││ ARROW_SCAN │
+34└───────────────────────────┘└───────────────────────────┘
+35""" So in effect, we’re actually exploiting a feature of DuckDB to allow us to write our queries in a suboptimal manner for greater readability, and allowing the optmizer to do a good chunk of our work for us. I wouldn’t recommend using this generally, because not all SQL engine optmizers will be able to find an efficient route to these calculations for large datasets.
+I’m glad you asked. Here’s the DuckDB page explaining EXPLAIN (heh). Here’s the code I used:
1import duckdb as db
+ 2db.sql("SET EXPLAIN_OUTPUT='all';")
+ 3print(db.query("""
+ 4EXPLAIN
+ 5SELECT
+ 6 A.arrival_time
+ 7 ,A.departure_time
+ 8 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 9 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+12FROM data A, data B
+13WHERE B.arrival_time <= window_close
+14AND B.departure_time >= window_open
+15GROUP BY 1, 2, 3, 4
+16""").pl()[1, 1])data.table waydata.table is a package that has historically been ahead of its time - in both speed and features that it has had. Developement has taken a hit recently, but will likely pick back up. It’s my favourite package on all fronts for data manipulation, but suffers simply from the lack of broader R support across the ML and DL space.
foverlaps functionIf this kind of overlapping join is common, shouldn’t someone have developed a package for it? Turns out, data.table has, and with very specific constraints that make it the perfect solution to our problem (if you don’t mind switching over to R, that is).
The foverlaps function has these requirements:
data.table objects have to be keyed for automatic recognition of columns.within overlap, matching start and end windows,by must specify the start and end points of the overlapping window. This isn’t a problem for us now, but does restrict for future uses where we may want non-equi joins on other cases.Without further ado:
+ + + + + + 1library(data.table)
+ 2library(lubridate)
+ 3
+ 4######### BOILERPLATE CODE, NO LOGIC HERE ####################
+ 5arrival_time = as_datetime(c(
+ 6 '2023-01-01 06:23:47.000000', '2023-01-01 06:26:42.000000',
+ 7 '2023-01-01 06:30:20.000000', '2023-01-01 06:32:06.000000',
+ 8 '2023-01-01 06:33:09.000000', '2023-01-01 06:34:08.000000',
+ 9 '2023-01-01 06:36:40.000000', '2023-01-01 06:37:43.000000',
+10 '2023-01-01 06:39:48.000000'))
+11departure_time = as_datetime(c(
+12 '2023-01-01 06:25:08.000000', '2023-01-01 06:28:02.000000',
+13 '2023-01-01 06:35:01.000000', '2023-01-01 06:33:48.000000',
+14 '2023-01-01 06:36:01.000000', '2023-01-01 06:39:49.000000',
+15 '2023-01-01 06:38:34.000000', '2023-01-01 06:40:48.000000',
+16 '2023-01-01 06:46:10.000000'))
+17ID = c('A1', 'A1', 'A5', 'A6', 'B3', 'C3', 'A6', 'A5', 'A6')
+18
+19DT = data.table(
+20 arrival_time = arrival_time,
+21 departure_time = departure_time,
+22 ID = ID)
+23######### BOILERPLATE CODE, NO LOGIC HERE ####################
+24
+25# A copy(DT) creates a copy of a data.table that isn't linked
+26# to the original one, so that changes in it don't reflect in
+27# the original DT object.
+28# The `:=` allow assignment by reference (i.e. "in place").
+29DT_with_windows = copy(DT)[, `:=`(
+30 window_start = arrival_time - minutes(1),
+31 window_end = departure_time + minutes(1))]
+32
+33# This step is necessary for the second table, but not the first, but we
+34# key both data.tables to make the foverlap code very succinct.
+35setkeyv(DT, c("arrival_time", "departure_time"))
+36setkeyv(DT_with_windows, c("window_start", "window_end"))
+37
+38# The foverlap function returns a data.table, so we can simply apply
+39# the usual data.table syntax on it!
+40# Since we have the same name of some columns in both data.tables,
+41# the latter table's columns are prefixed with "i." to avoid conflicts.
+42foverlaps(DT, DT_with_windows)[
+43 , .(docked_trucks = list(unique(i.ID)),
+44 docked_truck_count = uniqueN(i.ID))
+45 , .(arrival_time, departure_time)]provides us the output:
+ + + + + + 1 arrival_time departure_time docked_trucks docked_truck_count
+ 2 <POSc> <POSc> <list> <int>
+ 31: 2023-01-01 06:23:47 2023-01-01 06:25:08 A1 1
+ 42: 2023-01-01 06:26:42 2023-01-01 06:28:02 A1 1
+ 53: 2023-01-01 06:30:20 2023-01-01 06:35:01 A5,A6,B3,C3 4
+ 64: 2023-01-01 06:32:06 2023-01-01 06:33:48 A5,A6,B3,C3 4
+ 75: 2023-01-01 06:33:09 2023-01-01 06:36:01 A5,A6,B3,C3 4
+ 86: 2023-01-01 06:34:08 2023-01-01 06:39:49 A5,A6,B3,C3 4
+ 97: 2023-01-01 06:36:40 2023-01-01 06:38:34 B3,C3,A6,A5 4
+108: 2023-01-01 06:37:43 2023-01-01 06:40:48 C3,A6,A5 3
+119: 2023-01-01 06:39:48 2023-01-01 06:46:10 C3,A5,A6 3data.tableThe package offers a wonderful, nearly one-stop solution that doesn’t require you to write the logic out for the query or command yourself, but has a major problem for a lot of users - it requires you to switch your codebase to R, and a lot of your tasks may be on Python or in an SQL pipeline. So, what do you do?
+Consider the effort in maintaining an additional dependency for your analytics pipeline (i.e. R), and the effort that you’ll need to invest to run R from Python, or run an R script in your pipeline and pull the output from it back into the pipeline, and make your call.
+]]>My name is Avinash Mallya (pronounced Uh-vin-aash Muh-ll-yeah), and I’m a data scientist by profession. This website is a creative outlet, and my piece of the internet where I show off.
You’ll find the following:
You can find me on:
Please reach out via one of the above if you want to talk.
data.table, that I found very useful earlier in my career to quicky churn out analyses. It is not ground-breaking, but rather something that anybody with sufficient basic skills in R and understand, and save an immense amount of time.data.table and dplyr dominated), so I was eager to make it better for everybody making the switch.When I worked in healthcare consulting, I often spent a LOT of my time creating PowerPoint presentations (decks in consulting lingo - not even slide decks). However, it was rather repetitive. Thus, was born PowerPointSnap.
+I’ll write this down as pointers.
+The project is available on this Github repo. The instructions to install it are available there, but here’s the down-low:
+Frankly, a LOT. The base concept of this tool is:
+Here’s a non-exhaustive list of all the options available.
+This is the part of the interface that can be used for shapes (which include charts and tables).
+
To use, first select a shape object, click on “Set”. Then, choose the object you want to Snap its properties to (see how I got the inspiration for the name?). You should be able to copy all compatible properties - if something is not copy-able, the tool will show an error, and then let you exit.
+Note that it’s probably not to apply a property of a shape to a table - if you want to make the entire table orange, there are probably better built-in ways to do it than to use Snap.
+Charts are also supported, with dedicated features for it.
+
What do these features do? You should be able to hover over the option and get a tooltip that shows what it’s capable of, but here’s another summary just in case:
+The next two options deserve their own section.
+Your immediate senior in a consulting environment would frown at your chart, and then exclaim, “I think that’s too many labels for the data points. Can you show them every two/three/four labels? I know this is manual work, but it’s a one time thing!”
+It’s never a one time affair. But don’t worry, we have this nice feature to help us. If you click on the Customize Label option, you will get this (without the “Set” option):
+Never mind the rather unfriendly legend entries. They’re just here to demonstrate that you can do the following kinds of whacky abilities with your own chart!
+Of course, visuals will do it more justice. For example, look at this image:
+
Here’s what you can do:
+This is what your results should look like:
+
Of course, getting those calculations right is a whole different thing that will need some work.
+Oftentimes, you have two tables that show similar values… you know the drill. Here’s what you can do in a scenario such as this:
+
This is what the Tables section of the tool looks like:
+
To align these tables together,
+Here’s what you’ll end up with:
+
Pretty neat, eh?
+]]>In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.I stumbled upon an interesting Stackoverflow question that was linked via an issue on Polars github repo. The OP asked for a pure Polars solution. At the time of answering the question Polars did not have support for non-equi joins, and any solution using it would be pretty cumbersome.
+I’m more of a right-tool-for-the-job person, so I tried to find a better solution.
+Suppose we have a dataset that captures the arrival and departure times of trucks at a station, along with the truck’s ID.
+ + + + + + 1import polars as pl # if you don't have polars, run
+ 2 # pip install 'polars[all]'
+ 3data = pl.from_repr("""
+ 4┌─────────────────────┬─────────────────────┬─────┐
+ 5│ arrival_time ┆ departure_time ┆ ID │
+ 6│ --- ┆ --- ┆ --- │
+ 7│ datetime[μs] ┆ datetime[μs] ┆ str │
+ 8╞═════════════════════╪═════════════════════╪═════╡
+ 9│ 2023-01-01 06:23:47 ┆ 2023-01-01 06:25:08 ┆ A1 │
+10│ 2023-01-01 06:26:42 ┆ 2023-01-01 06:28:02 ┆ A1 │
+11│ 2023-01-01 06:30:20 ┆ 2023-01-01 06:35:01 ┆ A5 │
+12│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+13│ 2023-01-01 06:33:09 ┆ 2023-01-01 06:36:01 ┆ B3 │
+14│ 2023-01-01 06:34:08 ┆ 2023-01-01 06:39:49 ┆ C3 │
+15│ 2023-01-01 06:36:40 ┆ 2023-01-01 06:38:34 ┆ A6 │
+16│ 2023-01-01 06:37:43 ┆ 2023-01-01 06:40:48 ┆ A5 │
+17│ 2023-01-01 06:39:48 ┆ 2023-01-01 06:46:10 ┆ A6 │
+18└─────────────────────┴─────────────────────┴─────┘
+19""")We want to identify the number of trucks docked at any given time within a threshold of 1 minute prior to the arrival time of a truck, and 1 minute after the departure of a truck. Equivalently, this means that we need to calculate the number of trucks within a specific window for each row of the data.
+Before we find a general solution to this problem, let’s consider a specific row to understand the problem better:
+ + + + + +1"""
+2┌─────────────────────┬─────────────────────┬─────┐
+3│ arrival_time ┆ departure_time ┆ ID │
+4│ --- ┆ --- ┆ --- │
+5│ datetime[μs] ┆ datetime[μs] ┆ str │
+6╞═════════════════════╪═════════════════════╪═════╡
+7│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+8└─────────────────────┴─────────────────────┴─────┘
+9"""For this row, we need to find the number of trucks that are there between 2023-01-01 06:31:06 (1 minute prior to the arrival_time and 2023-01-01 06:34:48 (1 minute post the departure_time). Manually going through the original dataset, we see that B3, C3, A6 and A5 are the truck IDs that qualify - they all are at the station in a duration that is between 2023-01-01 06:31:06 and 2023-01-01 06:34:48.
There are many cases that will qualify a truck to be present in the overlap window defined by a particular row. Specifically for the example above, we have (this visualization is generalizable, because for each row we can calculate without much difficulty the overlap window relative to the arrival and departure times):
+
Take some time to absorb these cases - it’s important for the part where we write the code for the solution. Note that we need to actually tell our algorithm to filter only for Cases 2, 3 and 4, since Cases 1 and 5 will not satisfy our requirements.
+In theory, we can use any language that has the capability to define rules that meet our algorithmic requirements outlined in the above section to find the solution. Why choose SQL? It’s often able to convey elegantly the logic that was used to execute the algorithm; and while it does come with excessive verbosity at times, it doesn’t quite in this case.
+Note here that we run SQL in Python with almost no setup or boilerplate code - so this is a Python based solution as well (although not quite Pythonic!).
+Once again, in theory, any SQL package or language can be used. Far too few however meet the ease-of-use that DuckDB provides:
+pip install duckdb),all with mind-blowing speed that stands shoulder-to-shoulder with Polars. We’ll also use a few advanced SQL concepts noted below.
+This should be a familiar, albeit not often used concept - a join of a table with itself is a self join. There are few cases where such an operation would make sense, and this happens to be one of them.
+A key concept that we’ll use is the idea of joining on a range of values rather than a specific value. That is, instead of the usual LEFT JOIN ON A.column = B.column, we can do LEFT JOIN ON A.column <= B.column for one row in table A to match to multiple rows in B. DuckDB has a blog post that outlines this join in detail, including fast implementation.
LIST columnsDuckDB has first class support for LIST columns - that is, each row in a LIST column can have a varying length (much like a Python list), but must have the exact same datatype (like R’s vector). Using list columns allow us to eschew the use of an additional GROUP BY operation on top of a WHERE filter or SELECT DISTINCT operation, since we can directly perform those on the LIST column itself.
Dates can be rather difficult to handle well in most tools and languages, with several packages purpose built to make handling them easier - lubridate from the tidyverse is a stellar example. Thankfully, DuckDB provides a similar swiss-knife set of tools to deal with it, including specifying INTERVALs (a special data type that represent a period of time independent of specific time values) to modify TIMESTAMP values using addition or subtraction.
Okay - had a lot of background. Let’s have at it! The query by itself in SQL is (see immediately below for runnable code in Python):
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN (
+16 SELECT *
+17 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+18 FROM data) B
+19
+20ON ((B.arrival_time <= A.window_open AND
+21 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+22 (B.arrival_time >= A.window_open AND
+23 B.departure_time <= A.window_close) OR
+24 (B.arrival_time >= A.window_open AND
+25 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close))
+26GROUP BY 1, 2, 3, 4A small, succinct query such as this will need a bit of explanation to take it all in. Here’s one below, reproducible in Python (make sure to install duckdb first!). Expand it to view.
1import duckdb as db
+ 2db.query("""
+ 3 SELECT
+ 4 A.arrival_time
+ 5 ,A.departure_time
+ 6 ,A.window_open
+ 7 ,A.window_close
+ 8 -- LIST aggregates the values into a LIST column
+ 9 -- and LIST_DISTINCT finds the unique values in it
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 -- finally, LIST_UNIQUE calculates the unique number of values in it
+12 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+13
+14 FROM (
+15 SELECT
+16 *
+17 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+18 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+19 FROM data -- remember we defined data as the Polars DataFrame with our truck station data
+20 ) A
+21
+22 LEFT JOIN (
+23 SELECT
+24 *
+25 -- This is the time, in seconds between the arrival and departure of
+26 -- each truck PER ROW in the original data-frame
+27 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+28 FROM data -- this is where we perform a self-join
+29 ) B
+30
+31 ON (
+32 -- Case 2 in the diagram;
+33 (B.arrival_time <= A.window_open AND
+34 -- Adding the duration here makes sure that the second interval
+35 -- is at least ENDING AFTER the start of the overlap window
+36 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+37
+38 -- Case 3 in the diagram - the simplest of all five cases
+39 (B.arrival_time >= A.window_open AND
+40 B.departure_time <= A.window_close) OR
+41
+42 -- Case 4 in the digram;
+43 (B.arrival_time >= A.window_open AND
+44 -- Subtracting the duration here makes sure that the second interval
+45 -- STARTS BEFORE the end of the overlap window.
+46 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)
+47 )
+48 GROUP BY 1, 2, 3, 4
+49""")The output of this query is:
+ + + + + +"""
+┌─────────────────────┬─────────────────────┬─────────────────────┬───┬──────────────────┬────────────────────┐
+│ arrival_time │ departure_time │ window_open │ … │ docked_trucks │ docked_truck_count │
+│ timestamp │ timestamp │ timestamp │ │ varchar[] │ uint64 │
+├─────────────────────┼─────────────────────┼─────────────────────┼───┼──────────────────┼────────────────────┤
+│ 2023-01-01 06:23:47 │ 2023-01-01 06:25:08 │ 2023-01-01 06:22:47 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:26:42 │ 2023-01-01 06:28:02 │ 2023-01-01 06:25:42 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:30:20 │ 2023-01-01 06:35:01 │ 2023-01-01 06:29:20 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:32:06 │ 2023-01-01 06:33:48 │ 2023-01-01 06:31:06 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:33:09 │ 2023-01-01 06:36:01 │ 2023-01-01 06:32:09 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:34:08 │ 2023-01-01 06:39:49 │ 2023-01-01 06:33:08 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:36:40 │ 2023-01-01 06:38:34 │ 2023-01-01 06:35:40 │ … │ [A5, A6, C3, B3] │ 4 │
+│ 2023-01-01 06:37:43 │ 2023-01-01 06:40:48 │ 2023-01-01 06:36:43 │ … │ [A5, A6, C3] │ 3 │
+│ 2023-01-01 06:39:48 │ 2023-01-01 06:46:10 │ 2023-01-01 06:38:48 │ … │ [A6, A5, C3] │ 3 │
+├─────────────────────┴─────────────────────┴─────────────────────┴───┴──────────────────┴────────────────────┤
+│ 9 rows 6 columns (5 shown) │
+└─────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
+"""We clearly see the strengths of DuckDB in how succintly we were able to express this operation. We also find how DuckDB is able to seamlessly integrate with an existing Pandas or Polars pipeline with zero-conversion costs. In fact, we can convert this back to a Polars or Pandas dataframe by appending the ending bracket with db.query(...).pl() and db.query(...).pd() respectively.
Now that we’ve understood the logic that goes into the query, let’s try to optimize the algorithm. We have the three conditions:
+ + + + + +1-- Case 2 in the diagram
+2(B.arrival_time <= A.window_open AND
+3 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+4-- Case 3 in the diagram
+5(B.arrival_time >= A.window_open AND
+6 B.departure_time <= A.window_close) OR
+7-- Case 4 in the diagram
+8(B.arrival_time >= A.window_open AND
+9 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)What is common between these three conditions? It takes a while to see it; but it becomes clear that all these cases require the start of the overlap to be before the window ends, and the end of the overlap to be after the window starts. This can be simplified to just:
+ + + + + +1B.arrival_time <= A.window_close AND
+2B.departure_time >= A.window_openmaking our query much simpler!
+We’ve removed the need for the duration calculation algother now. Therefore, we can write:
1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN data B
+16
+17ON (
+18 B.arrival_time <= A.window_close AND
+19 B.departure_time >= A.window_open
+20)
+21GROUP BY 1, 2, 3, 4Can we simplify this even further?
+I think the SQL query in the above section is very easy to ready already. However, it is a little clunky overall, and there is a way that we can leverage DuckDB’s extensive optimizations to simplify our legibility by rewriting the query as a cross join:
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 5 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8FROM data A, data B
+ 9WHERE B.arrival_time <= window_close
+10AND B.departure_time >= window_open
+11GROUP BY 1, 2, 3, 4Why does this work? Before optimization on DuckDB, this is what the query plan looks like:
+ 1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ FILTER │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ (arrival_time <= │
+25│(departure_time + to_m... │
+26│ AS BIGINT)))) │
+27│ (departure_time >= │
+28│(arrival_time - to_min... │
+29│ AS BIGINT)))) │
+30└─────────────┬─────────────┘
+31┌─────────────┴─────────────┐
+32│ CROSS_PRODUCT ├──────────────┐
+33└─────────────┬─────────────┘ │
+34┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+35│ ARROW_SCAN ││ ARROW_SCAN │
+36└───────────────────────────┘└───────────────────────────┘
+37""" After optimization, the CROSS_PRODUCT is automatically optimized to an interval join!
1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ COMPARISON_JOIN │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ INNER │
+25│ ((departure_time + '00:01 │
+26│ :00'::INTERVAL) >= ├──────────────┐
+27│ arrival_time) │ │
+28│((arrival_time - '00:01:00'│ │
+29│ ::INTERVAL) <= │ │
+30│ departure_time) │ │
+31└─────────────┬─────────────┘ │
+32┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+33│ ARROW_SCAN ││ ARROW_SCAN │
+34└───────────────────────────┘└───────────────────────────┘
+35""" So in effect, we’re actually exploiting a feature of DuckDB to allow us to write our queries in a suboptimal manner for greater readability, and allowing the optmizer to do a good chunk of our work for us. I wouldn’t recommend using this generally, because not all SQL engine optmizers will be able to find an efficient route to these calculations for large datasets.
+I’m glad you asked. Here’s the DuckDB page explaining EXPLAIN (heh). Here’s the code I used:
1import duckdb as db
+ 2db.sql("SET EXPLAIN_OUTPUT='all';")
+ 3print(db.query("""
+ 4EXPLAIN
+ 5SELECT
+ 6 A.arrival_time
+ 7 ,A.departure_time
+ 8 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 9 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+12FROM data A, data B
+13WHERE B.arrival_time <= window_close
+14AND B.departure_time >= window_open
+15GROUP BY 1, 2, 3, 4
+16""").pl()[1, 1])data.table waydata.table is a package that has historically been ahead of its time - in both speed and features that it has had. Developement has taken a hit recently, but will likely pick back up. It’s my favourite package on all fronts for data manipulation, but suffers simply from the lack of broader R support across the ML and DL space.
foverlaps functionIf this kind of overlapping join is common, shouldn’t someone have developed a package for it? Turns out, data.table has, and with very specific constraints that make it the perfect solution to our problem (if you don’t mind switching over to R, that is).
The foverlaps function has these requirements:
data.table objects have to be keyed for automatic recognition of columns.within overlap, matching start and end windows,by must specify the start and end points of the overlapping window. This isn’t a problem for us now, but does restrict for future uses where we may want non-equi joins on other cases.Without further ado:
+ + + + + + 1library(data.table)
+ 2library(lubridate)
+ 3
+ 4######### BOILERPLATE CODE, NO LOGIC HERE ####################
+ 5arrival_time = as_datetime(c(
+ 6 '2023-01-01 06:23:47.000000', '2023-01-01 06:26:42.000000',
+ 7 '2023-01-01 06:30:20.000000', '2023-01-01 06:32:06.000000',
+ 8 '2023-01-01 06:33:09.000000', '2023-01-01 06:34:08.000000',
+ 9 '2023-01-01 06:36:40.000000', '2023-01-01 06:37:43.000000',
+10 '2023-01-01 06:39:48.000000'))
+11departure_time = as_datetime(c(
+12 '2023-01-01 06:25:08.000000', '2023-01-01 06:28:02.000000',
+13 '2023-01-01 06:35:01.000000', '2023-01-01 06:33:48.000000',
+14 '2023-01-01 06:36:01.000000', '2023-01-01 06:39:49.000000',
+15 '2023-01-01 06:38:34.000000', '2023-01-01 06:40:48.000000',
+16 '2023-01-01 06:46:10.000000'))
+17ID = c('A1', 'A1', 'A5', 'A6', 'B3', 'C3', 'A6', 'A5', 'A6')
+18
+19DT = data.table(
+20 arrival_time = arrival_time,
+21 departure_time = departure_time,
+22 ID = ID)
+23######### BOILERPLATE CODE, NO LOGIC HERE ####################
+24
+25# A copy(DT) creates a copy of a data.table that isn't linked
+26# to the original one, so that changes in it don't reflect in
+27# the original DT object.
+28# The `:=` allow assignment by reference (i.e. "in place").
+29DT_with_windows = copy(DT)[, `:=`(
+30 window_start = arrival_time - minutes(1),
+31 window_end = departure_time + minutes(1))]
+32
+33# This step is necessary for the second table, but not the first, but we
+34# key both data.tables to make the foverlap code very succinct.
+35setkeyv(DT, c("arrival_time", "departure_time"))
+36setkeyv(DT_with_windows, c("window_start", "window_end"))
+37
+38# The foverlap function returns a data.table, so we can simply apply
+39# the usual data.table syntax on it!
+40# Since we have the same name of some columns in both data.tables,
+41# the latter table's columns are prefixed with "i." to avoid conflicts.
+42foverlaps(DT, DT_with_windows)[
+43 , .(docked_trucks = list(unique(i.ID)),
+44 docked_truck_count = uniqueN(i.ID))
+45 , .(arrival_time, departure_time)]provides us the output:
+ + + + + + 1 arrival_time departure_time docked_trucks docked_truck_count
+ 2 <POSc> <POSc> <list> <int>
+ 31: 2023-01-01 06:23:47 2023-01-01 06:25:08 A1 1
+ 42: 2023-01-01 06:26:42 2023-01-01 06:28:02 A1 1
+ 53: 2023-01-01 06:30:20 2023-01-01 06:35:01 A5,A6,B3,C3 4
+ 64: 2023-01-01 06:32:06 2023-01-01 06:33:48 A5,A6,B3,C3 4
+ 75: 2023-01-01 06:33:09 2023-01-01 06:36:01 A5,A6,B3,C3 4
+ 86: 2023-01-01 06:34:08 2023-01-01 06:39:49 A5,A6,B3,C3 4
+ 97: 2023-01-01 06:36:40 2023-01-01 06:38:34 B3,C3,A6,A5 4
+108: 2023-01-01 06:37:43 2023-01-01 06:40:48 C3,A6,A5 3
+119: 2023-01-01 06:39:48 2023-01-01 06:46:10 C3,A5,A6 3data.tableThe package offers a wonderful, nearly one-stop solution that doesn’t require you to write the logic out for the query or command yourself, but has a major problem for a lot of users - it requires you to switch your codebase to R, and a lot of your tasks may be on Python or in an SQL pipeline. So, what do you do?
+Consider the effort in maintaining an additional dependency for your analytics pipeline (i.e. R), and the effort that you’ll need to invest to run R from Python, or run an R script in your pipeline and pull the output from it back into the pipeline, and make your call.
+]]>I stumbled upon an interesting Stackoverflow question that was linked via an issue on Polars github repo. The OP asked for a pure Polars solution. At the time of answering the question Polars did not have support for non-equi joins, and any solution using it would be pretty cumbersome.
+I’m more of a right-tool-for-the-job person, so I tried to find a better solution.
+Suppose we have a dataset that captures the arrival and departure times of trucks at a station, along with the truck’s ID.
+ + + + + + 1import polars as pl # if you don't have polars, run
+ 2 # pip install 'polars[all]'
+ 3data = pl.from_repr("""
+ 4┌─────────────────────┬─────────────────────┬─────┐
+ 5│ arrival_time ┆ departure_time ┆ ID │
+ 6│ --- ┆ --- ┆ --- │
+ 7│ datetime[μs] ┆ datetime[μs] ┆ str │
+ 8╞═════════════════════╪═════════════════════╪═════╡
+ 9│ 2023-01-01 06:23:47 ┆ 2023-01-01 06:25:08 ┆ A1 │
+10│ 2023-01-01 06:26:42 ┆ 2023-01-01 06:28:02 ┆ A1 │
+11│ 2023-01-01 06:30:20 ┆ 2023-01-01 06:35:01 ┆ A5 │
+12│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+13│ 2023-01-01 06:33:09 ┆ 2023-01-01 06:36:01 ┆ B3 │
+14│ 2023-01-01 06:34:08 ┆ 2023-01-01 06:39:49 ┆ C3 │
+15│ 2023-01-01 06:36:40 ┆ 2023-01-01 06:38:34 ┆ A6 │
+16│ 2023-01-01 06:37:43 ┆ 2023-01-01 06:40:48 ┆ A5 │
+17│ 2023-01-01 06:39:48 ┆ 2023-01-01 06:46:10 ┆ A6 │
+18└─────────────────────┴─────────────────────┴─────┘
+19""")We want to identify the number of trucks docked at any given time within a threshold of 1 minute prior to the arrival time of a truck, and 1 minute after the departure of a truck. Equivalently, this means that we need to calculate the number of trucks within a specific window for each row of the data.
+Before we find a general solution to this problem, let’s consider a specific row to understand the problem better:
+ + + + + +1"""
+2┌─────────────────────┬─────────────────────┬─────┐
+3│ arrival_time ┆ departure_time ┆ ID │
+4│ --- ┆ --- ┆ --- │
+5│ datetime[μs] ┆ datetime[μs] ┆ str │
+6╞═════════════════════╪═════════════════════╪═════╡
+7│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+8└─────────────────────┴─────────────────────┴─────┘
+9"""For this row, we need to find the number of trucks that are there between 2023-01-01 06:31:06 (1 minute prior to the arrival_time and 2023-01-01 06:34:48 (1 minute post the departure_time). Manually going through the original dataset, we see that B3, C3, A6 and A5 are the truck IDs that qualify - they all are at the station in a duration that is between 2023-01-01 06:31:06 and 2023-01-01 06:34:48.
There are many cases that will qualify a truck to be present in the overlap window defined by a particular row. Specifically for the example above, we have (this visualization is generalizable, because for each row we can calculate without much difficulty the overlap window relative to the arrival and departure times):
+
Take some time to absorb these cases - it’s important for the part where we write the code for the solution. Note that we need to actually tell our algorithm to filter only for Cases 2, 3 and 4, since Cases 1 and 5 will not satisfy our requirements.
+In theory, we can use any language that has the capability to define rules that meet our algorithmic requirements outlined in the above section to find the solution. Why choose SQL? It’s often able to convey elegantly the logic that was used to execute the algorithm; and while it does come with excessive verbosity at times, it doesn’t quite in this case.
+Note here that we run SQL in Python with almost no setup or boilerplate code - so this is a Python based solution as well (although not quite Pythonic!).
+Once again, in theory, any SQL package or language can be used. Far too few however meet the ease-of-use that DuckDB provides:
+pip install duckdb),all with mind-blowing speed that stands shoulder-to-shoulder with Polars. We’ll also use a few advanced SQL concepts noted below.
+This should be a familiar, albeit not often used concept - a join of a table with itself is a self join. There are few cases where such an operation would make sense, and this happens to be one of them.
+A key concept that we’ll use is the idea of joining on a range of values rather than a specific value. That is, instead of the usual LEFT JOIN ON A.column = B.column, we can do LEFT JOIN ON A.column <= B.column for one row in table A to match to multiple rows in B. DuckDB has a blog post that outlines this join in detail, including fast implementation.
LIST columnsDuckDB has first class support for LIST columns - that is, each row in a LIST column can have a varying length (much like a Python list), but must have the exact same datatype (like R’s vector). Using list columns allow us to eschew the use of an additional GROUP BY operation on top of a WHERE filter or SELECT DISTINCT operation, since we can directly perform those on the LIST column itself.
Dates can be rather difficult to handle well in most tools and languages, with several packages purpose built to make handling them easier - lubridate from the tidyverse is a stellar example. Thankfully, DuckDB provides a similar swiss-knife set of tools to deal with it, including specifying INTERVALs (a special data type that represent a period of time independent of specific time values) to modify TIMESTAMP values using addition or subtraction.
Okay - had a lot of background. Let’s have at it! The query by itself in SQL is (see immediately below for runnable code in Python):
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN (
+16 SELECT *
+17 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+18 FROM data) B
+19
+20ON ((B.arrival_time <= A.window_open AND
+21 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+22 (B.arrival_time >= A.window_open AND
+23 B.departure_time <= A.window_close) OR
+24 (B.arrival_time >= A.window_open AND
+25 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close))
+26GROUP BY 1, 2, 3, 4A small, succinct query such as this will need a bit of explanation to take it all in. Here’s one below, reproducible in Python (make sure to install duckdb first!). Expand it to view.
1import duckdb as db
+ 2db.query("""
+ 3 SELECT
+ 4 A.arrival_time
+ 5 ,A.departure_time
+ 6 ,A.window_open
+ 7 ,A.window_close
+ 8 -- LIST aggregates the values into a LIST column
+ 9 -- and LIST_DISTINCT finds the unique values in it
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 -- finally, LIST_UNIQUE calculates the unique number of values in it
+12 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+13
+14 FROM (
+15 SELECT
+16 *
+17 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+18 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+19 FROM data -- remember we defined data as the Polars DataFrame with our truck station data
+20 ) A
+21
+22 LEFT JOIN (
+23 SELECT
+24 *
+25 -- This is the time, in seconds between the arrival and departure of
+26 -- each truck PER ROW in the original data-frame
+27 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+28 FROM data -- this is where we perform a self-join
+29 ) B
+30
+31 ON (
+32 -- Case 2 in the diagram;
+33 (B.arrival_time <= A.window_open AND
+34 -- Adding the duration here makes sure that the second interval
+35 -- is at least ENDING AFTER the start of the overlap window
+36 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+37
+38 -- Case 3 in the diagram - the simplest of all five cases
+39 (B.arrival_time >= A.window_open AND
+40 B.departure_time <= A.window_close) OR
+41
+42 -- Case 4 in the digram;
+43 (B.arrival_time >= A.window_open AND
+44 -- Subtracting the duration here makes sure that the second interval
+45 -- STARTS BEFORE the end of the overlap window.
+46 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)
+47 )
+48 GROUP BY 1, 2, 3, 4
+49""")The output of this query is:
+ + + + + +"""
+┌─────────────────────┬─────────────────────┬─────────────────────┬───┬──────────────────┬────────────────────┐
+│ arrival_time │ departure_time │ window_open │ … │ docked_trucks │ docked_truck_count │
+│ timestamp │ timestamp │ timestamp │ │ varchar[] │ uint64 │
+├─────────────────────┼─────────────────────┼─────────────────────┼───┼──────────────────┼────────────────────┤
+│ 2023-01-01 06:23:47 │ 2023-01-01 06:25:08 │ 2023-01-01 06:22:47 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:26:42 │ 2023-01-01 06:28:02 │ 2023-01-01 06:25:42 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:30:20 │ 2023-01-01 06:35:01 │ 2023-01-01 06:29:20 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:32:06 │ 2023-01-01 06:33:48 │ 2023-01-01 06:31:06 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:33:09 │ 2023-01-01 06:36:01 │ 2023-01-01 06:32:09 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:34:08 │ 2023-01-01 06:39:49 │ 2023-01-01 06:33:08 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:36:40 │ 2023-01-01 06:38:34 │ 2023-01-01 06:35:40 │ … │ [A5, A6, C3, B3] │ 4 │
+│ 2023-01-01 06:37:43 │ 2023-01-01 06:40:48 │ 2023-01-01 06:36:43 │ … │ [A5, A6, C3] │ 3 │
+│ 2023-01-01 06:39:48 │ 2023-01-01 06:46:10 │ 2023-01-01 06:38:48 │ … │ [A6, A5, C3] │ 3 │
+├─────────────────────┴─────────────────────┴─────────────────────┴───┴──────────────────┴────────────────────┤
+│ 9 rows 6 columns (5 shown) │
+└─────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
+"""We clearly see the strengths of DuckDB in how succintly we were able to express this operation. We also find how DuckDB is able to seamlessly integrate with an existing Pandas or Polars pipeline with zero-conversion costs. In fact, we can convert this back to a Polars or Pandas dataframe by appending the ending bracket with db.query(...).pl() and db.query(...).pd() respectively.
Now that we’ve understood the logic that goes into the query, let’s try to optimize the algorithm. We have the three conditions:
+ + + + + +1-- Case 2 in the diagram
+2(B.arrival_time <= A.window_open AND
+3 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+4-- Case 3 in the diagram
+5(B.arrival_time >= A.window_open AND
+6 B.departure_time <= A.window_close) OR
+7-- Case 4 in the diagram
+8(B.arrival_time >= A.window_open AND
+9 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)What is common between these three conditions? It takes a while to see it; but it becomes clear that all these cases require the start of the overlap to be before the window ends, and the end of the overlap to be after the window starts. This can be simplified to just:
+ + + + + +1B.arrival_time <= A.window_close AND
+2B.departure_time >= A.window_openmaking our query much simpler!
+We’ve removed the need for the duration calculation algother now. Therefore, we can write:
1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN data B
+16
+17ON (
+18 B.arrival_time <= A.window_close AND
+19 B.departure_time >= A.window_open
+20)
+21GROUP BY 1, 2, 3, 4Can we simplify this even further?
+I think the SQL query in the above section is very easy to ready already. However, it is a little clunky overall, and there is a way that we can leverage DuckDB’s extensive optimizations to simplify our legibility by rewriting the query as a cross join:
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 5 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8FROM data A, data B
+ 9WHERE B.arrival_time <= window_close
+10AND B.departure_time >= window_open
+11GROUP BY 1, 2, 3, 4Why does this work? Before optimization on DuckDB, this is what the query plan looks like:
+ 1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ FILTER │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ (arrival_time <= │
+25│(departure_time + to_m... │
+26│ AS BIGINT)))) │
+27│ (departure_time >= │
+28│(arrival_time - to_min... │
+29│ AS BIGINT)))) │
+30└─────────────┬─────────────┘
+31┌─────────────┴─────────────┐
+32│ CROSS_PRODUCT ├──────────────┐
+33└─────────────┬─────────────┘ │
+34┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+35│ ARROW_SCAN ││ ARROW_SCAN │
+36└───────────────────────────┘└───────────────────────────┘
+37""" After optimization, the CROSS_PRODUCT is automatically optimized to an interval join!
1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ COMPARISON_JOIN │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ INNER │
+25│ ((departure_time + '00:01 │
+26│ :00'::INTERVAL) >= ├──────────────┐
+27│ arrival_time) │ │
+28│((arrival_time - '00:01:00'│ │
+29│ ::INTERVAL) <= │ │
+30│ departure_time) │ │
+31└─────────────┬─────────────┘ │
+32┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+33│ ARROW_SCAN ││ ARROW_SCAN │
+34└───────────────────────────┘└───────────────────────────┘
+35""" So in effect, we’re actually exploiting a feature of DuckDB to allow us to write our queries in a suboptimal manner for greater readability, and allowing the optmizer to do a good chunk of our work for us. I wouldn’t recommend using this generally, because not all SQL engine optmizers will be able to find an efficient route to these calculations for large datasets.
+I’m glad you asked. Here’s the DuckDB page explaining EXPLAIN (heh). Here’s the code I used:
1import duckdb as db
+ 2db.sql("SET EXPLAIN_OUTPUT='all';")
+ 3print(db.query("""
+ 4EXPLAIN
+ 5SELECT
+ 6 A.arrival_time
+ 7 ,A.departure_time
+ 8 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 9 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+12FROM data A, data B
+13WHERE B.arrival_time <= window_close
+14AND B.departure_time >= window_open
+15GROUP BY 1, 2, 3, 4
+16""").pl()[1, 1])data.table waydata.table is a package that has historically been ahead of its time - in both speed and features that it has had. Developement has taken a hit recently, but will likely pick back up. It’s my favourite package on all fronts for data manipulation, but suffers simply from the lack of broader R support across the ML and DL space.
foverlaps functionIf this kind of overlapping join is common, shouldn’t someone have developed a package for it? Turns out, data.table has, and with very specific constraints that make it the perfect solution to our problem (if you don’t mind switching over to R, that is).
The foverlaps function has these requirements:
data.table objects have to be keyed for automatic recognition of columns.within overlap, matching start and end windows,by must specify the start and end points of the overlapping window. This isn’t a problem for us now, but does restrict for future uses where we may want non-equi joins on other cases.Without further ado:
+ + + + + + 1library(data.table)
+ 2library(lubridate)
+ 3
+ 4######### BOILERPLATE CODE, NO LOGIC HERE ####################
+ 5arrival_time = as_datetime(c(
+ 6 '2023-01-01 06:23:47.000000', '2023-01-01 06:26:42.000000',
+ 7 '2023-01-01 06:30:20.000000', '2023-01-01 06:32:06.000000',
+ 8 '2023-01-01 06:33:09.000000', '2023-01-01 06:34:08.000000',
+ 9 '2023-01-01 06:36:40.000000', '2023-01-01 06:37:43.000000',
+10 '2023-01-01 06:39:48.000000'))
+11departure_time = as_datetime(c(
+12 '2023-01-01 06:25:08.000000', '2023-01-01 06:28:02.000000',
+13 '2023-01-01 06:35:01.000000', '2023-01-01 06:33:48.000000',
+14 '2023-01-01 06:36:01.000000', '2023-01-01 06:39:49.000000',
+15 '2023-01-01 06:38:34.000000', '2023-01-01 06:40:48.000000',
+16 '2023-01-01 06:46:10.000000'))
+17ID = c('A1', 'A1', 'A5', 'A6', 'B3', 'C3', 'A6', 'A5', 'A6')
+18
+19DT = data.table(
+20 arrival_time = arrival_time,
+21 departure_time = departure_time,
+22 ID = ID)
+23######### BOILERPLATE CODE, NO LOGIC HERE ####################
+24
+25# A copy(DT) creates a copy of a data.table that isn't linked
+26# to the original one, so that changes in it don't reflect in
+27# the original DT object.
+28# The `:=` allow assignment by reference (i.e. "in place").
+29DT_with_windows = copy(DT)[, `:=`(
+30 window_start = arrival_time - minutes(1),
+31 window_end = departure_time + minutes(1))]
+32
+33# This step is necessary for the second table, but not the first, but we
+34# key both data.tables to make the foverlap code very succinct.
+35setkeyv(DT, c("arrival_time", "departure_time"))
+36setkeyv(DT_with_windows, c("window_start", "window_end"))
+37
+38# The foverlap function returns a data.table, so we can simply apply
+39# the usual data.table syntax on it!
+40# Since we have the same name of some columns in both data.tables,
+41# the latter table's columns are prefixed with "i." to avoid conflicts.
+42foverlaps(DT, DT_with_windows)[
+43 , .(docked_trucks = list(unique(i.ID)),
+44 docked_truck_count = uniqueN(i.ID))
+45 , .(arrival_time, departure_time)]provides us the output:
+ + + + + + 1 arrival_time departure_time docked_trucks docked_truck_count
+ 2 <POSc> <POSc> <list> <int>
+ 31: 2023-01-01 06:23:47 2023-01-01 06:25:08 A1 1
+ 42: 2023-01-01 06:26:42 2023-01-01 06:28:02 A1 1
+ 53: 2023-01-01 06:30:20 2023-01-01 06:35:01 A5,A6,B3,C3 4
+ 64: 2023-01-01 06:32:06 2023-01-01 06:33:48 A5,A6,B3,C3 4
+ 75: 2023-01-01 06:33:09 2023-01-01 06:36:01 A5,A6,B3,C3 4
+ 86: 2023-01-01 06:34:08 2023-01-01 06:39:49 A5,A6,B3,C3 4
+ 97: 2023-01-01 06:36:40 2023-01-01 06:38:34 B3,C3,A6,A5 4
+108: 2023-01-01 06:37:43 2023-01-01 06:40:48 C3,A6,A5 3
+119: 2023-01-01 06:39:48 2023-01-01 06:46:10 C3,A5,A6 3data.tableThe package offers a wonderful, nearly one-stop solution that doesn’t require you to write the logic out for the query or command yourself, but has a major problem for a lot of users - it requires you to switch your codebase to R, and a lot of your tasks may be on Python or in an SQL pipeline. So, what do you do?
+Consider the effort in maintaining an additional dependency for your analytics pipeline (i.e. R), and the effort that you’ll need to invest to run R from Python, or run an R script in your pipeline and pull the output from it back into the pipeline, and make your call.
+ ++ +
+ + + + + + +I stumbled upon an interesting Stackoverflow question that was linked via an issue on Polars github repo. The OP asked for a pure Polars solution. At the time of answering the question Polars did not have support for non-equi joins, and any solution using it would be pretty cumbersome.
+I’m more of a right-tool-for-the-job person, so I tried to find a better solution.
+Suppose we have a dataset that captures the arrival and departure times of trucks at a station, along with the truck’s ID.
+ + + + + + 1import polars as pl # if you don't have polars, run
+ 2 # pip install 'polars[all]'
+ 3data = pl.from_repr("""
+ 4┌─────────────────────┬─────────────────────┬─────┐
+ 5│ arrival_time ┆ departure_time ┆ ID │
+ 6│ --- ┆ --- ┆ --- │
+ 7│ datetime[μs] ┆ datetime[μs] ┆ str │
+ 8╞═════════════════════╪═════════════════════╪═════╡
+ 9│ 2023-01-01 06:23:47 ┆ 2023-01-01 06:25:08 ┆ A1 │
+10│ 2023-01-01 06:26:42 ┆ 2023-01-01 06:28:02 ┆ A1 │
+11│ 2023-01-01 06:30:20 ┆ 2023-01-01 06:35:01 ┆ A5 │
+12│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+13│ 2023-01-01 06:33:09 ┆ 2023-01-01 06:36:01 ┆ B3 │
+14│ 2023-01-01 06:34:08 ┆ 2023-01-01 06:39:49 ┆ C3 │
+15│ 2023-01-01 06:36:40 ┆ 2023-01-01 06:38:34 ┆ A6 │
+16│ 2023-01-01 06:37:43 ┆ 2023-01-01 06:40:48 ┆ A5 │
+17│ 2023-01-01 06:39:48 ┆ 2023-01-01 06:46:10 ┆ A6 │
+18└─────────────────────┴─────────────────────┴─────┘
+19""")We want to identify the number of trucks docked at any given time within a threshold of 1 minute prior to the arrival time of a truck, and 1 minute after the departure of a truck. Equivalently, this means that we need to calculate the number of trucks within a specific window for each row of the data.
+Before we find a general solution to this problem, let’s consider a specific row to understand the problem better:
+ + + + + +1"""
+2┌─────────────────────┬─────────────────────┬─────┐
+3│ arrival_time ┆ departure_time ┆ ID │
+4│ --- ┆ --- ┆ --- │
+5│ datetime[μs] ┆ datetime[μs] ┆ str │
+6╞═════════════════════╪═════════════════════╪═════╡
+7│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+8└─────────────────────┴─────────────────────┴─────┘
+9"""For this row, we need to find the number of trucks that are there between 2023-01-01 06:31:06 (1 minute prior to the arrival_time and 2023-01-01 06:34:48 (1 minute post the departure_time). Manually going through the original dataset, we see that B3, C3, A6 and A5 are the truck IDs that qualify - they all are at the station in a duration that is between 2023-01-01 06:31:06 and 2023-01-01 06:34:48.
There are many cases that will qualify a truck to be present in the overlap window defined by a particular row. Specifically for the example above, we have (this visualization is generalizable, because for each row we can calculate without much difficulty the overlap window relative to the arrival and departure times):
+
Take some time to absorb these cases - it’s important for the part where we write the code for the solution. Note that we need to actually tell our algorithm to filter only for Cases 2, 3 and 4, since Cases 1 and 5 will not satisfy our requirements.
+In theory, we can use any language that has the capability to define rules that meet our algorithmic requirements outlined in the above section to find the solution. Why choose SQL? It’s often able to convey elegantly the logic that was used to execute the algorithm; and while it does come with excessive verbosity at times, it doesn’t quite in this case.
+Note here that we run SQL in Python with almost no setup or boilerplate code - so this is a Python based solution as well (although not quite Pythonic!).
+Once again, in theory, any SQL package or language can be used. Far too few however meet the ease-of-use that DuckDB provides:
+pip install duckdb),all with mind-blowing speed that stands shoulder-to-shoulder with Polars. We’ll also use a few advanced SQL concepts noted below.
+This should be a familiar, albeit not often used concept - a join of a table with itself is a self join. There are few cases where such an operation would make sense, and this happens to be one of them.
+A key concept that we’ll use is the idea of joining on a range of values rather than a specific value. That is, instead of the usual LEFT JOIN ON A.column = B.column, we can do LEFT JOIN ON A.column <= B.column for one row in table A to match to multiple rows in B. DuckDB has a blog post that outlines this join in detail, including fast implementation.
LIST columnsDuckDB has first class support for LIST columns - that is, each row in a LIST column can have a varying length (much like a Python list), but must have the exact same datatype (like R’s vector). Using list columns allow us to eschew the use of an additional GROUP BY operation on top of a WHERE filter or SELECT DISTINCT operation, since we can directly perform those on the LIST column itself.
Dates can be rather difficult to handle well in most tools and languages, with several packages purpose built to make handling them easier - lubridate from the tidyverse is a stellar example. Thankfully, DuckDB provides a similar swiss-knife set of tools to deal with it, including specifying INTERVALs (a special data type that represent a period of time independent of specific time values) to modify TIMESTAMP values using addition or subtraction.
Okay - had a lot of background. Let’s have at it! The query by itself in SQL is (see immediately below for runnable code in Python):
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN (
+16 SELECT *
+17 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+18 FROM data) B
+19
+20ON ((B.arrival_time <= A.window_open AND
+21 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+22 (B.arrival_time >= A.window_open AND
+23 B.departure_time <= A.window_close) OR
+24 (B.arrival_time >= A.window_open AND
+25 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close))
+26GROUP BY 1, 2, 3, 4A small, succinct query such as this will need a bit of explanation to take it all in. Here’s one below, reproducible in Python (make sure to install duckdb first!). Expand it to view.
1import duckdb as db
+ 2db.query("""
+ 3 SELECT
+ 4 A.arrival_time
+ 5 ,A.departure_time
+ 6 ,A.window_open
+ 7 ,A.window_close
+ 8 -- LIST aggregates the values into a LIST column
+ 9 -- and LIST_DISTINCT finds the unique values in it
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 -- finally, LIST_UNIQUE calculates the unique number of values in it
+12 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+13
+14 FROM (
+15 SELECT
+16 *
+17 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+18 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+19 FROM data -- remember we defined data as the Polars DataFrame with our truck station data
+20 ) A
+21
+22 LEFT JOIN (
+23 SELECT
+24 *
+25 -- This is the time, in seconds between the arrival and departure of
+26 -- each truck PER ROW in the original data-frame
+27 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+28 FROM data -- this is where we perform a self-join
+29 ) B
+30
+31 ON (
+32 -- Case 2 in the diagram;
+33 (B.arrival_time <= A.window_open AND
+34 -- Adding the duration here makes sure that the second interval
+35 -- is at least ENDING AFTER the start of the overlap window
+36 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+37
+38 -- Case 3 in the diagram - the simplest of all five cases
+39 (B.arrival_time >= A.window_open AND
+40 B.departure_time <= A.window_close) OR
+41
+42 -- Case 4 in the digram;
+43 (B.arrival_time >= A.window_open AND
+44 -- Subtracting the duration here makes sure that the second interval
+45 -- STARTS BEFORE the end of the overlap window.
+46 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)
+47 )
+48 GROUP BY 1, 2, 3, 4
+49""")The output of this query is:
+ + + + + +"""
+┌─────────────────────┬─────────────────────┬─────────────────────┬───┬──────────────────┬────────────────────┐
+│ arrival_time │ departure_time │ window_open │ … │ docked_trucks │ docked_truck_count │
+│ timestamp │ timestamp │ timestamp │ │ varchar[] │ uint64 │
+├─────────────────────┼─────────────────────┼─────────────────────┼───┼──────────────────┼────────────────────┤
+│ 2023-01-01 06:23:47 │ 2023-01-01 06:25:08 │ 2023-01-01 06:22:47 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:26:42 │ 2023-01-01 06:28:02 │ 2023-01-01 06:25:42 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:30:20 │ 2023-01-01 06:35:01 │ 2023-01-01 06:29:20 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:32:06 │ 2023-01-01 06:33:48 │ 2023-01-01 06:31:06 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:33:09 │ 2023-01-01 06:36:01 │ 2023-01-01 06:32:09 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:34:08 │ 2023-01-01 06:39:49 │ 2023-01-01 06:33:08 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:36:40 │ 2023-01-01 06:38:34 │ 2023-01-01 06:35:40 │ … │ [A5, A6, C3, B3] │ 4 │
+│ 2023-01-01 06:37:43 │ 2023-01-01 06:40:48 │ 2023-01-01 06:36:43 │ … │ [A5, A6, C3] │ 3 │
+│ 2023-01-01 06:39:48 │ 2023-01-01 06:46:10 │ 2023-01-01 06:38:48 │ … │ [A6, A5, C3] │ 3 │
+├─────────────────────┴─────────────────────┴─────────────────────┴───┴──────────────────┴────────────────────┤
+│ 9 rows 6 columns (5 shown) │
+└─────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
+"""We clearly see the strengths of DuckDB in how succintly we were able to express this operation. We also find how DuckDB is able to seamlessly integrate with an existing Pandas or Polars pipeline with zero-conversion costs. In fact, we can convert this back to a Polars or Pandas dataframe by appending the ending bracket with db.query(...).pl() and db.query(...).pd() respectively.
Now that we’ve understood the logic that goes into the query, let’s try to optimize the algorithm. We have the three conditions:
+ + + + + +1-- Case 2 in the diagram
+2(B.arrival_time <= A.window_open AND
+3 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+4-- Case 3 in the diagram
+5(B.arrival_time >= A.window_open AND
+6 B.departure_time <= A.window_close) OR
+7-- Case 4 in the diagram
+8(B.arrival_time >= A.window_open AND
+9 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)What is common between these three conditions? It takes a while to see it; but it becomes clear that all these cases require the start of the overlap to be before the window ends, and the end of the overlap to be after the window starts. This can be simplified to just:
+ + + + + +1B.arrival_time <= A.window_close AND
+2B.departure_time >= A.window_openmaking our query much simpler!
+We’ve removed the need for the duration calculation algother now. Therefore, we can write:
1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN data B
+16
+17ON (
+18 B.arrival_time <= A.window_close AND
+19 B.departure_time >= A.window_open
+20)
+21GROUP BY 1, 2, 3, 4Can we simplify this even further?
+I think the SQL query in the above section is very easy to ready already. However, it is a little clunky overall, and there is a way that we can leverage DuckDB’s extensive optimizations to simplify our legibility by rewriting the query as a cross join:
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 5 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8FROM data A, data B
+ 9WHERE B.arrival_time <= window_close
+10AND B.departure_time >= window_open
+11GROUP BY 1, 2, 3, 4Why does this work? Before optimization on DuckDB, this is what the query plan looks like:
+ 1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ FILTER │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ (arrival_time <= │
+25│(departure_time + to_m... │
+26│ AS BIGINT)))) │
+27│ (departure_time >= │
+28│(arrival_time - to_min... │
+29│ AS BIGINT)))) │
+30└─────────────┬─────────────┘
+31┌─────────────┴─────────────┐
+32│ CROSS_PRODUCT ├──────────────┐
+33└─────────────┬─────────────┘ │
+34┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+35│ ARROW_SCAN ││ ARROW_SCAN │
+36└───────────────────────────┘└───────────────────────────┘
+37""" After optimization, the CROSS_PRODUCT is automatically optimized to an interval join!
1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ COMPARISON_JOIN │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ INNER │
+25│ ((departure_time + '00:01 │
+26│ :00'::INTERVAL) >= ├──────────────┐
+27│ arrival_time) │ │
+28│((arrival_time - '00:01:00'│ │
+29│ ::INTERVAL) <= │ │
+30│ departure_time) │ │
+31└─────────────┬─────────────┘ │
+32┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+33│ ARROW_SCAN ││ ARROW_SCAN │
+34└───────────────────────────┘└───────────────────────────┘
+35""" So in effect, we’re actually exploiting a feature of DuckDB to allow us to write our queries in a suboptimal manner for greater readability, and allowing the optmizer to do a good chunk of our work for us. I wouldn’t recommend using this generally, because not all SQL engine optmizers will be able to find an efficient route to these calculations for large datasets.
+I’m glad you asked. Here’s the DuckDB page explaining EXPLAIN (heh). Here’s the code I used:
1import duckdb as db
+ 2db.sql("SET EXPLAIN_OUTPUT='all';")
+ 3print(db.query("""
+ 4EXPLAIN
+ 5SELECT
+ 6 A.arrival_time
+ 7 ,A.departure_time
+ 8 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 9 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+12FROM data A, data B
+13WHERE B.arrival_time <= window_close
+14AND B.departure_time >= window_open
+15GROUP BY 1, 2, 3, 4
+16""").pl()[1, 1])data.table waydata.table is a package that has historically been ahead of its time - in both speed and features that it has had. Developement has taken a hit recently, but will likely pick back up. It’s my favourite package on all fronts for data manipulation, but suffers simply from the lack of broader R support across the ML and DL space.
foverlaps functionIf this kind of overlapping join is common, shouldn’t someone have developed a package for it? Turns out, data.table has, and with very specific constraints that make it the perfect solution to our problem (if you don’t mind switching over to R, that is).
The foverlaps function has these requirements:
data.table objects have to be keyed for automatic recognition of columns.within overlap, matching start and end windows,by must specify the start and end points of the overlapping window. This isn’t a problem for us now, but does restrict for future uses where we may want non-equi joins on other cases.Without further ado:
+ + + + + + 1library(data.table)
+ 2library(lubridate)
+ 3
+ 4######### BOILERPLATE CODE, NO LOGIC HERE ####################
+ 5arrival_time = as_datetime(c(
+ 6 '2023-01-01 06:23:47.000000', '2023-01-01 06:26:42.000000',
+ 7 '2023-01-01 06:30:20.000000', '2023-01-01 06:32:06.000000',
+ 8 '2023-01-01 06:33:09.000000', '2023-01-01 06:34:08.000000',
+ 9 '2023-01-01 06:36:40.000000', '2023-01-01 06:37:43.000000',
+10 '2023-01-01 06:39:48.000000'))
+11departure_time = as_datetime(c(
+12 '2023-01-01 06:25:08.000000', '2023-01-01 06:28:02.000000',
+13 '2023-01-01 06:35:01.000000', '2023-01-01 06:33:48.000000',
+14 '2023-01-01 06:36:01.000000', '2023-01-01 06:39:49.000000',
+15 '2023-01-01 06:38:34.000000', '2023-01-01 06:40:48.000000',
+16 '2023-01-01 06:46:10.000000'))
+17ID = c('A1', 'A1', 'A5', 'A6', 'B3', 'C3', 'A6', 'A5', 'A6')
+18
+19DT = data.table(
+20 arrival_time = arrival_time,
+21 departure_time = departure_time,
+22 ID = ID)
+23######### BOILERPLATE CODE, NO LOGIC HERE ####################
+24
+25# A copy(DT) creates a copy of a data.table that isn't linked
+26# to the original one, so that changes in it don't reflect in
+27# the original DT object.
+28# The `:=` allow assignment by reference (i.e. "in place").
+29DT_with_windows = copy(DT)[, `:=`(
+30 window_start = arrival_time - minutes(1),
+31 window_end = departure_time + minutes(1))]
+32
+33# This step is necessary for the second table, but not the first, but we
+34# key both data.tables to make the foverlap code very succinct.
+35setkeyv(DT, c("arrival_time", "departure_time"))
+36setkeyv(DT_with_windows, c("window_start", "window_end"))
+37
+38# The foverlap function returns a data.table, so we can simply apply
+39# the usual data.table syntax on it!
+40# Since we have the same name of some columns in both data.tables,
+41# the latter table's columns are prefixed with "i." to avoid conflicts.
+42foverlaps(DT, DT_with_windows)[
+43 , .(docked_trucks = list(unique(i.ID)),
+44 docked_truck_count = uniqueN(i.ID))
+45 , .(arrival_time, departure_time)]provides us the output:
+ + + + + + 1 arrival_time departure_time docked_trucks docked_truck_count
+ 2 <POSc> <POSc> <list> <int>
+ 31: 2023-01-01 06:23:47 2023-01-01 06:25:08 A1 1
+ 42: 2023-01-01 06:26:42 2023-01-01 06:28:02 A1 1
+ 53: 2023-01-01 06:30:20 2023-01-01 06:35:01 A5,A6,B3,C3 4
+ 64: 2023-01-01 06:32:06 2023-01-01 06:33:48 A5,A6,B3,C3 4
+ 75: 2023-01-01 06:33:09 2023-01-01 06:36:01 A5,A6,B3,C3 4
+ 86: 2023-01-01 06:34:08 2023-01-01 06:39:49 A5,A6,B3,C3 4
+ 97: 2023-01-01 06:36:40 2023-01-01 06:38:34 B3,C3,A6,A5 4
+108: 2023-01-01 06:37:43 2023-01-01 06:40:48 C3,A6,A5 3
+119: 2023-01-01 06:39:48 2023-01-01 06:46:10 C3,A5,A6 3data.tableThe package offers a wonderful, nearly one-stop solution that doesn’t require you to write the logic out for the query or command yourself, but has a major problem for a lot of users - it requires you to switch your codebase to R, and a lot of your tasks may be on Python or in an SQL pipeline. So, what do you do?
+Consider the effort in maintaining an additional dependency for your analytics pipeline (i.e. R), and the effort that you’ll need to invest to run R from Python, or run an R script in your pipeline and pull the output from it back into the pipeline, and make your call.
+]]>Fucking hell.
+ ++ +
+ + + + + + +I stumbled upon an interesting Stackoverflow question that was linked via an issue on Polars github repo. The OP asked for a pure Polars solution. At the time of answering the question Polars did not have support for non-equi joins, and any solution using it would be pretty cumbersome.
+I’m more of a right-tool-for-the-job person, so I tried to find a better solution.
+Suppose we have a dataset that captures the arrival and departure times of trucks at a station, along with the truck’s ID.
+ + + + + + 1import polars as pl # if you don't have polars, run
+ 2 # pip install 'polars[all]'
+ 3data = pl.from_repr("""
+ 4┌─────────────────────┬─────────────────────┬─────┐
+ 5│ arrival_time ┆ departure_time ┆ ID │
+ 6│ --- ┆ --- ┆ --- │
+ 7│ datetime[μs] ┆ datetime[μs] ┆ str │
+ 8╞═════════════════════╪═════════════════════╪═════╡
+ 9│ 2023-01-01 06:23:47 ┆ 2023-01-01 06:25:08 ┆ A1 │
+10│ 2023-01-01 06:26:42 ┆ 2023-01-01 06:28:02 ┆ A1 │
+11│ 2023-01-01 06:30:20 ┆ 2023-01-01 06:35:01 ┆ A5 │
+12│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+13│ 2023-01-01 06:33:09 ┆ 2023-01-01 06:36:01 ┆ B3 │
+14│ 2023-01-01 06:34:08 ┆ 2023-01-01 06:39:49 ┆ C3 │
+15│ 2023-01-01 06:36:40 ┆ 2023-01-01 06:38:34 ┆ A6 │
+16│ 2023-01-01 06:37:43 ┆ 2023-01-01 06:40:48 ┆ A5 │
+17│ 2023-01-01 06:39:48 ┆ 2023-01-01 06:46:10 ┆ A6 │
+18└─────────────────────┴─────────────────────┴─────┘
+19""")We want to identify the number of trucks docked at any given time within a threshold of 1 minute prior to the arrival time of a truck, and 1 minute after the departure of a truck. Equivalently, this means that we need to calculate the number of trucks within a specific window for each row of the data.
+Before we find a general solution to this problem, let’s consider a specific row to understand the problem better:
+ + + + + +1"""
+2┌─────────────────────┬─────────────────────┬─────┐
+3│ arrival_time ┆ departure_time ┆ ID │
+4│ --- ┆ --- ┆ --- │
+5│ datetime[μs] ┆ datetime[μs] ┆ str │
+6╞═════════════════════╪═════════════════════╪═════╡
+7│ 2023-01-01 06:32:06 ┆ 2023-01-01 06:33:48 ┆ A6 │
+8└─────────────────────┴─────────────────────┴─────┘
+9"""For this row, we need to find the number of trucks that are there between 2023-01-01 06:31:06 (1 minute prior to the arrival_time and 2023-01-01 06:34:48 (1 minute post the departure_time). Manually going through the original dataset, we see that B3, C3, A6 and A5 are the truck IDs that qualify - they all are at the station in a duration that is between 2023-01-01 06:31:06 and 2023-01-01 06:34:48.
There are many cases that will qualify a truck to be present in the overlap window defined by a particular row. Specifically for the example above, we have (this visualization is generalizable, because for each row we can calculate without much difficulty the overlap window relative to the arrival and departure times):
+
Take some time to absorb these cases - it’s important for the part where we write the code for the solution. Note that we need to actually tell our algorithm to filter only for Cases 2, 3 and 4, since Cases 1 and 5 will not satisfy our requirements.
+In theory, we can use any language that has the capability to define rules that meet our algorithmic requirements outlined in the above section to find the solution. Why choose SQL? It’s often able to convey elegantly the logic that was used to execute the algorithm; and while it does come with excessive verbosity at times, it doesn’t quite in this case.
+Note here that we run SQL in Python with almost no setup or boilerplate code - so this is a Python based solution as well (although not quite Pythonic!).
+Once again, in theory, any SQL package or language can be used. Far too few however meet the ease-of-use that DuckDB provides:
+pip install duckdb),all with mind-blowing speed that stands shoulder-to-shoulder with Polars. We’ll also use a few advanced SQL concepts noted below.
+This should be a familiar, albeit not often used concept - a join of a table with itself is a self join. There are few cases where such an operation would make sense, and this happens to be one of them.
+A key concept that we’ll use is the idea of joining on a range of values rather than a specific value. That is, instead of the usual LEFT JOIN ON A.column = B.column, we can do LEFT JOIN ON A.column <= B.column for one row in table A to match to multiple rows in B. DuckDB has a blog post that outlines this join in detail, including fast implementation.
LIST columnsDuckDB has first class support for LIST columns - that is, each row in a LIST column can have a varying length (much like a Python list), but must have the exact same datatype (like R’s vector). Using list columns allow us to eschew the use of an additional GROUP BY operation on top of a WHERE filter or SELECT DISTINCT operation, since we can directly perform those on the LIST column itself.
Dates can be rather difficult to handle well in most tools and languages, with several packages purpose built to make handling them easier - lubridate from the tidyverse is a stellar example. Thankfully, DuckDB provides a similar swiss-knife set of tools to deal with it, including specifying INTERVALs (a special data type that represent a period of time independent of specific time values) to modify TIMESTAMP values using addition or subtraction.
Okay - had a lot of background. Let’s have at it! The query by itself in SQL is (see immediately below for runnable code in Python):
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN (
+16 SELECT *
+17 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+18 FROM data) B
+19
+20ON ((B.arrival_time <= A.window_open AND
+21 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+22 (B.arrival_time >= A.window_open AND
+23 B.departure_time <= A.window_close) OR
+24 (B.arrival_time >= A.window_open AND
+25 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close))
+26GROUP BY 1, 2, 3, 4A small, succinct query such as this will need a bit of explanation to take it all in. Here’s one below, reproducible in Python (make sure to install duckdb first!). Expand it to view.
1import duckdb as db
+ 2db.query("""
+ 3 SELECT
+ 4 A.arrival_time
+ 5 ,A.departure_time
+ 6 ,A.window_open
+ 7 ,A.window_close
+ 8 -- LIST aggregates the values into a LIST column
+ 9 -- and LIST_DISTINCT finds the unique values in it
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 -- finally, LIST_UNIQUE calculates the unique number of values in it
+12 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+13
+14 FROM (
+15 SELECT
+16 *
+17 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+18 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+19 FROM data -- remember we defined data as the Polars DataFrame with our truck station data
+20 ) A
+21
+22 LEFT JOIN (
+23 SELECT
+24 *
+25 -- This is the time, in seconds between the arrival and departure of
+26 -- each truck PER ROW in the original data-frame
+27 ,DATEDIFF('seconds', arrival_time, departure_time) AS duration
+28 FROM data -- this is where we perform a self-join
+29 ) B
+30
+31 ON (
+32 -- Case 2 in the diagram;
+33 (B.arrival_time <= A.window_open AND
+34 -- Adding the duration here makes sure that the second interval
+35 -- is at least ENDING AFTER the start of the overlap window
+36 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+37
+38 -- Case 3 in the diagram - the simplest of all five cases
+39 (B.arrival_time >= A.window_open AND
+40 B.departure_time <= A.window_close) OR
+41
+42 -- Case 4 in the digram;
+43 (B.arrival_time >= A.window_open AND
+44 -- Subtracting the duration here makes sure that the second interval
+45 -- STARTS BEFORE the end of the overlap window.
+46 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)
+47 )
+48 GROUP BY 1, 2, 3, 4
+49""")The output of this query is:
+ + + + + +"""
+┌─────────────────────┬─────────────────────┬─────────────────────┬───┬──────────────────┬────────────────────┐
+│ arrival_time │ departure_time │ window_open │ … │ docked_trucks │ docked_truck_count │
+│ timestamp │ timestamp │ timestamp │ │ varchar[] │ uint64 │
+├─────────────────────┼─────────────────────┼─────────────────────┼───┼──────────────────┼────────────────────┤
+│ 2023-01-01 06:23:47 │ 2023-01-01 06:25:08 │ 2023-01-01 06:22:47 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:26:42 │ 2023-01-01 06:28:02 │ 2023-01-01 06:25:42 │ … │ [A1] │ 1 │
+│ 2023-01-01 06:30:20 │ 2023-01-01 06:35:01 │ 2023-01-01 06:29:20 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:32:06 │ 2023-01-01 06:33:48 │ 2023-01-01 06:31:06 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:33:09 │ 2023-01-01 06:36:01 │ 2023-01-01 06:32:09 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:34:08 │ 2023-01-01 06:39:49 │ 2023-01-01 06:33:08 │ … │ [B3, C3, A6, A5] │ 4 │
+│ 2023-01-01 06:36:40 │ 2023-01-01 06:38:34 │ 2023-01-01 06:35:40 │ … │ [A5, A6, C3, B3] │ 4 │
+│ 2023-01-01 06:37:43 │ 2023-01-01 06:40:48 │ 2023-01-01 06:36:43 │ … │ [A5, A6, C3] │ 3 │
+│ 2023-01-01 06:39:48 │ 2023-01-01 06:46:10 │ 2023-01-01 06:38:48 │ … │ [A6, A5, C3] │ 3 │
+├─────────────────────┴─────────────────────┴─────────────────────┴───┴──────────────────┴────────────────────┤
+│ 9 rows 6 columns (5 shown) │
+└─────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
+"""We clearly see the strengths of DuckDB in how succintly we were able to express this operation. We also find how DuckDB is able to seamlessly integrate with an existing Pandas or Polars pipeline with zero-conversion costs. In fact, we can convert this back to a Polars or Pandas dataframe by appending the ending bracket with db.query(...).pl() and db.query(...).pd() respectively.
Now that we’ve understood the logic that goes into the query, let’s try to optimize the algorithm. We have the three conditions:
+ + + + + +1-- Case 2 in the diagram
+2(B.arrival_time <= A.window_open AND
+3 (B.arrival_time + TO_SECONDS(B.duration)) >= A.window_open) OR
+4-- Case 3 in the diagram
+5(B.arrival_time >= A.window_open AND
+6 B.departure_time <= A.window_close) OR
+7-- Case 4 in the diagram
+8(B.arrival_time >= A.window_open AND
+9 (B.departure_time - TO_SECONDS(B.duration)) <= A.window_close)What is common between these three conditions? It takes a while to see it; but it becomes clear that all these cases require the start of the overlap to be before the window ends, and the end of the overlap to be after the window starts. This can be simplified to just:
+ + + + + +1B.arrival_time <= A.window_close AND
+2B.departure_time >= A.window_openmaking our query much simpler!
+We’ve removed the need for the duration calculation algother now. Therefore, we can write:
1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.window_open
+ 5 ,A.window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8
+ 9FROM (
+10 SELECT *
+11 ,arrival_time - (INTERVAL 1 MINUTE) AS window_open
+12 ,departure_time + (INTERVAL 1 MINUTE) AS window_close
+13 FROM data) A
+14
+15LEFT JOIN data B
+16
+17ON (
+18 B.arrival_time <= A.window_close AND
+19 B.departure_time >= A.window_open
+20)
+21GROUP BY 1, 2, 3, 4Can we simplify this even further?
+I think the SQL query in the above section is very easy to ready already. However, it is a little clunky overall, and there is a way that we can leverage DuckDB’s extensive optimizations to simplify our legibility by rewriting the query as a cross join:
+ + + + + + 1SELECT
+ 2 A.arrival_time
+ 3 ,A.departure_time
+ 4 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 5 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+ 6 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+ 7 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+ 8FROM data A, data B
+ 9WHERE B.arrival_time <= window_close
+10AND B.departure_time >= window_open
+11GROUP BY 1, 2, 3, 4Why does this work? Before optimization on DuckDB, this is what the query plan looks like:
+ + + + + + + 1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ FILTER │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ (arrival_time <= │
+25│(departure_time + to_m... │
+26│ AS BIGINT)))) │
+27│ (departure_time >= │
+28│(arrival_time - to_min... │
+29│ AS BIGINT)))) │
+30└─────────────┬─────────────┘
+31┌─────────────┴─────────────┐
+32│ CROSS_PRODUCT ├──────────────┐
+33└─────────────┬─────────────┘ │
+34┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+35│ ARROW_SCAN ││ ARROW_SCAN │
+36└───────────────────────────┘└───────────────────────────┘
+37""" After optimization, the CROSS_PRODUCT is automatically optimized to an interval join!
1"""
+ 2┌───────────────────────────┐
+ 3│ PROJECTION │
+ 4│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+ 5│ 0 │
+ 6│ 1 │
+ 7│ 2 │
+ 8│ 3 │
+ 9│ docked_trucks │
+10│ docked_truck_count │
+11└─────────────┬─────────────┘
+12┌─────────────┴─────────────┐
+13│ AGGREGATE │
+14│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+15│ arrival_time │
+16│ departure_time │
+17│ window_open │
+18│ window_close │
+19│ list(ID) │
+20└─────────────┬─────────────┘
+21┌─────────────┴─────────────┐
+22│ COMPARISON_JOIN │
+23│ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ │
+24│ INNER │
+25│ ((departure_time + '00:01 │
+26│ :00'::INTERVAL) >= ├──────────────┐
+27│ arrival_time) │ │
+28│((arrival_time - '00:01:00'│ │
+29│ ::INTERVAL) <= │ │
+30│ departure_time) │ │
+31└─────────────┬─────────────┘ │
+32┌─────────────┴─────────────┐┌─────────────┴─────────────┐
+33│ ARROW_SCAN ││ ARROW_SCAN │
+34└───────────────────────────┘└───────────────────────────┘
+35""" So in effect, we’re actually exploiting a feature of DuckDB to allow us to write our queries in a suboptimal manner for greater readability, and allowing the optmizer to do a good chunk of our work for us. I wouldn’t recommend using this generally, because not all SQL engine optmizers will be able to find an efficient route to these calculations for large datasets.
+I’m glad you asked. Here’s the DuckDB page explaining EXPLAIN (heh). Here’s the code I used:
1import duckdb as db
+ 2db.sql("SET EXPLAIN_OUTPUT='all';")
+ 3print(db.query("""
+ 4EXPLAIN
+ 5SELECT
+ 6 A.arrival_time
+ 7 ,A.departure_time
+ 8 ,A.arrival_time - (INTERVAL 1 MINUTE) AS window_open
+ 9 ,A.departure_time + (INTERVAL 1 MINUTE) AS window_close
+10 ,LIST_DISTINCT(LIST(B.ID)) AS docked_trucks
+11 ,LIST_UNIQUE(LIST(B.ID)) AS docked_truck_count
+12FROM data A, data B
+13WHERE B.arrival_time <= window_close
+14AND B.departure_time >= window_open
+15GROUP BY 1, 2, 3, 4
+16""").pl()[1, 1])data.table waydata.table is a package that has historically been ahead of its time - in both speed and features that it has had. Developement has taken a hit recently, but will likely pick back up. It’s my favourite package on all fronts for data manipulation, but suffers simply from the lack of broader R support across the ML and DL space.
foverlaps functionIf this kind of overlapping join is common, shouldn’t someone have developed a package for it? Turns out, data.table has, and with very specific constraints that make it the perfect solution to our problem (if you don’t mind switching over to R, that is).
The foverlaps function has these requirements:
data.table objects have to be keyed for automatic recognition of columns.within overlap, matching start and end windows,by must specify the start and end points of the overlapping window. This isn’t a problem for us now, but does restrict for future uses where we may want non-equi joins on other cases.Without further ado:
+ + + + + + 1library(data.table)
+ 2library(lubridate)
+ 3
+ 4######### BOILERPLATE CODE, NO LOGIC HERE ####################
+ 5arrival_time = as_datetime(c(
+ 6 '2023-01-01 06:23:47.000000', '2023-01-01 06:26:42.000000',
+ 7 '2023-01-01 06:30:20.000000', '2023-01-01 06:32:06.000000',
+ 8 '2023-01-01 06:33:09.000000', '2023-01-01 06:34:08.000000',
+ 9 '2023-01-01 06:36:40.000000', '2023-01-01 06:37:43.000000',
+10 '2023-01-01 06:39:48.000000'))
+11departure_time = as_datetime(c(
+12 '2023-01-01 06:25:08.000000', '2023-01-01 06:28:02.000000',
+13 '2023-01-01 06:35:01.000000', '2023-01-01 06:33:48.000000',
+14 '2023-01-01 06:36:01.000000', '2023-01-01 06:39:49.000000',
+15 '2023-01-01 06:38:34.000000', '2023-01-01 06:40:48.000000',
+16 '2023-01-01 06:46:10.000000'))
+17ID = c('A1', 'A1', 'A5', 'A6', 'B3', 'C3', 'A6', 'A5', 'A6')
+18
+19DT = data.table(
+20 arrival_time = arrival_time,
+21 departure_time = departure_time,
+22 ID = ID)
+23######### BOILERPLATE CODE, NO LOGIC HERE ####################
+24
+25# A copy(DT) creates a copy of a data.table that isn't linked
+26# to the original one, so that changes in it don't reflect in
+27# the original DT object.
+28# The `:=` allow assignment by reference (i.e. "in place").
+29DT_with_windows = copy(DT)[, `:=`(
+30 window_start = arrival_time - minutes(1),
+31 window_end = departure_time + minutes(1))]
+32
+33# This step is necessary for the second table, but not the first, but we
+34# key both data.tables to make the foverlap code very succinct.
+35setkeyv(DT, c("arrival_time", "departure_time"))
+36setkeyv(DT_with_windows, c("window_start", "window_end"))
+37
+38# The foverlap function returns a data.table, so we can simply apply
+39# the usual data.table syntax on it!
+40# Since we have the same name of some columns in both data.tables,
+41# the latter table's columns are prefixed with "i." to avoid conflicts.
+42foverlaps(DT, DT_with_windows)[
+43 , .(docked_trucks = list(unique(i.ID)),
+44 docked_truck_count = uniqueN(i.ID))
+45 , .(arrival_time, departure_time)]provides us the output:
+ + + + + + 1 arrival_time departure_time docked_trucks docked_truck_count
+ 2 <POSc> <POSc> <list> <int>
+ 31: 2023-01-01 06:23:47 2023-01-01 06:25:08 A1 1
+ 42: 2023-01-01 06:26:42 2023-01-01 06:28:02 A1 1
+ 53: 2023-01-01 06:30:20 2023-01-01 06:35:01 A5,A6,B3,C3 4
+ 64: 2023-01-01 06:32:06 2023-01-01 06:33:48 A5,A6,B3,C3 4
+ 75: 2023-01-01 06:33:09 2023-01-01 06:36:01 A5,A6,B3,C3 4
+ 86: 2023-01-01 06:34:08 2023-01-01 06:39:49 A5,A6,B3,C3 4
+ 97: 2023-01-01 06:36:40 2023-01-01 06:38:34 B3,C3,A6,A5 4
+108: 2023-01-01 06:37:43 2023-01-01 06:40:48 C3,A6,A5 3
+119: 2023-01-01 06:39:48 2023-01-01 06:46:10 C3,A5,A6 3data.tableThe package offers a wonderful, nearly one-stop solution that doesn’t require you to write the logic out for the query or command yourself, but has a major problem for a lot of users - it requires you to switch your codebase to R, and a lot of your tasks may be on Python or in an SQL pipeline. So, what do you do?
+Consider the effort in maintaining an additional dependency for your analytics pipeline (i.e. R), and the effort that you’ll need to invest to run R from Python, or run an R script in your pipeline and pull the output from it back into the pipeline, and make your call.
+ ++ +
+ + + + + + +Most of my work is on private repositories, but I do find some time to learn new topics, contribute back to some of the open source packages I frequently use, or to create interesting tools.
data.table, that I found very useful earlier in my career to quicky churn out analyses. It is not ground-breaking, but rather something that anybody with sufficient basic skills in R and understand, and save an immense amount of time.data.table and dplyr dominated), so I was eager to make it better for everybody making the switch.+ +
+ + + + +In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.When I worked in healthcare consulting, I often spent a LOT of my time creating PowerPoint presentations (decks in consulting lingo - not even slide decks). However, it was rather repetitive. Thus, was born PowerPointSnap.
+I’ll write this down as pointers.
+The project is available on this Github repo. The instructions to install it are available there, but here’s the down-low:
+Frankly, a LOT. The base concept of this tool is:
+Here’s a non-exhaustive list of all the options available.
+This is the part of the interface that can be used for shapes (which include charts and tables).
+
To use, first select a shape object, click on “Set”. Then, choose the object you want to Snap its properties to (see how I got the inspiration for the name?). You should be able to copy all compatible properties - if something is not copy-able, the tool will show an error, and then let you exit.
+Note that it’s probably not to apply a property of a shape to a table - if you want to make the entire table orange, there are probably better built-in ways to do it than to use Snap.
+Charts are also supported, with dedicated features for it.
+
What do these features do? You should be able to hover over the option and get a tooltip that shows what it’s capable of, but here’s another summary just in case:
+The next two options deserve their own section.
+Your immediate senior in a consulting environment would frown at your chart, and then exclaim, “I think that’s too many labels for the data points. Can you show them every two/three/four labels? I know this is manual work, but it’s a one time thing!”
+It’s never a one time affair. But don’t worry, we have this nice feature to help us. If you click on the Customize Label option, you will get this (without the “Set” option):
+Never mind the rather unfriendly legend entries. They’re just here to demonstrate that you can do the following kinds of whacky abilities with your own chart!
+Of course, visuals will do it more justice. For example, look at this image:
+
Here’s what you can do:
+This is what your results should look like:
+
Of course, getting those calculations right is a whole different thing that will need some work.
+Oftentimes, you have two tables that show similar values… you know the drill. Here’s what you can do in a scenario such as this:
+
This is what the Tables section of the tool looks like:
+
To align these tables together,
+Here’s what you’ll end up with:
+
Pretty neat, eh?
+]]>When I worked in healthcare consulting, I often spent a LOT of my time creating PowerPoint presentations (decks in consulting lingo - not even slide decks). However, it was rather repetitive. Thus, was born PowerPointSnap.
+I’ll write this down as pointers.
+The project is available on this Github repo. The instructions to install it are available there, but here’s the down-low:
+Frankly, a LOT. The base concept of this tool is:
+Here’s a non-exhaustive list of all the options available.
+This is the part of the interface that can be used for shapes (which include charts and tables).
+
To use, first select a shape object, click on “Set”. Then, choose the object you want to Snap its properties to (see how I got the inspiration for the name?). You should be able to copy all compatible properties - if something is not copy-able, the tool will show an error, and then let you exit.
+Note that it’s probably not to apply a property of a shape to a table - if you want to make the entire table orange, there are probably better built-in ways to do it than to use Snap.
+Charts are also supported, with dedicated features for it.
+
What do these features do? You should be able to hover over the option and get a tooltip that shows what it’s capable of, but here’s another summary just in case:
+The next two options deserve their own section.
+Your immediate senior in a consulting environment would frown at your chart, and then exclaim, “I think that’s too many labels for the data points. Can you show them every two/three/four labels? I know this is manual work, but it’s a one time thing!”
+It’s never a one time affair. But don’t worry, we have this nice feature to help us. If you click on the Customize Label option, you will get this (without the “Set” option):
+Never mind the rather unfriendly legend entries. They’re just here to demonstrate that you can do the following kinds of whacky abilities with your own chart!
+Of course, visuals will do it more justice. For example, look at this image:
+
Here’s what you can do:
+This is what your results should look like:
+
Of course, getting those calculations right is a whole different thing that will need some work.
+Oftentimes, you have two tables that show similar values… you know the drill. Here’s what you can do in a scenario such as this:
+
This is what the Tables section of the tool looks like:
+
To align these tables together,
+Here’s what you’ll end up with:
+
Pretty neat, eh?
+]]>In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.In this day and age, we’re not short on data. Good data, on the other hand, is very valuable. When you’ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
+Let’s formalize the problem a little so that a proper approach can be developed. Here’s the problem statement:
+Here’s what you need to do:
+Generally, three things come to mind:
+This data can be practically anything that can be represented as a 2D matrix.
+There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they’ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some generic task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.
+For this specific article, I will use the ShopMania dataset on Kaggle. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I’m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:
+++ + + + + +NOTE: whenever I want to show an output along with the code I used for it, you’ll see the characters
+>>indicating the command used, and the output to be without those prefixes.
1>> import polars as pl
+ 2>> data = pl.read_csv("archive/shopmania.csv")
+ 3>> data
+ 4shape: (313_705, 4)
+ 5┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐
+ 6│ product_ID ┆ product_title ┆ category_ID ┆ category_label │
+ 7│ --- ┆ --- ┆ --- ┆ --- │
+ 8│ i64 ┆ str ┆ i64 ┆ str │
+ 9╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡
+10│ 2 ┆ twilight central park print ┆ 2 ┆ Collectibles │
+11│ 3 ┆ fox print ┆ 2 ┆ Collectibles │
+12│ 4 ┆ circulo de papel wall art ┆ 2 ┆ Collectibles │
+13│ 5 ┆ hidden path print ┆ 2 ┆ Collectibles │
+14│ … ┆ … ┆ … ┆ … │
+15│ 313703 ┆ deago anti fog swimming diving full face mask ┆ 229 ┆ Water Sports │
+16│ ┆ surface snorkel scuba fr gopro black s/m ┆ ┆ │
+17│ 313704 ┆ etc buys full face gopro compatible snorkel scuba ┆ 229 ┆ Water Sports │
+18│ ┆ diving mask blue large/xtralarge blue ┆ ┆ │
+19│ 313705 ┆ men 039 s full face breathe free diving snorkel mask ┆ 229 ┆ Water Sports │
+20│ ┆ scuba optional hd camera blue mask only adult men ┆ ┆ │
+21│ 313706 ┆ women 039 s full face breathe free diving snorkel ┆ 229 ┆ Water Sports │
+22│ ┆ mask scuba optional hd camera black mask only ┆ ┆ │
+23│ ┆ children and women ┆ ┆ │
+24└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘The data documentation on Kaggle states:
+++The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.
+
For demonstration, I’ll just limit the categories to those that have exactly 10,000 occurences.
+ + + + + +1data = (
+2 data
+3 .filter(pl.count().over("category_ID") == 10000)
+4)You’ll notice that there are only 17 categories in this dataset. Run this to verify that fact.
+ + + + + + 1>>> data.get_column("category_label").unique()
+ 2shape: (17,)
+ 3Series: 'category_label' [str]
+ 4[
+ 5 "Kitchen & Dining"
+ 6 "Scarves and wraps"
+ 7 "Handbags & Wallets"
+ 8 "Rugs Tapestry & Linens"
+ 9 "Cell Phones Accessories"
+10 "Men's Clothing"
+11 "Jewelry"
+12 "Belts"
+13 "Men Lingerie"
+14 "Crafts"
+15 "Football"
+16 "Medical Supplies"
+17 "Adult"
+18 "Hunting"
+19 "Women's Clothing"
+20 "Pet Supply"
+21 "Office Supplies"
+22]Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.
+Okay - so now we have exactly 10,000 products per category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:
+++Craft a small representative sample for each category.
+
Why small? It helps that it’ll make the model faster to train - and keep the training data manageable in size.
+I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer’s rather simple: use SentenceTransformers to get a string’s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I’ve noticed that SentenceTransformers are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.
SentenceTransformer embeddingsThis part is rather simple. If you’re unable to install SentenceTransformers, please check their website.
+ + + + + +1import sentence_transformers
+2# See list of models at www.sbert.net/docs/pretrained_models.html
+3ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+4title_embeddings = (
+5 ST.encode(
+6 data.get_column("product_title").to_list(),
+7 show_progress_bar=True, convert_to_tensor=True)
+8 .numpy())This process will be slow (~30 minutes) if you don’t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to .numpy() at the end is to directly get a single numpy array - otherwise you get a list of numpy arrays, which is rather inefficient. Further, SentenceTransformers will try to run on the GPU if available, and if so, you will need to write .cpu().numpy() so that the tensor is copied from the GPU to the CPU.
++NOTE: for a proof-of-concept implementation, or if you’re on the CPU, try the
+all-MiniLM-L6-v2model. It’s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.
Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. Approximate nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least one of the nearest neighbors (hence the term approximate).
+There are several algorithms that you can use - I shall proceed with faiss, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are available here.
I’ll explain why we’re in the nearest neighbor territory in due course.
+To build the database, all we need is the title_embeddings matrix.
1import faiss
+2def create_index(title_embeddings):
+3 d = title_embeddings.shape[1] # Number of dimensions
+4 ann_index = faiss.IndexFlatL2(d) # Index using Eucledian Matrix
+5 ann_index.add(title_embeddings) # Build the index
+6
+7 return ann_index # Faiss considers databases an "index"This does create a database. But remember, we’re trying to find representative samples - which means we need to do this by the category (or label). So let’s design a function that sends only the necessary data as that for a particular category, and then create the database. We’ll need three pieces of information from this function:
+faiss database.faiss database.(2) and (3) will help us later in rebuilding a “network graph” that will allow us to reference the original data points.
+ + + + + + 1import faiss
+ 2import numpy as np
+ 3import polars as pl
+ 4
+ 5def create_index(label):
+ 6 faiss_indices = (
+ 7 data # this needs to be an argument if you want to create a generic function
+ 8 .with_row_count("row_idx")
+ 9 .filter(pl.col("category_label") == label)
+10 .get_column("row_idx")
+11 .to_list()
+12 )
+13
+14 faiss_data = title_embeddings[faiss_indices]
+15 d = data.shape[1] # Number of dimensions
+16 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+17 faiss.normalize_L2(data) # Normalized L2 with Inner Product search = cosine similarity
+18 # Why cosine similarity? It's easier to specify thresholds - they'll always be between 0 and 1.4.
+19 # If using Eucledian or other distance, we'll have to spend some time finding a good range
+20 # where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.
+21 faiss_DB.add(data) # Build the index
+22
+23 return faiss_DB, faiss_data, faiss_indicesTo proceed with getting a representative sample, the next step is to find the nearest neighbors for all data points in the database. This isn’t too hard - faiss index objects have a built-in search method to find the k nearest neighbors for a given index, along with the (approximate) distance to it. Let’s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an edge list i.e. a list of pair of nodes that are connected, along with any additional information that specifies a property (in this case distance) of the edge that connects these nodes.
1def get_edge_list(label, k=5):
+ 2 faiss_DB, faiss_data, faiss_indices = create_index(label)
+ 3 # To map the data back to the original `train[b'data']` array
+ 4 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+ 5 # To map the indices back to the original strings
+ 6 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+ 7 distances, neighbors = faiss_DB.search(faiss_data, k)
+ 8
+ 9 return (
+10 pl.DataFrame({
+11 "from": faiss_indices})
+12 .with_columns(
+13 pl.Series("to", neighbors),
+14 pl.Series("distance", distances))
+15 .explode("to", "distance")
+16 .with_columns(
+17 pl.col("from")
+18 .map_dict(title_name_map),
+19 pl.col("to")
+20 .map_dict(faiss_indices_map)
+21 .map_dict(title_name_map))
+22 .filter(pl.col("from") != pl.col("to"))
+23 ) The next step in the process is to create a network graph using the edge-list. But why?
+Remember that we have identified the (k=5) nearest neighbors of each data point. Let’s say that we have a point A that has a nearest neighbor B. C is not a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular minimum thershold, then A will be connected to C through B! Hopefully a small visual below would help.
+
What happens when such a concept is extended for many data points? Not all of them would be connected - because we’re applying a minimum threshold that they have to meet. This is the only hueristic part of the rather fast process. Here’s one more helpful visual:
+
Very starry night-eque vibes here. Let’s get to the code.
+ + + + + +1import networkx as nx
+2def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+3 edge_list = (
+4 get_edge_list(label, k=k)
+5 .filter(pl.col("distance") >= min_cosine_distance)
+6 )
+7 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+8 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}Now that all the parts of the puzzle are together, let’s run it to see what kind of clusters you get for Cell Phone Accessories.
1clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)Make sure to configure the following if your results aren’t good enough:
+min_cosine_distance value if you want bigger clusters.There will likely be many clusters (you can see how many exactly with len(clusters)). Let’s look at a random cluster:
1>> clusters[3]
+2['smartphone lanyard with card slot for any phone up to 6 yellow 72570099',
+3 'smartphone lanyard with card slot for any phone up to 6 black 72570093',
+4 'smartphone lanyard with card slot for any phone up to 6 lightblue 72570097',
+5 'smartphone lanyard with card slot for any phone up to 6 blue 72570095',
+6 'smartphone lanyard with card slot for any phone up to 6 green 72570101',
+7 'smartphone lanyard with card slot for any phone up to 6 pink 72570091']Let’s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).
+ + + + + + 1>>> clusters[6]
+ 2['otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a',
+ 3 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58',
+ 4 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a',
+ 5 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d',
+ 6 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45',
+ 7 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16',
+ 8 ...
+ 9 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20',
+10 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17',
+11 'otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a',
+12 'otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08',
+13 'otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a',
+14 'otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22']This isn’t that hard (although it may take more than a moment). Just iterate it for each category!
+ + + + + +1clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]I get it - you often want a solution that “just works”. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren’t in a hurry, this also serves as a nice summary (and copy-pastable code)!
+ 1import sentence_transformers
+ 2import faiss
+ 3import polars as pl
+ 4import numpy as np
+ 5
+ 6# Data is read here. You download the files from Kaggle here:
+ 7# https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization
+ 8data = pl.read_csv("archive/shopmania.csv", new_columns=[
+ 9 "product_ID", "product_title", "category_ID", "category_label"])
+10data = (
+11 data
+12 .filter(pl.count().over("category_ID") == 10000)
+13 .with_row_count("row_idx")
+14)
+15
+16
+17# See list of models at www.sbert.net/docs/pretrained_models.html
+18ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
+19title_embeddings = (
+20 ST.encode(
+21 data.get_column("product_title").to_list(),
+22 # I'm on a MacBook, you should use `cuda` or `cpu`
+23 # if you've got different hardware.
+24 device="mps",
+25 show_progress_bar=True, convert_to_tensor=True)
+26 .cpu().numpy())
+27
+28# Code to create a FAISS index
+29def create_index(label):
+30 faiss_indices = (
+31 data # this needs to be an argument if you want to create a generic function
+32 .filter(pl.col("category_label") == label)
+33 .get_column("row_idx")
+34 .to_list()
+35 )
+36
+37 faiss_data = title_embeddings[faiss_indices]
+38 d = faiss_data.shape[1] # Number of dimensions
+39 faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
+40 faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
+41 faiss_DB.add(faiss_data) # Build the index
+42
+43 return faiss_DB, faiss_data, faiss_indices
+44
+45# Code to create an edge-list
+46def get_edge_list(label, k=5):
+47 faiss_DB, faiss_data, faiss_indices = create_index(label)
+48 # To map the data back to the original `train[b'data']` array
+49 faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
+50 # To map the indices back to the original strings
+51 title_name_map = {i: x for i,x in data.select("row_idx", "product_title").rows()}
+52 distances, neighbors = faiss_DB.search(faiss_data, k)
+53
+54 return (
+55 pl.DataFrame({
+56 "from": faiss_indices})
+57 .with_columns(
+58 pl.Series("to", neighbors),
+59 pl.Series("distance", distances))
+60 .explode("to", "distance")
+61 .with_columns(
+62 pl.col("from")
+63 .map_dict(title_name_map),
+64 pl.col("to")
+65 .map_dict(faiss_indices_map)
+66 .map_dict(title_name_map))
+67 .filter(pl.col("from") != pl.col("to"))
+68 )
+69
+70# Code to extract components from a Network Graph
+71import networkx as nx
+72def get_cluster_map(label, k=5, min_cosine_distance=0.95):
+73 edge_list = (
+74 get_edge_list(label, k=k)
+75 .filter(pl.col("distance") >= min_cosine_distance)
+76 )
+77 graph = nx.from_pandas_edgelist(edge_list.to_pandas(), source="from", target="to")
+78 return {i: list(x) for i,x in enumerate(nx.connected_components(graph))}
+79
+80# Example call to a single category to obtain its clusters
+81clusters = get_cluster_map("Cell Phones Accessories", 5, 0.95)
+82# Example call to **all** categories to obtain all clusters
+83clusters = [get_cluster_map(x, 5, 0.95) for x in data.get_column("category_label").unique()]If you want to write down an algorithmic way of looking at this approach,
+faiss) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.When I worked in healthcare consulting, I often spent a LOT of my time creating PowerPoint presentations (decks in consulting lingo - not even slide decks). However, it was rather repetitive. Thus, was born PowerPointSnap.
+I’ll write this down as pointers.
+The project is available on this Github repo. The instructions to install it are available there, but here’s the down-low:
+Frankly, a LOT. The base concept of this tool is:
+Here’s a non-exhaustive list of all the options available.
+This is the part of the interface that can be used for shapes (which include charts and tables).
+
To use, first select a shape object, click on “Set”. Then, choose the object you want to Snap its properties to (see how I got the inspiration for the name?). You should be able to copy all compatible properties - if something is not copy-able, the tool will show an error, and then let you exit.
+Note that it’s probably not to apply a property of a shape to a table - if you want to make the entire table orange, there are probably better built-in ways to do it than to use Snap.
+Charts are also supported, with dedicated features for it.
+
What do these features do? You should be able to hover over the option and get a tooltip that shows what it’s capable of, but here’s another summary just in case:
+The next two options deserve their own section.
+Your immediate senior in a consulting environment would frown at your chart, and then exclaim, “I think that’s too many labels for the data points. Can you show them every two/three/four labels? I know this is manual work, but it’s a one time thing!”
+It’s never a one time affair. But don’t worry, we have this nice feature to help us. If you click on the Customize Label option, you will get this (without the “Set” option):
+Never mind the rather unfriendly legend entries. They’re just here to demonstrate that you can do the following kinds of whacky abilities with your own chart!
+Of course, visuals will do it more justice. For example, look at this image:
+
Here’s what you can do:
+This is what your results should look like:
+
Of course, getting those calculations right is a whole different thing that will need some work.
+Oftentimes, you have two tables that show similar values… you know the drill. Here’s what you can do in a scenario such as this:
+
This is what the Tables section of the tool looks like:
+
To align these tables together,
+Here’s what you’ll end up with:
+
Pretty neat, eh?
+]]>🙈 :see_no_evil: 🙉 :hear_no_evil: 🙊 :speak_no_evil:
Test
+ + +``` + +#### Code block indented with four spaces + + + + + +Test
+ + + +#### Code block with Hugo's internal highlight shortcode +{{< highlight html >}} + + + + +Test
+ + +{{< /highlight >}} + +## List Types + +#### Ordered List + +1. First item +2. Second item +3. Third item + +#### Unordered List + +* List item +* Another item +* And another item + +#### Nested list + +* Fruit + * Apple + * Orange + * Banana +* Dairy + * Milk + * Cheese + +## Other Elements — abbr, sub, sup, kbd, mark + +GIF is a bitmap image format. + +H2O + +Xn + Yn = Zn + +Press CTRL+ALT+Delete to end the session. + +Most salamanders are nocturnal, and hunt for insects, worms, and other small creatures. diff --git a/themes/hugo-bearcub/exampleSite/content/blog/math-typesetting.md b/themes/hugo-bearcub/exampleSite/content/blog/math-typesetting.md new file mode 100644 index 0000000..62831a9 --- /dev/null +++ b/themes/hugo-bearcub/exampleSite/content/blog/math-typesetting.md @@ -0,0 +1,49 @@ +--- +author: Hugo Authors +title: Math Typesetting +date: 2019-03-08 +description: A brief guide to setup KaTeX +math: true +--- + +Mathematical notation in a Hugo project can be enabled by using third party JavaScript libraries. + + +In this example we will be using [KaTeX](https://katex.org/) + +- Create a partial under `/layouts/partials/math.html` +- Within this partial reference the [Auto-render Extension](https://katex.org/docs/autorender.html) or host these scripts locally. +- Include the partial in your templates like so: + +```bash +{{ if or .Params.math .Site.Params.math }} +{{ partial "math.html" . }} +{{ end }} +``` + +- To enable KaTex globally set the parameter `math` to `true` in a project's configuration +- To enable KaTex on a per page basis include the parameter `math: true` in content files + +**Note:** Use the online reference of [Supported TeX Functions](https://katex.org/docs/supported.html) + +{{< math.inline >}} +{{ if or .Page.Params.math .Site.Params.math }} + + + + +{{ end }} +{{ math.inline >}} + +### Examples + +{{< math.inline >}} ++Inline math: \(\varphi = \dfrac{1+\sqrt5}{2}= 1.6180339887…\) +
+{{ math.inline >}} + +Block math: +$$ + \varphi = 1+\frac{1} {1+\frac{1} {1+\frac{1} {1+\cdots} } } +$$ diff --git a/themes/hugo-bearcub/exampleSite/content/blog/placeholder-text.md b/themes/hugo-bearcub/exampleSite/content/blog/placeholder-text.md new file mode 100644 index 0000000..9ed5f69 --- /dev/null +++ b/themes/hugo-bearcub/exampleSite/content/blog/placeholder-text.md @@ -0,0 +1,45 @@ ++++ +author = "Hugo Authors" +title = "Placeholder Text" +date = "2019-03-09" +description = "Lorem Ipsum Dolor Si Amet" +tags = [ + "markdown", + "text", +] ++++ + +Lorem est tota propiore conpellat pectoribus de pectora summo. Redit teque digerit hominumque toris verebor lumina non cervice subde tollit usus habet Arctonque, furores quas nec ferunt. Quoque montibus nunc caluere tempus inhospita parcite confusaque translucet patri vestro qui optatis lumine cognoscere flos nubis! Fronde ipsamque patulos Dryopen deorum. + +1. Exierant elisi ambit vivere dedere +2. Duce pollice +3. Eris modo +4. Spargitque ferrea quos palude + +Rursus nulli murmur; hastile inridet ut ab gravi sententia! Nomine potitus silentia flumen, sustinet placuit petis in dilapsa erat sunt. Atria tractus malis. + +1. Comas hunc haec pietate fetum procerum dixit +2. Post torum vates letum Tiresia +3. Flumen querellas +4. Arcanaque montibus omnes +5. Quidem et + +# Vagus elidunt + + + +[The Van de Graaf Canon](https://en.wikipedia.org/wiki/Canons_of_page_construction#Van_de_Graaf_canon) + +## Mane refeci capiebant unda mulcebat + +Victa caducifer, malo vulnere contra dicere aurato, ludit regale, voca! Retorsit colit est profanae esse virescere furit nec; iaculi matertera et visa est, viribus. Divesque creatis, tecta novat collumque vulnus est, parvas. **Faces illo pepulere** tempus adest. Tendit flamma, ab opes virum sustinet, sidus sequendo urbis. + +Iubar proles corpore raptos vero auctor imperium; sed et huic: manus caeli Lelegas tu lux. Verbis obstitit intus oblectamina fixis linguisque ausus sperare Echionides cornuaque tenent clausit possit. Omnia putatur. Praeteritae refert ausus; ferebant e primus lora nutat, vici quae mea ipse. Et iter nil spectatae vulnus haerentia iuste et exercebat, sui et. + +Eurytus Hector, materna ipsumque ut Politen, nec, nate, ignari, vernum cohaesit sequitur. Vel **mitis temploque** vocatus, inque alis, *oculos nomen* non silvis corpore coniunx ne displicet illa. Crescunt non unus, vidit visa quantum inmiti flumina mortis facto sic: undique a alios vincula sunt iactata abdita! Suspenderat ego fuit tendit: luna, ante urbem Propoetides **parte**. + +{{< css.inline >}} + +{{< /css.inline >}} diff --git a/themes/hugo-bearcub/exampleSite/content/blog/rich-content.md b/themes/hugo-bearcub/exampleSite/content/blog/rich-content.md new file mode 100644 index 0000000..3c406af --- /dev/null +++ b/themes/hugo-bearcub/exampleSite/content/blog/rich-content.md @@ -0,0 +1,34 @@ ++++ +author = "Hugo Authors" +title = "Rich Content" +date = "2019-03-10" +description = "A brief description of Hugo Shortcodes" +tags = [ + "shortcodes", + "privacy", +] ++++ + +Hugo ships with several [Built-in Shortcodes](https://gohugo.io/content-management/shortcodes/#use-hugos-built-in-shortcodes) for rich content, along with a [Privacy Config](https://gohugo.io/about/hugo-and-gdpr/) and a set of Simple Shortcodes that enable static and no-JS versions of various social media embeds. + +--- + +## YouTube Privacy Enhanced Shortcode + +{{< youtube ZJthWmvUzzc >}} + ++ {{ range (.GetTerms "tags") }} + #{{ lower .LinkTitle }} + {{ end }} +
+{{ if not .Params.hideReply }} +{{ with .Site.Params.author.email }} ++ + {{ i18n "email-reply" }} ↪ + +
+{{ end }} +{{ end }} +{{ end }} diff --git a/themes/hugo-bearcub/layouts/index.html b/themes/hugo-bearcub/layouts/index.html new file mode 100644 index 0000000..9983b08 --- /dev/null +++ b/themes/hugo-bearcub/layouts/index.html @@ -0,0 +1,3 @@ +{{ define "main" }} +{{ .Content }} +{{ end }} diff --git a/themes/hugo-bearcub/layouts/partials/custom_body.html b/themes/hugo-bearcub/layouts/partials/custom_body.html new file mode 100644 index 0000000..951fb15 --- /dev/null +++ b/themes/hugo-bearcub/layouts/partials/custom_body.html @@ -0,0 +1,3 @@ + diff --git a/themes/hugo-bearcub/layouts/partials/custom_head.html b/themes/hugo-bearcub/layouts/partials/custom_head.html new file mode 100644 index 0000000..4c53c40 --- /dev/null +++ b/themes/hugo-bearcub/layouts/partials/custom_head.html @@ -0,0 +1,3 @@ + diff --git a/themes/hugo-bearcub/layouts/partials/favicon.html b/themes/hugo-bearcub/layouts/partials/favicon.html new file mode 100644 index 0000000..d391a86 --- /dev/null +++ b/themes/hugo-bearcub/layouts/partials/favicon.html @@ -0,0 +1,2 @@ +{{ with .Site.Params.favicon }} +{{ end }} diff --git a/themes/hugo-bearcub/layouts/partials/footer.html b/themes/hugo-bearcub/layouts/partials/footer.html new file mode 100644 index 0000000..9290715 --- /dev/null +++ b/themes/hugo-bearcub/layouts/partials/footer.html @@ -0,0 +1,3 @@ + + {{ .Site.Copyright }} | {{ markdownify .Site.Params.madeWith }} + \ No newline at end of file diff --git a/themes/hugo-bearcub/layouts/partials/header.html b/themes/hugo-bearcub/layouts/partials/header.html new file mode 100644 index 0000000..42b5c10 --- /dev/null +++ b/themes/hugo-bearcub/layouts/partials/header.html @@ -0,0 +1,4 @@ +{{ i18n "skip-link" }} + ++ {{- .Get "caption" | markdownify -}} + {{- with .Get "attrlink" }} + + {{- end -}} + {{- .Get "attr" | markdownify -}} + {{- if .Get "attrlink" }}{{ end }}
+ {{- end }} +