Merge pull request #5 from avimallu/RepresentativeSamples

fix typos and unused arguments
This commit is contained in:
avimallu
2023-10-30 17:29:01 -05:00
committed by GitHub

View File

@@ -164,7 +164,7 @@ def create_index(title_embeddings):
return ann_index # Faiss considers databases an "index"
```
This does create *a* database. But remember, we're trying to find *representative samples* - which means we need to do this *by* the category (or label). So let's design a function that sends only the necessary data as that for a particular category, and the create the database. We'll need three pieces of information from this function:
This does create *a* database. But remember, we're trying to find *representative samples* - which means we need to do this *by* the category (or label). So let's design a function that sends only the necessary data as that for a particular category, and then create the database. We'll need three pieces of information from this function:
1. The actual `faiss` database.
2. The actual subset of data that was used to build this index.
@@ -177,9 +177,9 @@ import faiss
import numpy as np
import polars as pl
def create_index(label, return_data=False):
def create_index(label):
faiss_indices = (
data
data # this needs to be an argument if you want to create a generic function
.with_row_count("row_idx")
.filter(pl.col("category_label") == label)
.get_column("row_idx")
@@ -204,7 +204,7 @@ To proceed with getting a representative sample, the next step is to find the ne
```py
def get_edge_list(label, k=5):
faiss_DB, faiss_data, faiss_indices = create_index(label, return_data=True)
faiss_DB, faiss_data, faiss_indices = create_index(label)
# To map the data back to the original `train[b'data']` array
faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
# To map the indices back to the original strings
@@ -334,21 +334,24 @@ data = (
ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
title_embeddings = (
ST.encode(
data.get_column("product_title").to_list(), device="mps",
data.get_column("product_title").to_list(),
# I'm on a MacBook, you should use `cuda` or `cpu`
# if you've got different hardware.
device="mps",
show_progress_bar=True, convert_to_tensor=True)
.cpu().numpy())
# Code to create a FAISS index
def create_index(label, return_data=False):
def create_index(label):
faiss_indices = (
data
data # this needs to be an argument if you want to create a generic function
.filter(pl.col("category_label") == label)
.get_column("row_idx")
.to_list()
)
faiss_data = title_embeddings[faiss_indices]
d = faiss_data.shape[1] # Number of dimensions
d = faiss_data.shape[1] # Number of dimensions
faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
faiss_DB.add(faiss_data) # Build the index
@@ -357,7 +360,7 @@ def create_index(label, return_data=False):
# Code to create an edge-list
def get_edge_list(label, k=5):
faiss_DB, faiss_data, faiss_indices = create_index(label, return_data=True)
faiss_DB, faiss_data, faiss_indices = create_index(label)
# To map the data back to the original `train[b'data']` array
faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
# To map the indices back to the original strings