Files
www/public/blog/002_representative_samples/index.html
Avinash Mallya 57eff46d6c Switch to Hugo
2025-09-13 21:27:23 -05:00

243 lines
74 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
<!doctype html><html lang=en-US><head><meta http-equiv=X-Clacks-Overhead content="GNU Terry Pratchett"><meta charset=utf-8><meta name=viewport content="width=device-width,initial-scale=1"><title>Finding representative samples efficiently for large datasets | Avinash's Blog</title><meta name=title content="Finding representative samples efficiently for large datasets"><meta name=description content="Premise
In this day and age, we&rsquo;re not short on data. Good data, on the other hand, is very valuable. When you&rsquo;ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
Let&rsquo;s formalize the problem a little so that a proper approach can be developed. Here&rsquo;s the problem statement:
You have a large-ish set of (imperfectly) labelled data points. These data points can be represented as a 2D matrix.
You need to train a model to classify these data points on either these labels, or on labels dervied from imperfect labels.
You need a good (but not perfect) representative sample for the model to be generalizable, but there are too many data points for each label to manually pick representative examples.
In a hurry?
Here&rsquo;s what you need to do:"><meta name=author content="Avinash Mallya"><meta name=keywords content="representative,samples,faiss,approximate,nearest,neighbor,network,graph,networkx,polars,category,"><meta property="og:url" content="https://avimallu.dev/blog/002_representative_samples/"><meta property="og:site_name" content="Avinash's Blog"><meta property="og:title" content="Finding representative samples efficiently for large datasets"><meta property="og:description" content="Premise In this day and age, were not short on data. Good data, on the other hand, is very valuable. When youve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
Lets formalize the problem a little so that a proper approach can be developed. Heres the problem statement:
You have a large-ish set of (imperfectly) labelled data points. These data points can be represented as a 2D matrix. You need to train a model to classify these data points on either these labels, or on labels dervied from imperfect labels. You need a good (but not perfect) representative sample for the model to be generalizable, but there are too many data points for each label to manually pick representative examples. In a hurry? Heres what you need to do:"><meta property="og:locale" content="en_US"><meta property="og:type" content="article"><meta property="article:section" content="blog"><meta property="article:published_time" content="2023-10-19T00:00:00+00:00"><meta property="article:modified_time" content="2023-10-19T00:00:00+00:00"><meta property="article:tag" content="Representative"><meta property="article:tag" content="Samples"><meta property="article:tag" content="Faiss"><meta property="article:tag" content="Approximate"><meta property="article:tag" content="Nearest"><meta property="article:tag" content="Neighbor"><meta property="og:image" content="https://avimallu.dev/static/favicon.ico"><meta name=twitter:card content="summary_large_image"><meta name=twitter:image content="https://avimallu.dev/static/favicon.ico"><meta name=twitter:title content="Finding representative samples efficiently for large datasets"><meta name=twitter:description content="Premise In this day and age, were not short on data. Good data, on the other hand, is very valuable. When youve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
Lets formalize the problem a little so that a proper approach can be developed. Heres the problem statement:
You have a large-ish set of (imperfectly) labelled data points. These data points can be represented as a 2D matrix. You need to train a model to classify these data points on either these labels, or on labels dervied from imperfect labels. You need a good (but not perfect) representative sample for the model to be generalizable, but there are too many data points for each label to manually pick representative examples. In a hurry? Heres what you need to do:"><meta itemprop=name content="Finding representative samples efficiently for large datasets"><meta itemprop=description content="Premise In this day and age, were not short on data. Good data, on the other hand, is very valuable. When youve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.
Lets formalize the problem a little so that a proper approach can be developed. Heres the problem statement:
You have a large-ish set of (imperfectly) labelled data points. These data points can be represented as a 2D matrix. You need to train a model to classify these data points on either these labels, or on labels dervied from imperfect labels. You need a good (but not perfect) representative sample for the model to be generalizable, but there are too many data points for each label to manually pick representative examples. In a hurry? Heres what you need to do:"><meta itemprop=datePublished content="2023-10-19T00:00:00+00:00"><meta itemprop=dateModified content="2023-10-19T00:00:00+00:00"><meta itemprop=wordCount content="3202"><meta itemprop=image content="https://avimallu.dev/static/favicon.ico"><meta itemprop=keywords content="Representative,Samples,Faiss,Approximate,Nearest,Neighbor,Network,Graph,Networkx,Polars,Category"><meta name=referrer content="no-referrer-when-downgrade"><link href=/original.min.css rel=stylesheet><link href=/syntax.min.css rel=stylesheet></head><body><header><a class=skip-link href=#main-content>Skip to main content</a>
<a href=/ class=title><h1>Avinash's Blog</h1></a><nav><a href=/>about</a>
<a href=/blog/>blog</a>
<a href=/projects/>projects</a>
<a href=https://avimallu.dev/index.xml>rss</a></nav></header><main id=main-content><h1>Finding representative samples efficiently for large datasets</h1><p class=byline><time datetime=2023-10-19 pubdate>2023-10-19
</time>· Avinash Mallya</p><content><h1 id=premise>Premise</h1><p>In this day and age, we&rsquo;re not short on data. <em>Good</em> data, on the other hand, is very valuable. When you&rsquo;ve got a large amount of improperly labelled data, it may become hard to find to find a representative dataset to train a model on such that it generalizes well.</p><p>Let&rsquo;s formalize the problem a little so that a proper approach can be developed. Here&rsquo;s the problem statement:</p><ol><li>You have a large-ish set of (imperfectly) labelled data points. These data points can be represented as a 2D matrix.</li><li>You need to train a model to classify these data points on either these labels, or on labels dervied from imperfect labels.</li><li>You need a good (but not perfect) representative sample for the model to be generalizable, but there are too many data points for each label to manually pick representative examples.</li></ol><h2 id=in-a-hurry>In a hurry?</h2><p>Here&rsquo;s what you need to do:</p><ol><li>Read the premise and see if it fits your problem.</li><li>Go to the <strong>For the folks in a hurry!</strong> section at the end to find the generic solution and how it works.</li></ol><h2 id=why-do-we-need-representative-samples>Why do we need representative samples?</h2><p>Generally, three things come to mind:</p><ol><li>Allows the model to be generalizable for all <em>kinds</em> of data points <em>within</em> a category.</li><li>Allows for faster training of the model - you need <em>fewer</em> data points to get the same accuracy!</li><li>Allows maintaining the training set - if your training set needs validation by experts or annotations, this keeps your costs low!</li></ol><h1 id=define-the-data>Define the data</h1><p>This data can be practically anything that can be represented as a 2D matrix.</p><p>There are exceptions. Raw image data (as numbers) might get difficult because even if you flatten them, they&rsquo;ll be significant correlation between features. For example, a face can appear practically anywhere in the image, and all pixels centered around the face will be highly correlated, even if they are on different lines. A workaround in this case would be to pipe the image through a CNN model that has been trained on some <em>generic</em> task and produces a 1D representation of a single image in the final hidden layer before the output. Other data will need further processing along similar lines.</p><h2 id=get-a-specific-dataset>Get a specific dataset</h2><p>For this specific article, I will use the <a href=https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization/data>ShopMania dataset on Kaggle</a>. I apologize in advance for not using a more easily accessible dataset (you need to sign into Kaggle to download it) - and I&rsquo;m not 100% sure if the GPL allows me to create a copy of the data and place it in my own repository. Nevertheless, the data (if you download it and choose to use it instead of some other dataset) will look like this:</p><blockquote><p><strong>NOTE</strong>: whenever I want to show an output <em>along</em> with the code I used for it, you&rsquo;ll see the characters <code>>></code> indicating the command used, and the output to be without those prefixes.</p></blockquote><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln> 1</span><span class=cl><span class=o>&gt;&gt;</span> <span class=kn>import</span> <span class=nn>polars</span> <span class=k>as</span> <span class=nn>pl</span>
</span></span><span class=line><span class=ln> 2</span><span class=cl><span class=o>&gt;&gt;</span> <span class=n>data</span> <span class=o>=</span> <span class=n>pl</span><span class=o>.</span><span class=n>read_csv</span><span class=p>(</span><span class=s2>&#34;archive/shopmania.csv&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln> 3</span><span class=cl><span class=o>&gt;&gt;</span> <span class=n>data</span>
</span></span><span class=line><span class=ln> 4</span><span class=cl><span class=n>shape</span><span class=p>:</span> <span class=p>(</span><span class=mi>313_705</span><span class=p>,</span> <span class=mi>4</span><span class=p>)</span>
</span></span><span class=line><span class=ln> 5</span><span class=cl><span class=err>┌────────────┬──────────────────────────────────────────────────────┬─────────────┬────────────────┐</span>
</span></span><span class=line><span class=ln> 6</span><span class=cl><span class=err></span> <span class=n>product_ID</span> <span class=err></span> <span class=n>product_title</span> <span class=err></span> <span class=n>category_ID</span> <span class=err></span> <span class=n>category_label</span> <span class=err></span>
</span></span><span class=line><span class=ln> 7</span><span class=cl><span class=err></span> <span class=o>---</span> <span class=err></span> <span class=o>---</span> <span class=err></span> <span class=o>---</span> <span class=err></span> <span class=o>---</span> <span class=err></span>
</span></span><span class=line><span class=ln> 8</span><span class=cl><span class=err></span> <span class=n>i64</span> <span class=err></span> <span class=nb>str</span> <span class=err></span> <span class=n>i64</span> <span class=err></span> <span class=nb>str</span> <span class=err></span>
</span></span><span class=line><span class=ln> 9</span><span class=cl><span class=err>╞════════════╪══════════════════════════════════════════════════════╪═════════════╪════════════════╡</span>
</span></span><span class=line><span class=ln>10</span><span class=cl><span class=err></span> <span class=mi>2</span> <span class=err></span> <span class=n>twilight</span> <span class=n>central</span> <span class=n>park</span> <span class=nb>print</span> <span class=err></span> <span class=mi>2</span> <span class=err></span> <span class=n>Collectibles</span> <span class=err></span>
</span></span><span class=line><span class=ln>11</span><span class=cl><span class=err></span> <span class=mi>3</span> <span class=err></span> <span class=n>fox</span> <span class=nb>print</span> <span class=err></span> <span class=mi>2</span> <span class=err></span> <span class=n>Collectibles</span> <span class=err></span>
</span></span><span class=line><span class=ln>12</span><span class=cl><span class=err></span> <span class=mi>4</span> <span class=err></span> <span class=n>circulo</span> <span class=n>de</span> <span class=n>papel</span> <span class=n>wall</span> <span class=n>art</span> <span class=err></span> <span class=mi>2</span> <span class=err></span> <span class=n>Collectibles</span> <span class=err></span>
</span></span><span class=line><span class=ln>13</span><span class=cl><span class=err></span> <span class=mi>5</span> <span class=err></span> <span class=n>hidden</span> <span class=n>path</span> <span class=nb>print</span> <span class=err></span> <span class=mi>2</span> <span class=err></span> <span class=n>Collectibles</span> <span class=err></span>
</span></span><span class=line><span class=ln>14</span><span class=cl><span class=err></span> <span class=err></span> <span class=err></span> <span class=err></span> <span class=err></span> <span class=err></span> <span class=err></span> <span class=err></span> <span class=err></span>
</span></span><span class=line><span class=ln>15</span><span class=cl><span class=err></span> <span class=mi>313703</span> <span class=err></span> <span class=n>deago</span> <span class=n>anti</span> <span class=n>fog</span> <span class=n>swimming</span> <span class=n>diving</span> <span class=n>full</span> <span class=n>face</span> <span class=n>mask</span> <span class=err></span> <span class=mi>229</span> <span class=err></span> <span class=n>Water</span> <span class=n>Sports</span> <span class=err></span>
</span></span><span class=line><span class=ln>16</span><span class=cl><span class=err></span> <span class=err></span> <span class=n>surface</span> <span class=n>snorkel</span> <span class=n>scuba</span> <span class=n>fr</span> <span class=n>gopro</span> <span class=n>black</span> <span class=n>s</span><span class=o>/</span><span class=n>m</span> <span class=err></span> <span class=err></span> <span class=err></span>
</span></span><span class=line><span class=ln>17</span><span class=cl><span class=err></span> <span class=mi>313704</span> <span class=err></span> <span class=n>etc</span> <span class=n>buys</span> <span class=n>full</span> <span class=n>face</span> <span class=n>gopro</span> <span class=n>compatible</span> <span class=n>snorkel</span> <span class=n>scuba</span> <span class=err></span> <span class=mi>229</span> <span class=err></span> <span class=n>Water</span> <span class=n>Sports</span> <span class=err></span>
</span></span><span class=line><span class=ln>18</span><span class=cl><span class=err></span> <span class=err></span> <span class=n>diving</span> <span class=n>mask</span> <span class=n>blue</span> <span class=n>large</span><span class=o>/</span><span class=n>xtralarge</span> <span class=n>blue</span> <span class=err></span> <span class=err></span> <span class=err></span>
</span></span><span class=line><span class=ln>19</span><span class=cl><span class=err></span> <span class=mi>313705</span> <span class=err></span> <span class=n>men</span> <span class=mi>039</span> <span class=n>s</span> <span class=n>full</span> <span class=n>face</span> <span class=n>breathe</span> <span class=n>free</span> <span class=n>diving</span> <span class=n>snorkel</span> <span class=n>mask</span> <span class=err></span> <span class=mi>229</span> <span class=err></span> <span class=n>Water</span> <span class=n>Sports</span> <span class=err></span>
</span></span><span class=line><span class=ln>20</span><span class=cl><span class=err></span> <span class=err></span> <span class=n>scuba</span> <span class=n>optional</span> <span class=n>hd</span> <span class=n>camera</span> <span class=n>blue</span> <span class=n>mask</span> <span class=n>only</span> <span class=n>adult</span> <span class=n>men</span> <span class=err></span> <span class=err></span> <span class=err></span>
</span></span><span class=line><span class=ln>21</span><span class=cl><span class=err></span> <span class=mi>313706</span> <span class=err></span> <span class=n>women</span> <span class=mi>039</span> <span class=n>s</span> <span class=n>full</span> <span class=n>face</span> <span class=n>breathe</span> <span class=n>free</span> <span class=n>diving</span> <span class=n>snorkel</span> <span class=err></span> <span class=mi>229</span> <span class=err></span> <span class=n>Water</span> <span class=n>Sports</span> <span class=err></span>
</span></span><span class=line><span class=ln>22</span><span class=cl><span class=err></span> <span class=err></span> <span class=n>mask</span> <span class=n>scuba</span> <span class=n>optional</span> <span class=n>hd</span> <span class=n>camera</span> <span class=n>black</span> <span class=n>mask</span> <span class=n>only</span> <span class=err></span> <span class=err></span> <span class=err></span>
</span></span><span class=line><span class=ln>23</span><span class=cl><span class=err></span> <span class=err></span> <span class=n>children</span> <span class=ow>and</span> <span class=n>women</span> <span class=err></span> <span class=err></span> <span class=err></span>
</span></span><span class=line><span class=ln>24</span><span class=cl><span class=err>└────────────┴──────────────────────────────────────────────────────┴─────────────┴────────────────┘</span></span></span></code></pre></div><p>The data documentation on Kaggle states:</p><blockquote><p>The first dataset originates from ShopMania, a popular online product comparison platform. It enlists tens of millions of products organized in a three-level hierarchy that includes 230 categories. The two higher levels of the hierarchy include 39 categories, whereas the third lower level accommodates the rest 191 leaf categories. Each product is categorized into this tree structure by being mapped to only one leaf category. Some of these 191 leaf categories contain millions of products. However, shopmania.com allows only the first 10,000 products to be retrieved from each category. Under this restriction, our crawler managed to collect 313,706 products.</p></blockquote><p>For demonstration, I&rsquo;ll just limit the categories to those that have exactly 10,000 occurences.</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln>1</span><span class=cl><span class=n>data</span> <span class=o>=</span> <span class=p>(</span>
</span></span><span class=line><span class=ln>2</span><span class=cl> <span class=n>data</span>
</span></span><span class=line><span class=ln>3</span><span class=cl> <span class=o>.</span><span class=n>filter</span><span class=p>(</span><span class=n>pl</span><span class=o>.</span><span class=n>count</span><span class=p>()</span><span class=o>.</span><span class=n>over</span><span class=p>(</span><span class=s2>&#34;category_ID&#34;</span><span class=p>)</span> <span class=o>==</span> <span class=mi>10000</span><span class=p>)</span>
</span></span><span class=line><span class=ln>4</span><span class=cl><span class=p>)</span></span></span></code></pre></div><p>You&rsquo;ll notice that there are only 17 categories in this dataset. Run this to verify that fact.</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln> 1</span><span class=cl><span class=o>&gt;&gt;&gt;</span> <span class=n>data</span><span class=o>.</span><span class=n>get_column</span><span class=p>(</span><span class=s2>&#34;category_label&#34;</span><span class=p>)</span><span class=o>.</span><span class=n>unique</span><span class=p>()</span>
</span></span><span class=line><span class=ln> 2</span><span class=cl><span class=n>shape</span><span class=p>:</span> <span class=p>(</span><span class=mi>17</span><span class=p>,)</span>
</span></span><span class=line><span class=ln> 3</span><span class=cl><span class=n>Series</span><span class=p>:</span> <span class=s1>&#39;category_label&#39;</span> <span class=p>[</span><span class=nb>str</span><span class=p>]</span>
</span></span><span class=line><span class=ln> 4</span><span class=cl><span class=p>[</span>
</span></span><span class=line><span class=ln> 5</span><span class=cl> <span class=s2>&#34;Kitchen &amp; Dining&#34;</span>
</span></span><span class=line><span class=ln> 6</span><span class=cl> <span class=s2>&#34;Scarves and wraps&#34;</span>
</span></span><span class=line><span class=ln> 7</span><span class=cl> <span class=s2>&#34;Handbags &amp; Wallets&#34;</span>
</span></span><span class=line><span class=ln> 8</span><span class=cl> <span class=s2>&#34;Rugs Tapestry &amp; Linens&#34;</span>
</span></span><span class=line><span class=ln> 9</span><span class=cl> <span class=s2>&#34;Cell Phones Accessories&#34;</span>
</span></span><span class=line><span class=ln>10</span><span class=cl> <span class=s2>&#34;Men&#39;s Clothing&#34;</span>
</span></span><span class=line><span class=ln>11</span><span class=cl> <span class=s2>&#34;Jewelry&#34;</span>
</span></span><span class=line><span class=ln>12</span><span class=cl> <span class=s2>&#34;Belts&#34;</span>
</span></span><span class=line><span class=ln>13</span><span class=cl> <span class=s2>&#34;Men Lingerie&#34;</span>
</span></span><span class=line><span class=ln>14</span><span class=cl> <span class=s2>&#34;Crafts&#34;</span>
</span></span><span class=line><span class=ln>15</span><span class=cl> <span class=s2>&#34;Football&#34;</span>
</span></span><span class=line><span class=ln>16</span><span class=cl> <span class=s2>&#34;Medical Supplies&#34;</span>
</span></span><span class=line><span class=ln>17</span><span class=cl> <span class=s2>&#34;Adult&#34;</span>
</span></span><span class=line><span class=ln>18</span><span class=cl> <span class=s2>&#34;Hunting&#34;</span>
</span></span><span class=line><span class=ln>19</span><span class=cl> <span class=s2>&#34;Women&#39;s Clothing&#34;</span>
</span></span><span class=line><span class=ln>20</span><span class=cl> <span class=s2>&#34;Pet Supply&#34;</span>
</span></span><span class=line><span class=ln>21</span><span class=cl> <span class=s2>&#34;Office Supplies&#34;</span>
</span></span><span class=line><span class=ln>22</span><span class=cl><span class=p>]</span></span></span></code></pre></div><p>Note that this is very easy in Polars, which is the package I typically use for data manipulation. I recommend using it over Pandas.</p><h2 id=specify-the-task>Specify the task</h2><p>Okay - so now we have exactly 10,000 products <em>per</em> category. We only have the title of the product that can be leveraged for categorization. So let me define the task this way:</p><blockquote><p>Craft a <em>small</em> representative sample for each category.</p></blockquote><p>Why small? It helps that it&rsquo;ll make the model faster to train - <em>and</em> keep the training data manageable in size.</p><h1 id=finding-representative-samples>Finding representative samples</h1><p>I mentioned earlier that we need to represent data as a 2D matrix for the technique I have in mind to work. How can I translate a list of text to a matrix? The answer&rsquo;s rather simple: use <code>SentenceTransformers</code> to get a string&rsquo;s embedding. You could also use more classic techniques like computing TF-IDF values, or use more advanced transformers, but I&rsquo;ve noticed that <code>SentenceTransformers</code> are able to capture semantic meaning of sentences rather well (assuming you use a good model suited for the language the data is in) - they are trained on sentence similarity after all.</p><h2 id=getting-sentencetransformer-embeddings>Getting <code>SentenceTransformer</code> embeddings</h2><p>This part is rather simple. If you&rsquo;re unable to install SentenceTransformers, <a href=https://www.sbert.net/docs/installation.html>please check their website</a>.</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln>1</span><span class=cl><span class=kn>import</span> <span class=nn>sentence_transformers</span>
</span></span><span class=line><span class=ln>2</span><span class=cl><span class=c1># See list of models at www.sbert.net/docs/pretrained_models.html</span>
</span></span><span class=line><span class=ln>3</span><span class=cl><span class=n>ST</span> <span class=o>=</span> <span class=n>sentence_transformers</span><span class=o>.</span><span class=n>SentenceTransformer</span><span class=p>(</span><span class=s2>&#34;all-mpnet-base-v2&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>4</span><span class=cl><span class=n>title_embeddings</span> <span class=o>=</span> <span class=p>(</span>
</span></span><span class=line><span class=ln>5</span><span class=cl> <span class=n>ST</span><span class=o>.</span><span class=n>encode</span><span class=p>(</span>
</span></span><span class=line><span class=ln>6</span><span class=cl> <span class=n>data</span><span class=o>.</span><span class=n>get_column</span><span class=p>(</span><span class=s2>&#34;product_title&#34;</span><span class=p>)</span><span class=o>.</span><span class=n>to_list</span><span class=p>(),</span>
</span></span><span class=line><span class=ln>7</span><span class=cl> <span class=n>show_progress_bar</span><span class=o>=</span><span class=kc>True</span><span class=p>,</span> <span class=n>convert_to_tensor</span><span class=o>=</span><span class=kc>True</span><span class=p>)</span>
</span></span><span class=line><span class=ln>8</span><span class=cl> <span class=o>.</span><span class=n>numpy</span><span class=p>())</span></span></span></code></pre></div><p>This process will be slow (~30 minutes) if you don&rsquo;t have a GPU. There are faster approaches, but they are slightly more involved than would be beneficial for a blog post. The wait will be worth it, I promise! In addition, the call to <code>.numpy()</code> at the end is to directly get a single <code>numpy</code> array - otherwise you get a <code>list</code> of <code>numpy</code> arrays, which is rather inefficient. Further, <code>SentenceTransformers</code> will try to run on the GPU if available, and if so, you will need to write <code>.cpu().numpy()</code> so that the tensor is copied from the GPU to the CPU.</p><blockquote><p><strong>NOTE</strong>: for a proof-of-concept implementation, or if you&rsquo;re on the CPU, try the <code>all-MiniLM-L6-v2</code> model. It&rsquo;s a much smaller and much faster model, although you sacrifice a little in terms of accuracy.</p></blockquote><h2 id=the-concept-of-_approximate_-nearest-neighbors>The concept of <em>approximate</em> nearest neighbors</h2><p>Performing any kind of nearest neighbor algorithm on medium scale datasets (even bordering 10,000 rows and tens of columns) tends to be slow. A primary driver of this was the need to calculate all, or nearly all distances between all data points. <em>Approximate</em> nearest neighbor (ANN) algorithms work around this through various approaches, which warrant their own blog post. For now, it would suffice to understand that there are shortcuts that ANN algorithms take to give you if not the exact nearest neighbor, at least <em>one</em> of the nearest neighbors (hence the term <em>approximate</em>).</p><p>There are several algorithms that you can use - I shall proceed with <code>faiss</code>, because it has a nice Python interface and is rather easy to work with. You can use any algorithm - a full list of the major ones are <a href=https://github.com/erikbern/ann-benchmarks>available here</a>.</p><p>I&rsquo;ll explain why we&rsquo;re in the nearest neighbor territory in due course.</p><h3 id=building-the-database>Building the database</h3><p>To build the database, all we need is the <code>title_embeddings</code> matrix.</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln>1</span><span class=cl><span class=kn>import</span> <span class=nn>faiss</span>
</span></span><span class=line><span class=ln>2</span><span class=cl><span class=k>def</span> <span class=nf>create_index</span><span class=p>(</span><span class=n>title_embeddings</span><span class=p>):</span>
</span></span><span class=line><span class=ln>3</span><span class=cl> <span class=n>d</span> <span class=o>=</span> <span class=n>title_embeddings</span><span class=o>.</span><span class=n>shape</span><span class=p>[</span><span class=mi>1</span><span class=p>]</span> <span class=c1># Number of dimensions</span>
</span></span><span class=line><span class=ln>4</span><span class=cl> <span class=n>ann_index</span> <span class=o>=</span> <span class=n>faiss</span><span class=o>.</span><span class=n>IndexFlatL2</span><span class=p>(</span><span class=n>d</span><span class=p>)</span> <span class=c1># Index using Eucledian Matrix</span>
</span></span><span class=line><span class=ln>5</span><span class=cl> <span class=n>ann_index</span><span class=o>.</span><span class=n>add</span><span class=p>(</span><span class=n>title_embeddings</span><span class=p>)</span> <span class=c1># Build the index</span>
</span></span><span class=line><span class=ln>6</span><span class=cl>
</span></span><span class=line><span class=ln>7</span><span class=cl> <span class=k>return</span> <span class=n>ann_index</span> <span class=c1># Faiss considers databases an &#34;index&#34;</span></span></span></code></pre></div><p>This does create <em>a</em> database. But remember, we&rsquo;re trying to find <em>representative samples</em> - which means we need to do this <em>by</em> the category (or label). So let&rsquo;s design a function that sends only the necessary data as that for a particular category, and then create the database. We&rsquo;ll need three pieces of information from this function:</p><ol><li>The actual <code>faiss</code> database.</li><li>The actual subset of data that was used to build this index.</li><li>The label indices with respect to the original data that went into the <code>faiss</code> database.</li></ol><p>(2) and (3) will help us later in rebuilding a &ldquo;network graph&rdquo; that will allow us to reference the original data points.</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln> 1</span><span class=cl><span class=kn>import</span> <span class=nn>faiss</span>
</span></span><span class=line><span class=ln> 2</span><span class=cl><span class=kn>import</span> <span class=nn>numpy</span> <span class=k>as</span> <span class=nn>np</span>
</span></span><span class=line><span class=ln> 3</span><span class=cl><span class=kn>import</span> <span class=nn>polars</span> <span class=k>as</span> <span class=nn>pl</span>
</span></span><span class=line><span class=ln> 4</span><span class=cl>
</span></span><span class=line><span class=ln> 5</span><span class=cl><span class=k>def</span> <span class=nf>create_index</span><span class=p>(</span><span class=n>label</span><span class=p>):</span>
</span></span><span class=line><span class=ln> 6</span><span class=cl> <span class=n>faiss_indices</span> <span class=o>=</span> <span class=p>(</span>
</span></span><span class=line><span class=ln> 7</span><span class=cl> <span class=n>data</span> <span class=c1># this needs to be an argument if you want to create a generic function</span>
</span></span><span class=line><span class=ln> 8</span><span class=cl> <span class=o>.</span><span class=n>with_row_count</span><span class=p>(</span><span class=s2>&#34;row_idx&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln> 9</span><span class=cl> <span class=o>.</span><span class=n>filter</span><span class=p>(</span><span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;category_label&#34;</span><span class=p>)</span> <span class=o>==</span> <span class=n>label</span><span class=p>)</span>
</span></span><span class=line><span class=ln>10</span><span class=cl> <span class=o>.</span><span class=n>get_column</span><span class=p>(</span><span class=s2>&#34;row_idx&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>11</span><span class=cl> <span class=o>.</span><span class=n>to_list</span><span class=p>()</span>
</span></span><span class=line><span class=ln>12</span><span class=cl> <span class=p>)</span>
</span></span><span class=line><span class=ln>13</span><span class=cl>
</span></span><span class=line><span class=ln>14</span><span class=cl> <span class=n>faiss_data</span> <span class=o>=</span> <span class=n>title_embeddings</span><span class=p>[</span><span class=n>faiss_indices</span><span class=p>]</span>
</span></span><span class=line><span class=ln>15</span><span class=cl> <span class=n>d</span> <span class=o>=</span> <span class=n>data</span><span class=o>.</span><span class=n>shape</span><span class=p>[</span><span class=mi>1</span><span class=p>]</span> <span class=c1># Number of dimensions</span>
</span></span><span class=line><span class=ln>16</span><span class=cl> <span class=n>faiss_DB</span> <span class=o>=</span> <span class=n>faiss</span><span class=o>.</span><span class=n>IndexFlatIP</span><span class=p>(</span><span class=n>d</span><span class=p>)</span> <span class=c1># Index using Inner Product</span>
</span></span><span class=line><span class=ln>17</span><span class=cl> <span class=n>faiss</span><span class=o>.</span><span class=n>normalize_L2</span><span class=p>(</span><span class=n>data</span><span class=p>)</span> <span class=c1># Normalized L2 with Inner Product search = cosine similarity</span>
</span></span><span class=line><span class=ln>18</span><span class=cl> <span class=c1># Why cosine similarity? It&#39;s easier to specify thresholds - they&#39;ll always be between 0 and 1.4.</span>
</span></span><span class=line><span class=ln>19</span><span class=cl> <span class=c1># If using Eucledian or other distance, we&#39;ll have to spend some time finding a good range</span>
</span></span><span class=line><span class=ln>20</span><span class=cl> <span class=c1># where distances are reasonable. See https://stats.stackexchange.com/a/146279 for details.</span>
</span></span><span class=line><span class=ln>21</span><span class=cl> <span class=n>faiss_DB</span><span class=o>.</span><span class=n>add</span><span class=p>(</span><span class=n>data</span><span class=p>)</span> <span class=c1># Build the index</span>
</span></span><span class=line><span class=ln>22</span><span class=cl>
</span></span><span class=line><span class=ln>23</span><span class=cl> <span class=k>return</span> <span class=n>faiss_DB</span><span class=p>,</span> <span class=n>faiss_data</span><span class=p>,</span> <span class=n>faiss_indices</span></span></span></code></pre></div><h3 id=identifying-the-nearest-neighbors>Identifying the nearest neighbors</h3><p>To proceed with getting a representative sample, the next step is to find the nearest neighbors for <strong>all</strong> data points in the database. This isn&rsquo;t too hard - <code>faiss</code> <code>index</code> objects have a built-in <code>search</code> method to find the <code>k</code> nearest neighbors for a given index, along with the (approximate) distance to it. Let&rsquo;s then write a function to get the following information: the label index for whom nearest neighbors are being searched, the indices of said nearest neighbors and the distance between them. In network graph parlance, this kind of data is called an <em>edge list</em> i.e. a list of pair of <em>nodes</em> that are connected, along with any additional information that specifies a property (in this case distance) of the <em>edge</em> that connects these <em>nodes</em>.</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln> 1</span><span class=cl><span class=k>def</span> <span class=nf>get_edge_list</span><span class=p>(</span><span class=n>label</span><span class=p>,</span> <span class=n>k</span><span class=o>=</span><span class=mi>5</span><span class=p>):</span>
</span></span><span class=line><span class=ln> 2</span><span class=cl> <span class=n>faiss_DB</span><span class=p>,</span> <span class=n>faiss_data</span><span class=p>,</span> <span class=n>faiss_indices</span> <span class=o>=</span> <span class=n>create_index</span><span class=p>(</span><span class=n>label</span><span class=p>)</span>
</span></span><span class=line><span class=ln> 3</span><span class=cl> <span class=c1># To map the data back to the original `train[b&#39;data&#39;]` array</span>
</span></span><span class=line><span class=ln> 4</span><span class=cl> <span class=n>faiss_indices_map</span> <span class=o>=</span> <span class=p>{</span><span class=n>i</span><span class=p>:</span> <span class=n>x</span> <span class=k>for</span> <span class=n>i</span><span class=p>,</span><span class=n>x</span> <span class=ow>in</span> <span class=nb>enumerate</span><span class=p>(</span><span class=n>faiss_indices</span><span class=p>)}</span>
</span></span><span class=line><span class=ln> 5</span><span class=cl> <span class=c1># To map the indices back to the original strings</span>
</span></span><span class=line><span class=ln> 6</span><span class=cl> <span class=n>title_name_map</span> <span class=o>=</span> <span class=p>{</span><span class=n>i</span><span class=p>:</span> <span class=n>x</span> <span class=k>for</span> <span class=n>i</span><span class=p>,</span><span class=n>x</span> <span class=ow>in</span> <span class=n>data</span><span class=o>.</span><span class=n>select</span><span class=p>(</span><span class=s2>&#34;row_idx&#34;</span><span class=p>,</span> <span class=s2>&#34;product_title&#34;</span><span class=p>)</span><span class=o>.</span><span class=n>rows</span><span class=p>()}</span>
</span></span><span class=line><span class=ln> 7</span><span class=cl> <span class=n>distances</span><span class=p>,</span> <span class=n>neighbors</span> <span class=o>=</span> <span class=n>faiss_DB</span><span class=o>.</span><span class=n>search</span><span class=p>(</span><span class=n>faiss_data</span><span class=p>,</span> <span class=n>k</span><span class=p>)</span>
</span></span><span class=line><span class=ln> 8</span><span class=cl>
</span></span><span class=line><span class=ln> 9</span><span class=cl> <span class=k>return</span> <span class=p>(</span>
</span></span><span class=line><span class=ln>10</span><span class=cl> <span class=n>pl</span><span class=o>.</span><span class=n>DataFrame</span><span class=p>({</span>
</span></span><span class=line><span class=ln>11</span><span class=cl> <span class=s2>&#34;from&#34;</span><span class=p>:</span> <span class=n>faiss_indices</span><span class=p>})</span>
</span></span><span class=line><span class=ln>12</span><span class=cl> <span class=o>.</span><span class=n>with_columns</span><span class=p>(</span>
</span></span><span class=line><span class=ln>13</span><span class=cl> <span class=n>pl</span><span class=o>.</span><span class=n>Series</span><span class=p>(</span><span class=s2>&#34;to&#34;</span><span class=p>,</span> <span class=n>neighbors</span><span class=p>),</span>
</span></span><span class=line><span class=ln>14</span><span class=cl> <span class=n>pl</span><span class=o>.</span><span class=n>Series</span><span class=p>(</span><span class=s2>&#34;distance&#34;</span><span class=p>,</span> <span class=n>distances</span><span class=p>))</span>
</span></span><span class=line><span class=ln>15</span><span class=cl> <span class=o>.</span><span class=n>explode</span><span class=p>(</span><span class=s2>&#34;to&#34;</span><span class=p>,</span> <span class=s2>&#34;distance&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>16</span><span class=cl> <span class=o>.</span><span class=n>with_columns</span><span class=p>(</span>
</span></span><span class=line><span class=ln>17</span><span class=cl> <span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;from&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>18</span><span class=cl> <span class=o>.</span><span class=n>map_dict</span><span class=p>(</span><span class=n>title_name_map</span><span class=p>),</span>
</span></span><span class=line><span class=ln>19</span><span class=cl> <span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;to&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>20</span><span class=cl> <span class=o>.</span><span class=n>map_dict</span><span class=p>(</span><span class=n>faiss_indices_map</span><span class=p>)</span>
</span></span><span class=line><span class=ln>21</span><span class=cl> <span class=o>.</span><span class=n>map_dict</span><span class=p>(</span><span class=n>title_name_map</span><span class=p>))</span>
</span></span><span class=line><span class=ln>22</span><span class=cl> <span class=o>.</span><span class=n>filter</span><span class=p>(</span><span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;from&#34;</span><span class=p>)</span> <span class=o>!=</span> <span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;to&#34;</span><span class=p>))</span>
</span></span><span class=line><span class=ln>23</span><span class=cl> <span class=p>)</span> </span></span></code></pre></div><h3 id=networkx-and-connected-components>NetworkX and Connected Components</h3><p>The next step in the process is to create a network graph using the edge-list. But why?</p><p>Remember that we have identified the (k=5) nearest neighbors of <strong>each</strong> data point. Let&rsquo;s say that we have a point A that has a nearest neighbor B. C is <strong>not</strong> a nearest neighbor of A, but it is a nearest neighbor of B. In a network graph, if A and C are sufficiently similar enough to B within a particular <em>minimum thershold</em>, then A will be connected to C through B! Hopefully a small visual below would help.</p><p><img src=/blog/002_representative_samples/001_Network_Cluster_1.png alt="How a network component is formed."></p><p>What happens when such a concept is extended for many data points? Not all of them would be connected - because we&rsquo;re applying a <em>minimum</em> threshold that they have to meet. This is the only hueristic part of the rather fast process. Here&rsquo;s one more helpful visual:</p><p><img src=/blog/002_representative_samples/002_Network_Cluster_2.png alt="How a network cluster is formed."></p><p>Very starry night-eque vibes here. Let&rsquo;s get to the code.</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln>1</span><span class=cl><span class=kn>import</span> <span class=nn>networkx</span> <span class=k>as</span> <span class=nn>nx</span>
</span></span><span class=line><span class=ln>2</span><span class=cl><span class=k>def</span> <span class=nf>get_cluster_map</span><span class=p>(</span><span class=n>label</span><span class=p>,</span> <span class=n>k</span><span class=o>=</span><span class=mi>5</span><span class=p>,</span> <span class=n>min_cosine_distance</span><span class=o>=</span><span class=mf>0.95</span><span class=p>):</span>
</span></span><span class=line><span class=ln>3</span><span class=cl> <span class=n>edge_list</span> <span class=o>=</span> <span class=p>(</span>
</span></span><span class=line><span class=ln>4</span><span class=cl> <span class=n>get_edge_list</span><span class=p>(</span><span class=n>label</span><span class=p>,</span> <span class=n>k</span><span class=o>=</span><span class=n>k</span><span class=p>)</span>
</span></span><span class=line><span class=ln>5</span><span class=cl> <span class=o>.</span><span class=n>filter</span><span class=p>(</span><span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;distance&#34;</span><span class=p>)</span> <span class=o>&gt;=</span> <span class=n>min_cosine_distance</span><span class=p>)</span>
</span></span><span class=line><span class=ln>6</span><span class=cl> <span class=p>)</span>
</span></span><span class=line><span class=ln>7</span><span class=cl> <span class=n>graph</span> <span class=o>=</span> <span class=n>nx</span><span class=o>.</span><span class=n>from_pandas_edgelist</span><span class=p>(</span><span class=n>edge_list</span><span class=o>.</span><span class=n>to_pandas</span><span class=p>(),</span> <span class=n>source</span><span class=o>=</span><span class=s2>&#34;from&#34;</span><span class=p>,</span> <span class=n>target</span><span class=o>=</span><span class=s2>&#34;to&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>8</span><span class=cl> <span class=k>return</span> <span class=p>{</span><span class=n>i</span><span class=p>:</span> <span class=nb>list</span><span class=p>(</span><span class=n>x</span><span class=p>)</span> <span class=k>for</span> <span class=n>i</span><span class=p>,</span><span class=n>x</span> <span class=ow>in</span> <span class=nb>enumerate</span><span class=p>(</span><span class=n>nx</span><span class=o>.</span><span class=n>connected_components</span><span class=p>(</span><span class=n>graph</span><span class=p>))}</span></span></span></code></pre></div><h1 id=getting-clusters>Getting clusters</h1><p>Now that all the parts of the puzzle are together, let&rsquo;s run it to see what kind of clusters you get for <code>Cell Phone Accessories</code>.</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln>1</span><span class=cl><span class=n>clusters</span> <span class=o>=</span> <span class=n>get_cluster_map</span><span class=p>(</span><span class=s2>&#34;Cell Phones Accessories&#34;</span><span class=p>,</span> <span class=mi>5</span><span class=p>,</span> <span class=mf>0.95</span><span class=p>)</span></span></span></code></pre></div><p>Make sure to configure the following if your results aren&rsquo;t good enough:</p><ol><li>Relax the <code>min_cosine_distance</code> value if you want <em>bigger</em> clusters.</li><li>Increase the number of nearest neighbors if you want <em>more</em> matches.</li></ol><h2 id=viewing-the-components>Viewing the components</h2><p>There will likely be many clusters (you can see how many exactly with <code>len(clusters)</code>). Let&rsquo;s look at a random cluster:</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln>1</span><span class=cl><span class=o>&gt;&gt;</span> <span class=n>clusters</span><span class=p>[</span><span class=mi>3</span><span class=p>]</span>
</span></span><span class=line><span class=ln>2</span><span class=cl><span class=p>[</span><span class=s1>&#39;smartphone lanyard with card slot for any phone up to 6 yellow 72570099&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>3</span><span class=cl> <span class=s1>&#39;smartphone lanyard with card slot for any phone up to 6 black 72570093&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>4</span><span class=cl> <span class=s1>&#39;smartphone lanyard with card slot for any phone up to 6 lightblue 72570097&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>5</span><span class=cl> <span class=s1>&#39;smartphone lanyard with card slot for any phone up to 6 blue 72570095&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>6</span><span class=cl> <span class=s1>&#39;smartphone lanyard with card slot for any phone up to 6 green 72570101&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>7</span><span class=cl> <span class=s1>&#39;smartphone lanyard with card slot for any phone up to 6 pink 72570091&#39;</span><span class=p>]</span></span></span></code></pre></div><p>Let&rsquo;s see another cluster that had 172(!) members in my run (the clusters themselves will be stable, but their indices may change in each run owing to some inherent randomness in the process).</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln> 1</span><span class=cl><span class=o>&gt;&gt;&gt;</span> <span class=n>clusters</span><span class=p>[</span><span class=mi>6</span><span class=p>]</span>
</span></span><span class=line><span class=ln> 2</span><span class=cl><span class=p>[</span><span class=s1>&#39;otm essentials iphone 8/7 modern clear printed phone case snowflakes iphone 8/7 op qq z051a&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln> 3</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 arrows blue op qq a02 58&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln> 4</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s golden pineapple op qq z089a&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln> 5</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s butteryfly delight yellow op qq z029d&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln> 6</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 luck of the irish op qq a01 45&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln> 7</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid white op qq a02 16&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln> 8</span><span class=cl> <span class=o>...</span>
</span></span><span class=line><span class=ln> 9</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 flying arrows white op qq hip 20&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>10</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 brides maid pink white op qq a02 17&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>11</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7 modern clear printed phone case iphone 8/7 anemone flowers white op qq z036a&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>12</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7 modern clear printed phone case mustache iphone 8/7 op qq hip 08&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>13</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7 modern clear printed phone case oh snap iphone 8/7 op qq z053a&#39;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>14</span><span class=cl> <span class=s1>&#39;otm essentials iphone 8/7/6s clear printed phone case single iphone 8/7/6s desert cacti orange pink op qq a02 22&#39;</span><span class=p>]</span></span></span></code></pre></div><h2 id=running-for-all-categories>Running for all categories</h2><p>This isn&rsquo;t that hard (although it may take more than a moment). Just iterate it for each category!</p><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln>1</span><span class=cl><span class=n>clusters</span> <span class=o>=</span> <span class=p>[</span><span class=n>get_cluster_map</span><span class=p>(</span><span class=n>x</span><span class=p>,</span> <span class=mi>5</span><span class=p>,</span> <span class=mf>0.95</span><span class=p>)</span> <span class=k>for</span> <span class=n>x</span> <span class=ow>in</span> <span class=n>data</span><span class=o>.</span><span class=n>get_column</span><span class=p>(</span><span class=s2>&#34;category_label&#34;</span><span class=p>)</span><span class=o>.</span><span class=n>unique</span><span class=p>()]</span></span></span></code></pre></div><h1 id=for-the-folks-in-a-hurry>For the folks in a hurry!</h1><p>I get it - you often want a solution that &ldquo;just works&rdquo;. I can come close to it. See below for code and a succinct explanation. For those of my readers who aren&rsquo;t in a hurry, this also serves as a nice summary (and copy-pastable code)!</p><h2 id=the-code>The code</h2><div class=highlight><pre tabindex=0 class=chroma><code class=language-py data-lang=py><span class=line><span class=ln> 1</span><span class=cl><span class=kn>import</span> <span class=nn>sentence_transformers</span>
</span></span><span class=line><span class=ln> 2</span><span class=cl><span class=kn>import</span> <span class=nn>faiss</span>
</span></span><span class=line><span class=ln> 3</span><span class=cl><span class=kn>import</span> <span class=nn>polars</span> <span class=k>as</span> <span class=nn>pl</span>
</span></span><span class=line><span class=ln> 4</span><span class=cl><span class=kn>import</span> <span class=nn>numpy</span> <span class=k>as</span> <span class=nn>np</span>
</span></span><span class=line><span class=ln> 5</span><span class=cl>
</span></span><span class=line><span class=ln> 6</span><span class=cl><span class=c1># Data is read here. You download the files from Kaggle here: </span>
</span></span><span class=line><span class=ln> 7</span><span class=cl><span class=c1># https://www.kaggle.com/datasets/lakritidis/product-classification-and-categorization</span>
</span></span><span class=line><span class=ln> 8</span><span class=cl><span class=n>data</span> <span class=o>=</span> <span class=n>pl</span><span class=o>.</span><span class=n>read_csv</span><span class=p>(</span><span class=s2>&#34;archive/shopmania.csv&#34;</span><span class=p>,</span> <span class=n>new_columns</span><span class=o>=</span><span class=p>[</span>
</span></span><span class=line><span class=ln> 9</span><span class=cl> <span class=s2>&#34;product_ID&#34;</span><span class=p>,</span> <span class=s2>&#34;product_title&#34;</span><span class=p>,</span> <span class=s2>&#34;category_ID&#34;</span><span class=p>,</span> <span class=s2>&#34;category_label&#34;</span><span class=p>])</span>
</span></span><span class=line><span class=ln>10</span><span class=cl><span class=n>data</span> <span class=o>=</span> <span class=p>(</span>
</span></span><span class=line><span class=ln>11</span><span class=cl> <span class=n>data</span>
</span></span><span class=line><span class=ln>12</span><span class=cl> <span class=o>.</span><span class=n>filter</span><span class=p>(</span><span class=n>pl</span><span class=o>.</span><span class=n>count</span><span class=p>()</span><span class=o>.</span><span class=n>over</span><span class=p>(</span><span class=s2>&#34;category_ID&#34;</span><span class=p>)</span> <span class=o>==</span> <span class=mi>10000</span><span class=p>)</span>
</span></span><span class=line><span class=ln>13</span><span class=cl> <span class=o>.</span><span class=n>with_row_count</span><span class=p>(</span><span class=s2>&#34;row_idx&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>14</span><span class=cl><span class=p>)</span>
</span></span><span class=line><span class=ln>15</span><span class=cl>
</span></span><span class=line><span class=ln>16</span><span class=cl>
</span></span><span class=line><span class=ln>17</span><span class=cl><span class=c1># See list of models at www.sbert.net/docs/pretrained_models.html</span>
</span></span><span class=line><span class=ln>18</span><span class=cl><span class=n>ST</span> <span class=o>=</span> <span class=n>sentence_transformers</span><span class=o>.</span><span class=n>SentenceTransformer</span><span class=p>(</span><span class=s2>&#34;all-mpnet-base-v2&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>19</span><span class=cl><span class=n>title_embeddings</span> <span class=o>=</span> <span class=p>(</span>
</span></span><span class=line><span class=ln>20</span><span class=cl> <span class=n>ST</span><span class=o>.</span><span class=n>encode</span><span class=p>(</span>
</span></span><span class=line><span class=ln>21</span><span class=cl> <span class=n>data</span><span class=o>.</span><span class=n>get_column</span><span class=p>(</span><span class=s2>&#34;product_title&#34;</span><span class=p>)</span><span class=o>.</span><span class=n>to_list</span><span class=p>(),</span>
</span></span><span class=line><span class=ln>22</span><span class=cl> <span class=c1># I&#39;m on a MacBook, you should use `cuda` or `cpu`</span>
</span></span><span class=line><span class=ln>23</span><span class=cl> <span class=c1># if you&#39;ve got different hardware.</span>
</span></span><span class=line><span class=ln>24</span><span class=cl> <span class=n>device</span><span class=o>=</span><span class=s2>&#34;mps&#34;</span><span class=p>,</span>
</span></span><span class=line><span class=ln>25</span><span class=cl> <span class=n>show_progress_bar</span><span class=o>=</span><span class=kc>True</span><span class=p>,</span> <span class=n>convert_to_tensor</span><span class=o>=</span><span class=kc>True</span><span class=p>)</span>
</span></span><span class=line><span class=ln>26</span><span class=cl> <span class=o>.</span><span class=n>cpu</span><span class=p>()</span><span class=o>.</span><span class=n>numpy</span><span class=p>())</span>
</span></span><span class=line><span class=ln>27</span><span class=cl>
</span></span><span class=line><span class=ln>28</span><span class=cl><span class=c1># Code to create a FAISS index</span>
</span></span><span class=line><span class=ln>29</span><span class=cl><span class=k>def</span> <span class=nf>create_index</span><span class=p>(</span><span class=n>label</span><span class=p>):</span>
</span></span><span class=line><span class=ln>30</span><span class=cl> <span class=n>faiss_indices</span> <span class=o>=</span> <span class=p>(</span>
</span></span><span class=line><span class=ln>31</span><span class=cl> <span class=n>data</span> <span class=c1># this needs to be an argument if you want to create a generic function</span>
</span></span><span class=line><span class=ln>32</span><span class=cl> <span class=o>.</span><span class=n>filter</span><span class=p>(</span><span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;category_label&#34;</span><span class=p>)</span> <span class=o>==</span> <span class=n>label</span><span class=p>)</span>
</span></span><span class=line><span class=ln>33</span><span class=cl> <span class=o>.</span><span class=n>get_column</span><span class=p>(</span><span class=s2>&#34;row_idx&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>34</span><span class=cl> <span class=o>.</span><span class=n>to_list</span><span class=p>()</span>
</span></span><span class=line><span class=ln>35</span><span class=cl> <span class=p>)</span>
</span></span><span class=line><span class=ln>36</span><span class=cl>
</span></span><span class=line><span class=ln>37</span><span class=cl> <span class=n>faiss_data</span> <span class=o>=</span> <span class=n>title_embeddings</span><span class=p>[</span><span class=n>faiss_indices</span><span class=p>]</span>
</span></span><span class=line><span class=ln>38</span><span class=cl> <span class=n>d</span> <span class=o>=</span> <span class=n>faiss_data</span><span class=o>.</span><span class=n>shape</span><span class=p>[</span><span class=mi>1</span><span class=p>]</span> <span class=c1># Number of dimensions</span>
</span></span><span class=line><span class=ln>39</span><span class=cl> <span class=n>faiss_DB</span> <span class=o>=</span> <span class=n>faiss</span><span class=o>.</span><span class=n>IndexFlatIP</span><span class=p>(</span><span class=n>d</span><span class=p>)</span> <span class=c1># Index using Inner Product</span>
</span></span><span class=line><span class=ln>40</span><span class=cl> <span class=n>faiss</span><span class=o>.</span><span class=n>normalize_L2</span><span class=p>(</span><span class=n>faiss_data</span><span class=p>)</span> <span class=c1># Normalized L2 with Inner Product search = cosine similarity</span>
</span></span><span class=line><span class=ln>41</span><span class=cl> <span class=n>faiss_DB</span><span class=o>.</span><span class=n>add</span><span class=p>(</span><span class=n>faiss_data</span><span class=p>)</span> <span class=c1># Build the index</span>
</span></span><span class=line><span class=ln>42</span><span class=cl>
</span></span><span class=line><span class=ln>43</span><span class=cl> <span class=k>return</span> <span class=n>faiss_DB</span><span class=p>,</span> <span class=n>faiss_data</span><span class=p>,</span> <span class=n>faiss_indices</span>
</span></span><span class=line><span class=ln>44</span><span class=cl>
</span></span><span class=line><span class=ln>45</span><span class=cl><span class=c1># Code to create an edge-list</span>
</span></span><span class=line><span class=ln>46</span><span class=cl><span class=k>def</span> <span class=nf>get_edge_list</span><span class=p>(</span><span class=n>label</span><span class=p>,</span> <span class=n>k</span><span class=o>=</span><span class=mi>5</span><span class=p>):</span>
</span></span><span class=line><span class=ln>47</span><span class=cl> <span class=n>faiss_DB</span><span class=p>,</span> <span class=n>faiss_data</span><span class=p>,</span> <span class=n>faiss_indices</span> <span class=o>=</span> <span class=n>create_index</span><span class=p>(</span><span class=n>label</span><span class=p>)</span>
</span></span><span class=line><span class=ln>48</span><span class=cl> <span class=c1># To map the data back to the original `train[b&#39;data&#39;]` array</span>
</span></span><span class=line><span class=ln>49</span><span class=cl> <span class=n>faiss_indices_map</span> <span class=o>=</span> <span class=p>{</span><span class=n>i</span><span class=p>:</span> <span class=n>x</span> <span class=k>for</span> <span class=n>i</span><span class=p>,</span><span class=n>x</span> <span class=ow>in</span> <span class=nb>enumerate</span><span class=p>(</span><span class=n>faiss_indices</span><span class=p>)}</span>
</span></span><span class=line><span class=ln>50</span><span class=cl> <span class=c1># To map the indices back to the original strings</span>
</span></span><span class=line><span class=ln>51</span><span class=cl> <span class=n>title_name_map</span> <span class=o>=</span> <span class=p>{</span><span class=n>i</span><span class=p>:</span> <span class=n>x</span> <span class=k>for</span> <span class=n>i</span><span class=p>,</span><span class=n>x</span> <span class=ow>in</span> <span class=n>data</span><span class=o>.</span><span class=n>select</span><span class=p>(</span><span class=s2>&#34;row_idx&#34;</span><span class=p>,</span> <span class=s2>&#34;product_title&#34;</span><span class=p>)</span><span class=o>.</span><span class=n>rows</span><span class=p>()}</span>
</span></span><span class=line><span class=ln>52</span><span class=cl> <span class=n>distances</span><span class=p>,</span> <span class=n>neighbors</span> <span class=o>=</span> <span class=n>faiss_DB</span><span class=o>.</span><span class=n>search</span><span class=p>(</span><span class=n>faiss_data</span><span class=p>,</span> <span class=n>k</span><span class=p>)</span>
</span></span><span class=line><span class=ln>53</span><span class=cl>
</span></span><span class=line><span class=ln>54</span><span class=cl> <span class=k>return</span> <span class=p>(</span>
</span></span><span class=line><span class=ln>55</span><span class=cl> <span class=n>pl</span><span class=o>.</span><span class=n>DataFrame</span><span class=p>({</span>
</span></span><span class=line><span class=ln>56</span><span class=cl> <span class=s2>&#34;from&#34;</span><span class=p>:</span> <span class=n>faiss_indices</span><span class=p>})</span>
</span></span><span class=line><span class=ln>57</span><span class=cl> <span class=o>.</span><span class=n>with_columns</span><span class=p>(</span>
</span></span><span class=line><span class=ln>58</span><span class=cl> <span class=n>pl</span><span class=o>.</span><span class=n>Series</span><span class=p>(</span><span class=s2>&#34;to&#34;</span><span class=p>,</span> <span class=n>neighbors</span><span class=p>),</span>
</span></span><span class=line><span class=ln>59</span><span class=cl> <span class=n>pl</span><span class=o>.</span><span class=n>Series</span><span class=p>(</span><span class=s2>&#34;distance&#34;</span><span class=p>,</span> <span class=n>distances</span><span class=p>))</span>
</span></span><span class=line><span class=ln>60</span><span class=cl> <span class=o>.</span><span class=n>explode</span><span class=p>(</span><span class=s2>&#34;to&#34;</span><span class=p>,</span> <span class=s2>&#34;distance&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>61</span><span class=cl> <span class=o>.</span><span class=n>with_columns</span><span class=p>(</span>
</span></span><span class=line><span class=ln>62</span><span class=cl> <span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;from&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>63</span><span class=cl> <span class=o>.</span><span class=n>map_dict</span><span class=p>(</span><span class=n>title_name_map</span><span class=p>),</span>
</span></span><span class=line><span class=ln>64</span><span class=cl> <span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;to&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>65</span><span class=cl> <span class=o>.</span><span class=n>map_dict</span><span class=p>(</span><span class=n>faiss_indices_map</span><span class=p>)</span>
</span></span><span class=line><span class=ln>66</span><span class=cl> <span class=o>.</span><span class=n>map_dict</span><span class=p>(</span><span class=n>title_name_map</span><span class=p>))</span>
</span></span><span class=line><span class=ln>67</span><span class=cl> <span class=o>.</span><span class=n>filter</span><span class=p>(</span><span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;from&#34;</span><span class=p>)</span> <span class=o>!=</span> <span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;to&#34;</span><span class=p>))</span>
</span></span><span class=line><span class=ln>68</span><span class=cl> <span class=p>)</span>
</span></span><span class=line><span class=ln>69</span><span class=cl>
</span></span><span class=line><span class=ln>70</span><span class=cl><span class=c1># Code to extract components from a Network Graph</span>
</span></span><span class=line><span class=ln>71</span><span class=cl><span class=kn>import</span> <span class=nn>networkx</span> <span class=k>as</span> <span class=nn>nx</span>
</span></span><span class=line><span class=ln>72</span><span class=cl><span class=k>def</span> <span class=nf>get_cluster_map</span><span class=p>(</span><span class=n>label</span><span class=p>,</span> <span class=n>k</span><span class=o>=</span><span class=mi>5</span><span class=p>,</span> <span class=n>min_cosine_distance</span><span class=o>=</span><span class=mf>0.95</span><span class=p>):</span>
</span></span><span class=line><span class=ln>73</span><span class=cl> <span class=n>edge_list</span> <span class=o>=</span> <span class=p>(</span>
</span></span><span class=line><span class=ln>74</span><span class=cl> <span class=n>get_edge_list</span><span class=p>(</span><span class=n>label</span><span class=p>,</span> <span class=n>k</span><span class=o>=</span><span class=n>k</span><span class=p>)</span>
</span></span><span class=line><span class=ln>75</span><span class=cl> <span class=o>.</span><span class=n>filter</span><span class=p>(</span><span class=n>pl</span><span class=o>.</span><span class=n>col</span><span class=p>(</span><span class=s2>&#34;distance&#34;</span><span class=p>)</span> <span class=o>&gt;=</span> <span class=n>min_cosine_distance</span><span class=p>)</span>
</span></span><span class=line><span class=ln>76</span><span class=cl> <span class=p>)</span>
</span></span><span class=line><span class=ln>77</span><span class=cl> <span class=n>graph</span> <span class=o>=</span> <span class=n>nx</span><span class=o>.</span><span class=n>from_pandas_edgelist</span><span class=p>(</span><span class=n>edge_list</span><span class=o>.</span><span class=n>to_pandas</span><span class=p>(),</span> <span class=n>source</span><span class=o>=</span><span class=s2>&#34;from&#34;</span><span class=p>,</span> <span class=n>target</span><span class=o>=</span><span class=s2>&#34;to&#34;</span><span class=p>)</span>
</span></span><span class=line><span class=ln>78</span><span class=cl> <span class=k>return</span> <span class=p>{</span><span class=n>i</span><span class=p>:</span> <span class=nb>list</span><span class=p>(</span><span class=n>x</span><span class=p>)</span> <span class=k>for</span> <span class=n>i</span><span class=p>,</span><span class=n>x</span> <span class=ow>in</span> <span class=nb>enumerate</span><span class=p>(</span><span class=n>nx</span><span class=o>.</span><span class=n>connected_components</span><span class=p>(</span><span class=n>graph</span><span class=p>))}</span>
</span></span><span class=line><span class=ln>79</span><span class=cl>
</span></span><span class=line><span class=ln>80</span><span class=cl><span class=c1># Example call to a single category to obtain its clusters</span>
</span></span><span class=line><span class=ln>81</span><span class=cl><span class=n>clusters</span> <span class=o>=</span> <span class=n>get_cluster_map</span><span class=p>(</span><span class=s2>&#34;Cell Phones Accessories&#34;</span><span class=p>,</span> <span class=mi>5</span><span class=p>,</span> <span class=mf>0.95</span><span class=p>)</span>
</span></span><span class=line><span class=ln>82</span><span class=cl><span class=c1># Example call to **all** categories to obtain all clusters</span>
</span></span><span class=line><span class=ln>83</span><span class=cl><span class=n>clusters</span> <span class=o>=</span> <span class=p>[</span><span class=n>get_cluster_map</span><span class=p>(</span><span class=n>x</span><span class=p>,</span> <span class=mi>5</span><span class=p>,</span> <span class=mf>0.95</span><span class=p>)</span> <span class=k>for</span> <span class=n>x</span> <span class=ow>in</span> <span class=n>data</span><span class=o>.</span><span class=n>get_column</span><span class=p>(</span><span class=s2>&#34;category_label&#34;</span><span class=p>)</span><span class=o>.</span><span class=n>unique</span><span class=p>()]</span></span></span></code></pre></div><h2 id=how-the-code-works>How the code works</h2><p>If you want to write down an algorithmic way of looking at this approach,</p><ol><li>Obtain a 2D representation of the labelled/categorized data. This can be embeddings for strings, the final hidden state output from a generic CNN model for images, or a good ol&rsquo; tabular dataset where all numbers are normalized and can be expressed as such.</li><li>Create an ANN database (based on a package such as <code>faiss</code>) that allows you fast nearest neighbor searches. Use cosine similarity for an easy threshold determination step.</li><li>Obtain an edge-list of k (from 5 to 100) nearest neighbors for <strong>all</strong> (or a sample of data points in case your dataset is incredibly HUGE) data points in the ANN database.</li><li>Apply a minimum threshold on similarity (completely based on heuristics), and obtain the connected components of the network graph from the filtered edge-list you just created.</li><li>Map all indices back to their source data-points that make sense, and pick any number of items from each cluster (usually, I end up picking one element from each cluster), and you now have your representative sample!</li></ol></content><p><a class=blog-tags href=/tags/representative/>#representative</a>
<a class=blog-tags href=/tags/samples/>#samples</a>
<a class=blog-tags href=/tags/faiss/>#faiss</a>
<a class=blog-tags href=/tags/approximate/>#approximate</a>
<a class=blog-tags href=/tags/nearest/>#nearest</a>
<a class=blog-tags href=/tags/neighbor/>#neighbor</a>
<a class=blog-tags href=/tags/network/>#network</a>
<a class=blog-tags href=/tags/graph/>#graph</a>
<a class=blog-tags href=/tags/networkx/>#networkx</a>
<a class=blog-tags href=/tags/polars/>#polars</a>
<a class=blog-tags href=/tags/category/>#category</a></p></main><footer><small>© Avinash Mallya | Design via <a href=https://github.com/clente/hugo-bearcub>Bear Cub</a>.</small></footer></body></html>