Compare commits

2 Commits

Author SHA1 Message Date
Avinash Mallya
ce147d0e08 WIP, move split calculation to its own class 2025-01-12 21:50:05 -06:00
Avinash Mallya
a86f856ec4 Start removal of Gradio and streamlit 2025-01-12 20:19:02 -06:00
4 changed files with 949 additions and 866 deletions

View File

@@ -5,11 +5,11 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"gradio>=5.9.1",
"great-tables>=0.15.0",
"more-itertools>=10.5.0",
"ollama>=0.4.4",
"polars>=1.18.0",
"reflex>=0.6.8",
"selenium>=4.27.1",
"surya-ocr>=0.8.1",
]

View File

@@ -1,625 +0,0 @@
from datetime import datetime
from typing import Literal, TypedDict, TypeVar
import argparse as agp
import gradio as gr
import polars as pl
from PIL.Image import Image
from src.app.split_ai import ReceiptReader
from src.app.utils import css_code, head_html, spinner_html
ComponentType = TypeVar("ComponentType")
def calculate_splits(
item_names: list[str],
item_people: list[list[str]],
item_amounts: list[float],
total: float,
tip: float,
tax: float,
people_list: list[str],
tip_split_proportionally: bool,
tax_split_proportionally: bool,
cashback_discount: float,
return_detailed_table: bool = False,
) -> gr.DataFrame:
"""
A simple, but long function to calculate splits for a provided receipt.
Args:
item_names: Names of the items being split.
item_people: A list of people for each item who are splitting its cost.
item_amounts: Amounts of the items being split.
total: The total amount in the receipt
tip: The tip in the receipt
tax: The tax in the receipt
people_list: The total number of people splitting the receipt.
tip_split_proportionally: Indicator for whether the tip is split proportional to pre-tax/tip cost.
tax_split_proportionally: Indicator for whether the tax is split proportional to pre-tax/tip cost.
cashback_discount: The total will be reduced by this percentage value.
return_detailed_table: Indicator to return full calculation table or a simplified one.
Returns:
A DataFrame form of the provided values along with their calculated splits or a simplified version.
"""
split_count = 0
unsplit_names = []
checkbox_count = len(item_people)
for name, split in zip(item_names, item_people):
if len(split) > 0:
split_count += 1
else:
unsplit_names.append(name)
if split_count != checkbox_count:
gr.Warning(
f"Pending splits: {','.join(unsplit_names)}",
title="Can't show splits yet",
)
return gr.DataFrame(pl.DataFrame(), visible=False)
else:
# Deliberately avoiding going the numpy route here since the data is very small anyway.
split_arrays: list[list[float]] = []
for split in item_people:
split_array = [1 / len(split) if x in split else 0.0 for x in people_list]
split_arrays.append(split_array)
split_amounts: list[list[float]] = []
for split_array, amount in zip(split_arrays, item_amounts):
split_amount = [amount * split for split in split_array]
split_amounts.append(split_amount)
split_subtotals = [sum(x) for x in zip(*split_amounts)]
subtotal = total - tip - tax
split_tips = [
x / subtotal * tip if tip_split_proportionally else tax / len(people_list)
for x in split_subtotals
]
split_taxes = [
x / subtotal * tax if tax_split_proportionally else tax / len(people_list)
for x in split_subtotals
]
split_totals_pre_cashback = [
split_subtotal + split_tip + split_tax
for split_subtotal, split_tip, split_tax in zip(
split_subtotals, split_tips, split_taxes
)
]
split_cashback = [-x * cashback_discount for x in split_totals_pre_cashback]
split_totals_post_cashback = [
x * (1 - cashback_discount) for x in split_totals_pre_cashback
]
first_col_names = list(item_names) + [
"Subtotal",
"Tip",
"Tax",
"Cashback",
"Total",
]
splits = split_amounts + [
split_subtotals,
split_tips,
split_taxes,
split_cashback,
split_totals_post_cashback,
]
horizontal_totals = list(item_amounts) + [
subtotal,
tip,
tax,
sum(split_cashback),
sum(split_totals_post_cashback),
]
full_calculation_df = (
pl.DataFrame(
{
"Item": first_col_names,
"splits": splits,
"Total": horizontal_totals,
},
schema={
"Item": pl.String,
"splits": pl.List(pl.Float64),
"Total": pl.Float64,
},
)
.with_columns(pl.col("splits").list.to_struct(fields=people_list))
.unnest("splits")
.with_columns(pl.col(pl.Float64).round(2))
)
if return_detailed_table:
return gr.DataFrame(full_calculation_df, visible=True)
else:
simple_calculation = (
full_calculation_df.filter(pl.col("Item").eq("Total"))
.select(pl.exclude("Total"))
.transpose(
include_header=True, header_name="Person", column_names=["Split"]
)
.filter(pl.col("Person").ne("Item"))
)
return gr.DataFrame(simple_calculation, visible=True)
class Item(TypedDict):
name: str
amount: float
class ItemSplitter:
def __init__(
self,
item: Item,
people_list: list[str],
) -> None:
self.people_list_state = people_list
self.item = item
self.no_interaction_kwargs = {"interactive": False, "min_width": 10}
self.interaction_kwargs = self.no_interaction_kwargs | {
"container": False,
"interactive": True,
}
def name_textbox(self, item_name: str) -> gr.Textbox:
return gr.Textbox(
item_name, show_label=False, scale=8, **self.interaction_kwargs
)
def amount_number(self, item_amount: float) -> gr.Number:
return gr.Number(
value=item_amount, precision=2, scale=3, **self.interaction_kwargs
)
def split_status_button(
self, choices: list[str] | None = None, status: Literal["⚠️", "🆗"] = "⚠️"
) -> gr.Button:
if choices is not None:
if len(choices) == 0:
status = "⚠️"
else:
status = "🆗"
else:
choices = []
variant: Literal["huggingface", "primary"] = (
"huggingface" if (status == "⚠️") | (len(choices) == 0) else "primary"
)
return gr.Button(
value=status, variant=variant, scale=1, **self.no_interaction_kwargs
)
def delete_item(self, item_list: list[Item]) -> list[Item]:
item_list.remove(self.item)
return item_list
def delete_item_button(self) -> gr.Button:
kwargs = self.no_interaction_kwargs | {"interactive": True}
return gr.Button(value="", variant="stop", **kwargs)
def people_list_checkbox(self, people_list: list[str]) -> gr.CheckboxGroup:
return gr.CheckboxGroup(choices=people_list, **self.interaction_kwargs)
def generate(self) -> tuple[gr.Textbox, gr.CheckboxGroup, gr.Number, gr.Button]:
return self.generate_mobile()
def generate_mobile(self):
with gr.Row(variant="default", equal_height=True):
item_name_textbox = self.name_textbox(self.item["name"])
item_amount_number = self.amount_number(self.item["amount"])
split_status_button = self.split_status_button(status="⚠️")
delete_item_button = self.delete_item_button()
people_list_checkbox = self.people_list_checkbox(self.people_list_state)
people_list_checkbox.change(
lambda x: self.split_status_button(choices=x),
people_list_checkbox,
split_status_button,
)
return (
item_name_textbox,
people_list_checkbox,
item_amount_number,
delete_item_button,
)
class SplitAIApp:
valid_split_variant: Literal["primary"] = "primary"
invalid_split_variant: Literal["huggingface"] = "huggingface"
def __init__(self, llm_model: str):
self.receipt_reader = ReceiptReader(llm_model)
self.demo = self.create_app()
@staticmethod
def prepare_calculate_splits_kwargs(num_records: int, *all_values) -> gr.DataFrame:
"""
This method is necessary because only a list[gr.Component] or similar can be sent as
`inputs` to an event listener. Therefore, it is unpacked here and prepared into a
dictionary based on how it is sent by the event. This method is specifically for
the `get_split_button.click` event listener.
Args:
num_records: The number of items present to split.
*all_values: A list of components to forward to `calculate_splits`.
Returns:
gr.DataFrame
"""
kwargs = {
"item_names": all_values[:num_records],
"item_people": all_values[num_records : num_records * 2],
"item_amounts": all_values[num_records * 2 : num_records * 3],
}
additional_kwargs = {
k: v
for k, v in zip(
[
"total",
"tip",
"tax",
"people_list",
"tip_split_proportionally",
"tax_split_proportionally",
"cashback_discount",
"return_detailed_table",
],
tuple(all_values[num_records * 3 :]),
)
}
additional_kwargs["cashback_discount"] /= 100
kwargs.update(additional_kwargs)
return calculate_splits(**kwargs)
@staticmethod
def update_component_attributes(
component: ComponentType, **kwargs
) -> ComponentType:
"""
This requirement is in place because Gradio expects you to provide A NEW INSTANCE of
the component that you want to update with its attributes changed. It seems like it
doesn't replace the component, but updates it this way. Very weird behavior.
Args:
component: The gradio component to update attributes for.
**kwargs: (attribute, value) pairs to update in child.
Returns:
A new instance of child's class with the updated attributes.
"""
gradio_class = type(component)
try:
return gradio_class(**kwargs)
except Exception as err:
print(
f"The Gradio component {gradio_class} does not have one of the provided attribute keys."
)
raise err
@staticmethod
def validate_people_list(people_textbox) -> tuple[gr.Image, list]:
if "," in people_textbox and people_textbox[-1] != ",":
people_list = [x.strip() for x in people_textbox.split(",")]
return gr.Image(interactive=True), people_list
else:
gr.Warning("You need to enter a list of names separated by commas.")
return gr.Image(interactive=False), []
def create_app(self) -> gr.Blocks():
# `head_html` required to prevent iOS from scaling the UI when clicking on a textbox.
with gr.Blocks(
css=css_code,
head=head_html,
theme="JohnSmith9982/small_and_pretty",
fill_width=True,
) as split_app:
with gr.Column():
self.people_textbox = gr.Textbox(
placeholder="Split names with a comma",
label="Who all are splitting this receipt?",
lines=1,
autofocus=True,
submit_btn="Submit",
)
self.people_list = gr.State([])
self.image_uploader = gr.Image(
show_label=False, scale=1, type="pil", interactive=False
)
self.people_textbox.submit(
SplitAIApp.validate_people_list,
[self.people_textbox],
[self.image_uploader, self.people_list],
)
with gr.Column():
with gr.Column():
with gr.Row():
self.merchant = gr.Textbox(
interactive=True,
label="Merchant Name",
min_width=20,
visible=False,
scale=2,
)
self.receipt_date = gr.DateTime(
interactive=True,
include_time=False,
type="datetime",
label="Date",
min_width=20,
visible=False,
scale=1,
)
with gr.Row():
self.total_amount = gr.Number(
interactive=True,
label="Total",
minimum=0,
min_width=20,
visible=False,
precision=2,
)
self.tip_amount = gr.Number(
interactive=True,
label="Tip",
min_width=20,
minimum=0,
visible=False,
precision=2,
)
self.tax_amount = gr.Number(
interactive=True,
label="Tax",
minimum=0,
min_width=20,
visible=False,
precision=2,
)
self.items = gr.State([])
@gr.render(inputs=[self.items, self.people_list])
def render_items(items: list[Item], people_list: list[str]):
item_names = []
item_peoples = []
item_amounts = []
for key, item in enumerate(items):
with gr.Column(variant="compact"):
splitter = ItemSplitter(item, people_list)
item_name, item_people, item_amount, delete_item_button = (
splitter.generate()
)
# This event needs to be defined outside the ItemSplitter class
# because it references a gr.State variable. All Gradio components
# can be properly pass ONLY via event listeners, as their state is
# managed by Gradio outside the flow of the Python app.
delete_item_button.click(
splitter.delete_item, self.items, self.items
)
item_names.append(item_name)
item_peoples.append(item_people)
item_amounts.append(item_amount)
self.split_tip_proportionally = gr.Checkbox(
value=True,
label="Split tip proportional to other costs",
info="If unchecked, will split equally.",
interactive=True,
)
self.split_tax_proportionally = gr.Checkbox(
value=True,
label="Split tax proportional to other costs",
info="If unchecked, will split equally.",
interactive=True,
)
self.add_cashback_discount = gr.Number(
minimum=0,
maximum=100,
value=0,
step=0.5,
label="Cashback discount to apply on total",
info="Choose a number between 0% and 100%.",
interactive=True,
)
self.show_detailed_table = gr.Checkbox(
value=False,
label="Show a detailed calculation table",
info="If unchecked, will just show the splits.",
interactive=True,
)
with gr.Row():
self.integrity_markdown = gr.Markdown(
show_label=False, value="", visible=False
)
with gr.Row():
get_splits_button = gr.Button(
"Get Splits", variant="primary", scale=5, min_width=10
)
get_splits_button.click(
lambda *x: SplitAIApp.prepare_calculate_splits_kwargs(
len(item_names), *x
),
inputs=(
item_names
+ item_peoples
+ item_amounts
+ [
self.total_amount,
self.tip_amount,
self.tax_amount,
self.people_list,
self.split_tip_proportionally,
self.split_tax_proportionally,
self.add_cashback_discount,
self.show_detailed_table,
]
),
outputs=self.display_result,
)
add_item_button = gr.Button(
"", variant="secondary", scale=1, min_width=10
)
def add_item(
items: list[Item],
):
new_item_name = f"Item {len(items)+1}"
return items + [
{
"name": new_item_name,
"amount": 0.0,
}
]
add_item_button.click(
add_item,
inputs=[self.items],
outputs=[self.items],
)
# Constantly keep track of whether totals match or not.
def integrity_checker(*args) -> gr.Markdown:
items = args[: len(args) - 3]
tip_amount, tax_amount, total_amount = args[len(args) - 3 :]
subtotal = sum(items)
if subtotal + tip_amount + tax_amount != total_amount:
return gr.Markdown(
f"⚠️ Looks like the total ({total_amount}) doesn't match the value of subtotal ({subtotal}) + tip ({tip_amount}) + tax ({tax_amount}) ⚠️",
show_label=False,
visible=True,
)
else:
return gr.Markdown(visible=False)
gr.on(
triggers=[x.change for x in item_amounts]
+ [
self.tip_amount.change,
self.tax_amount.change,
self.total_amount.change,
],
fn=integrity_checker,
inputs=item_amounts
+ [
self.tip_amount,
self.tax_amount,
self.total_amount,
],
outputs=[self.integrity_markdown],
)
self.display_result = gr.DataFrame(value=None, visible=False)
self.spinner_html = gr.HTML(
spinner_html,
visible=False,
padding=False,
)
self.image_uploader.upload(
lambda: gr.HTML(visible=True), inputs=None, outputs=self.spinner_html
).then(
self.process_image,
inputs=[self.image_uploader, self.items],
outputs=[
self.merchant,
self.receipt_date,
self.total_amount,
self.tip_amount,
self.tax_amount,
self.items,
],
show_progress="hidden",
).then(
lambda: gr.HTML(visible=False), inputs=None, outputs=self.spinner_html
)
return split_app
def process_image(
self,
image: Image,
items: gr.State,
): # -> gr.State:
receipt_string = self.receipt_reader.get_ordered_text(image)
receipt_extracted = self.receipt_reader.extract_components(receipt_string)
print(receipt_extracted)
# receipt_extracted = {
# "merchant": "FUBAR",
# "receipt_date": datetime.now(),
# "total": 15,
# "tip": 0,
# "tax": 3,
# "item_amounts": [
# {"name": "PET TOY", "currency": "$", "amount": 2},
# {"name": "FLOPPY PUPPY", "currency": "$", "amount": 4},
# {"name": "SSSUPREME S", "currency": "$", "amount": 6},
# ],
# }
key_value_updates = [
{
"component": self.merchant,
"kwargs": {"value": receipt_extracted["merchant"], "visible": True},
},
{
"component": self.receipt_date,
"kwargs": {
"value": receipt_extracted["receipt_date"],
"visible": True,
},
},
{
"component": self.total_amount,
"kwargs": {
"value": receipt_extracted["total"],
"visible": True,
},
},
{
"component": self.tip_amount,
"kwargs": {
"value": receipt_extracted["tip"],
"visible": True,
},
},
{
"component": self.tax_amount,
"kwargs": {
"value": receipt_extracted["tax"],
"visible": True,
},
},
]
out = [
self.update_component_attributes(x["component"], **x["kwargs"])
for x in key_value_updates
]
items += [
{"name": x["name"], "amount": x["amount"]}
for x in receipt_extracted["receipt_items"]
]
out += [items]
return out
def launch(self, expose_to_local_network: bool = False):
if expose_to_local_network:
self.demo.queue().launch(server_name="0.0.0.0", server_port=7860)
else:
self.demo.queue().launch()
def arg_parser() -> agp.ArgumentParser:
ag = agp.ArgumentParser()
ag.add_argument(
"-m",
"--model",
type=str,
default="qwen2.5:7b",
help="Choose the LLM model used.",
)
return ag
if __name__ == "__main__":
args = arg_parser().parse_args()
demo = SplitAIApp(args.model)
demo.launch(True)

275
src/core/split.py Normal file
View File

@@ -0,0 +1,275 @@
import polars as pl
class IncompleteSplitError(Exception):
def __init__(
self,
message,
):
super().__init__(message)
class SplitCalculator:
"""
A simple, but long class to calculate splits for a provided receipt.
Args:
item_names: Names of the items being split.
item_people: A list of people for each item who are splitting its cost.
item_amounts: Amounts of the items being split.
total: The total amount in the receipt
tip: The tip in the receipt
tax: The tax in the receipt
people_list: The total number of people splitting the receipt.
tip_split_proportionally: Indicator for whether the tip is split proportional to pre-tax/tip cost.
tax_split_proportionally: Indicator for whether the tax is split proportional to pre-tax/tip cost.
cashback_discount: The total will be reduced by this percentage value.
return_detailed_table: Indicator to return full calculation table or a simplified one.
"""
def __init__(
self,
item_names: list[str],
item_people: list[list[str]],
item_amounts: list[float],
receipt_total: float,
receipt_tip: float,
receipt_tax: float,
people_list: list[str],
tip_split_proportionally: bool,
tax_split_proportionally: bool,
cashback_discount: float,
return_detailed_table: bool = False,
):
self.item_names = item_names
self.item_people = item_people
self.item_amounts = item_amounts
self.receipt_total = receipt_total
self.receipt_tip = receipt_tip
self.receipt_tax = receipt_tax
self.people_list = people_list
self.tip_split_proportionally = tip_split_proportionally
self.tax_split_proportionally = tax_split_proportionally
self.cashback_discount = cashback_discount
self.return_detailed_table = return_detailed_table
self.subtotal = self.receipt_total - self.receipt_tip - self.receipt_tax
self.split_tips: float | None = None
self.split_taxes: float | None = None
def validate_splits(self):
split_count = 0
unsplit_names = []
for name, split in zip(self.item_names, self.item_people):
if len(split) > 0:
split_count += 1
else:
unsplit_names.append(name)
if split_count != len(self.item_people):
raise IncompleteSplitError(
f"The following items have not been assigned splits: {','.join(unsplit_names)}"
)
def distribute_amount(self, amount: float, split_subtotals: list[float]):
"""
Distribute `amount` equally, or distribute it proportionally, among
all the people involved in the split.
"""
return [
x / self.subtotal * amount
if self.tax_split_proportionally
else amount / len(self.people_list)
for x in split_subtotals
]
def subtract_cashback(
self, split_totals: list[float]
) -> tuple[list[float], list[float]]:
split_cashback = [-x * cashback_discount for x in split_totals]
split_totals_minus_cashback = [
x * (1 - cashback_discount) for x in split_totals
]
return split_cashback, split_totals_minus_cashback
def forward(self):
split_arrays: list[list[float]] = []
for split in self.item_people:
split_array = [
1 / len(split) if x in split else 0.0 for x in self.people_list
]
split_arrays.append(split_array)
split_amounts: list[list[float]] = []
for split_array, amount in zip(split_arrays, self.item_amounts):
split_amount = [amount * split for split in split_array]
split_amounts.append(split_amount)
split_subtotals = [sum(x) for x in zip(*split_amounts)]
split_tips = self.distribute_amount(self.receipt_tip, split_subtotals)
split_taxes = self.distribute_amount(self.receipt_tax, split_subtotals)
split_totals = [
split_subtotal + split_tip + split_tax
for split_subtotal, split_tip, split_tax in zip(
split_subtotals, split_tips, split_taxes
)
]
split_cashback, split_totals_minus_cashback = self.subtract_cashback(
split_totals
)
def calculate_splits(
item_names: list[str],
item_people: list[list[str]],
item_amounts: list[float],
total: float,
tip: float,
tax: float,
people_list: list[str],
tip_split_proportionally: bool,
tax_split_proportionally: bool,
cashback_discount: float,
return_detailed_table: bool = False,
) -> pl.DataFrame:
"""
A simple, but long function to calculate splits for a provided receipt.
Args:
item_names: Names of the items being split.
item_people: A list of people for each item who are splitting its cost.
item_amounts: Amounts of the items being split.
total: The total amount in the receipt
tip: The tip in the receipt
tax: The tax in the receipt
people_list: The total number of people splitting the receipt.
tip_split_proportionally: Indicator for whether the tip is split proportional to pre-tax/tip cost.
tax_split_proportionally: Indicator for whether the tax is split proportional to pre-tax/tip cost.
cashback_discount: The total will be reduced by this percentage value.
return_detailed_table: Indicator to return full calculation table or a simplified one.
Returns:
A DataFrame form of the provided values along with their calculated splits or a simplified version.
"""
split_count = 0
unsplit_names = []
checkbox_count = len(item_people)
for name, split in zip(item_names, item_people):
if len(split) > 0:
split_count += 1
else:
unsplit_names.append(name)
if split_count != checkbox_count:
raise IncompleteSplitError(
f"The following items have not been assigned splits: {','.join(unsplit_names)}"
)
return None
else:
# Deliberately avoiding going the numpy route here since the data is very small anyway.
split_arrays: list[list[float]] = []
for split in item_people:
split_array = [1 / len(split) if x in split else 0.0 for x in people_list]
split_arrays.append(split_array)
split_amounts: list[list[float]] = []
for split_array, amount in zip(split_arrays, item_amounts):
split_amount = [amount * split for split in split_array]
split_amounts.append(split_amount)
split_subtotals = [sum(x) for x in zip(*split_amounts)]
subtotal = total - tip - tax
split_tips = [
x / subtotal * tip if tip_split_proportionally else tax / len(people_list)
for x in split_subtotals
]
split_taxes = [
x / subtotal * tax if tax_split_proportionally else tax / len(people_list)
for x in split_subtotals
]
split_totals_pre_cashback = [
split_subtotal + split_tip + split_tax
for split_subtotal, split_tip, split_tax in zip(
split_subtotals, split_tips, split_taxes
)
]
split_cashback = [-x * cashback_discount for x in split_totals_pre_cashback]
split_totals_post_cashback = [
x * (1 - cashback_discount) for x in split_totals_pre_cashback
]
first_col_names = list(item_names) + [
"Subtotal",
"Tip",
"Tax",
"Cashback",
"Total",
]
splits = split_amounts + [
split_subtotals,
split_tips,
split_taxes,
split_cashback,
split_totals_post_cashback,
]
horizontal_totals = list(item_amounts) + [
subtotal,
tip,
tax,
sum(split_cashback),
sum(split_totals_post_cashback),
]
full_calculation_df = (
pl.DataFrame(
{
"Item": first_col_names,
"splits": splits,
"Total": horizontal_totals,
},
schema={
"Item": pl.String,
"splits": pl.List(pl.Float64),
"Total": pl.Float64,
},
)
.with_columns(pl.col("splits").list.to_struct(fields=people_list))
.unnest("splits")
.with_columns(pl.col(pl.Float64).round(2))
)
if return_detailed_table:
return full_calculation_df
else:
simple_calculation = (
full_calculation_df.filter(pl.col("Item").eq("Total"))
.select(pl.exclude("Total"))
.transpose(
include_header=True, header_name="Person", column_names=["Split"]
)
.filter(pl.col("Person").ne("Item"))
)
return simple_calculation
if __name__ == "__main__":
# Example usage
item_names = ["Item 1", "Item 2", "Item 3"]
item_people = [["Alice", "Bob"], ["Alice"], ["Bob", "Charlie"]]
item_amounts = [10.0, 20.0, 30.0]
total = 70.0
tip = 6.0
tax = 4.0
people_list = ["Alice", "Bob", "Charlie"]
tip_split_proportionally = True
tax_split_proportionally = True
cashback_discount = 0.03
result_df = calculate_splits(
item_names,
item_people,
item_amounts,
total,
tip,
tax,
people_list,
tip_split_proportionally,
tax_split_proportionally,
cashback_discount,
return_detailed_table=True,
)
print(result_df)

913
uv.lock generated

File diff suppressed because it is too large Load Diff