Merge pull request #5 from avimallu/RepresentativeSamples
fix typos and unused arguments
This commit is contained in:
@@ -164,7 +164,7 @@ def create_index(title_embeddings):
|
||||
return ann_index # Faiss considers databases an "index"
|
||||
```
|
||||
|
||||
This does create *a* database. But remember, we're trying to find *representative samples* - which means we need to do this *by* the category (or label). So let's design a function that sends only the necessary data as that for a particular category, and the create the database. We'll need three pieces of information from this function:
|
||||
This does create *a* database. But remember, we're trying to find *representative samples* - which means we need to do this *by* the category (or label). So let's design a function that sends only the necessary data as that for a particular category, and then create the database. We'll need three pieces of information from this function:
|
||||
|
||||
1. The actual `faiss` database.
|
||||
2. The actual subset of data that was used to build this index.
|
||||
@@ -177,9 +177,9 @@ import faiss
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
def create_index(label, return_data=False):
|
||||
def create_index(label):
|
||||
faiss_indices = (
|
||||
data
|
||||
data # this needs to be an argument if you want to create a generic function
|
||||
.with_row_count("row_idx")
|
||||
.filter(pl.col("category_label") == label)
|
||||
.get_column("row_idx")
|
||||
@@ -204,7 +204,7 @@ To proceed with getting a representative sample, the next step is to find the ne
|
||||
|
||||
```py
|
||||
def get_edge_list(label, k=5):
|
||||
faiss_DB, faiss_data, faiss_indices = create_index(label, return_data=True)
|
||||
faiss_DB, faiss_data, faiss_indices = create_index(label)
|
||||
# To map the data back to the original `train[b'data']` array
|
||||
faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
|
||||
# To map the indices back to the original strings
|
||||
@@ -334,21 +334,24 @@ data = (
|
||||
ST = sentence_transformers.SentenceTransformer("all-mpnet-base-v2")
|
||||
title_embeddings = (
|
||||
ST.encode(
|
||||
data.get_column("product_title").to_list(), device="mps",
|
||||
data.get_column("product_title").to_list(),
|
||||
# I'm on a MacBook, you should use `cuda` or `cpu`
|
||||
# if you've got different hardware.
|
||||
device="mps",
|
||||
show_progress_bar=True, convert_to_tensor=True)
|
||||
.cpu().numpy())
|
||||
|
||||
# Code to create a FAISS index
|
||||
def create_index(label, return_data=False):
|
||||
def create_index(label):
|
||||
faiss_indices = (
|
||||
data
|
||||
data # this needs to be an argument if you want to create a generic function
|
||||
.filter(pl.col("category_label") == label)
|
||||
.get_column("row_idx")
|
||||
.to_list()
|
||||
)
|
||||
|
||||
faiss_data = title_embeddings[faiss_indices]
|
||||
d = faiss_data.shape[1] # Number of dimensions
|
||||
d = faiss_data.shape[1] # Number of dimensions
|
||||
faiss_DB = faiss.IndexFlatIP(d) # Index using Inner Product
|
||||
faiss.normalize_L2(faiss_data) # Normalized L2 with Inner Product search = cosine similarity
|
||||
faiss_DB.add(faiss_data) # Build the index
|
||||
@@ -357,7 +360,7 @@ def create_index(label, return_data=False):
|
||||
|
||||
# Code to create an edge-list
|
||||
def get_edge_list(label, k=5):
|
||||
faiss_DB, faiss_data, faiss_indices = create_index(label, return_data=True)
|
||||
faiss_DB, faiss_data, faiss_indices = create_index(label)
|
||||
# To map the data back to the original `train[b'data']` array
|
||||
faiss_indices_map = {i: x for i,x in enumerate(faiss_indices)}
|
||||
# To map the indices back to the original strings
|
||||
|
||||
Reference in New Issue
Block a user