Compare commits
2 Commits
main
...
switch_to_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce147d0e08 | ||
|
|
a86f856ec4 |
@@ -5,11 +5,11 @@ description = "Add your description here"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"gradio>=5.9.1",
|
|
||||||
"great-tables>=0.15.0",
|
"great-tables>=0.15.0",
|
||||||
"more-itertools>=10.5.0",
|
"more-itertools>=10.5.0",
|
||||||
"ollama>=0.4.4",
|
"ollama>=0.4.4",
|
||||||
"polars>=1.18.0",
|
"polars>=1.18.0",
|
||||||
|
"reflex>=0.6.8",
|
||||||
"selenium>=4.27.1",
|
"selenium>=4.27.1",
|
||||||
"surya-ocr>=0.8.1",
|
"surya-ocr>=0.8.1",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,625 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Literal, TypedDict, TypeVar
|
|
||||||
|
|
||||||
import argparse as agp
|
|
||||||
import gradio as gr
|
|
||||||
import polars as pl
|
|
||||||
from PIL.Image import Image
|
|
||||||
|
|
||||||
from src.app.split_ai import ReceiptReader
|
|
||||||
from src.app.utils import css_code, head_html, spinner_html
|
|
||||||
|
|
||||||
ComponentType = TypeVar("ComponentType")
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_splits(
|
|
||||||
item_names: list[str],
|
|
||||||
item_people: list[list[str]],
|
|
||||||
item_amounts: list[float],
|
|
||||||
total: float,
|
|
||||||
tip: float,
|
|
||||||
tax: float,
|
|
||||||
people_list: list[str],
|
|
||||||
tip_split_proportionally: bool,
|
|
||||||
tax_split_proportionally: bool,
|
|
||||||
cashback_discount: float,
|
|
||||||
return_detailed_table: bool = False,
|
|
||||||
) -> gr.DataFrame:
|
|
||||||
"""
|
|
||||||
A simple, but long function to calculate splits for a provided receipt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
item_names: Names of the items being split.
|
|
||||||
item_people: A list of people for each item who are splitting its cost.
|
|
||||||
item_amounts: Amounts of the items being split.
|
|
||||||
total: The total amount in the receipt
|
|
||||||
tip: The tip in the receipt
|
|
||||||
tax: The tax in the receipt
|
|
||||||
people_list: The total number of people splitting the receipt.
|
|
||||||
tip_split_proportionally: Indicator for whether the tip is split proportional to pre-tax/tip cost.
|
|
||||||
tax_split_proportionally: Indicator for whether the tax is split proportional to pre-tax/tip cost.
|
|
||||||
cashback_discount: The total will be reduced by this percentage value.
|
|
||||||
return_detailed_table: Indicator to return full calculation table or a simplified one.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A DataFrame form of the provided values along with their calculated splits or a simplified version.
|
|
||||||
"""
|
|
||||||
split_count = 0
|
|
||||||
unsplit_names = []
|
|
||||||
checkbox_count = len(item_people)
|
|
||||||
for name, split in zip(item_names, item_people):
|
|
||||||
if len(split) > 0:
|
|
||||||
split_count += 1
|
|
||||||
else:
|
|
||||||
unsplit_names.append(name)
|
|
||||||
if split_count != checkbox_count:
|
|
||||||
gr.Warning(
|
|
||||||
f"Pending splits: {','.join(unsplit_names)}",
|
|
||||||
title="Can't show splits yet",
|
|
||||||
)
|
|
||||||
return gr.DataFrame(pl.DataFrame(), visible=False)
|
|
||||||
else:
|
|
||||||
# Deliberately avoiding going the numpy route here since the data is very small anyway.
|
|
||||||
split_arrays: list[list[float]] = []
|
|
||||||
for split in item_people:
|
|
||||||
split_array = [1 / len(split) if x in split else 0.0 for x in people_list]
|
|
||||||
split_arrays.append(split_array)
|
|
||||||
split_amounts: list[list[float]] = []
|
|
||||||
for split_array, amount in zip(split_arrays, item_amounts):
|
|
||||||
split_amount = [amount * split for split in split_array]
|
|
||||||
split_amounts.append(split_amount)
|
|
||||||
|
|
||||||
split_subtotals = [sum(x) for x in zip(*split_amounts)]
|
|
||||||
subtotal = total - tip - tax
|
|
||||||
split_tips = [
|
|
||||||
x / subtotal * tip if tip_split_proportionally else tax / len(people_list)
|
|
||||||
for x in split_subtotals
|
|
||||||
]
|
|
||||||
split_taxes = [
|
|
||||||
x / subtotal * tax if tax_split_proportionally else tax / len(people_list)
|
|
||||||
for x in split_subtotals
|
|
||||||
]
|
|
||||||
split_totals_pre_cashback = [
|
|
||||||
split_subtotal + split_tip + split_tax
|
|
||||||
for split_subtotal, split_tip, split_tax in zip(
|
|
||||||
split_subtotals, split_tips, split_taxes
|
|
||||||
)
|
|
||||||
]
|
|
||||||
split_cashback = [-x * cashback_discount for x in split_totals_pre_cashback]
|
|
||||||
split_totals_post_cashback = [
|
|
||||||
x * (1 - cashback_discount) for x in split_totals_pre_cashback
|
|
||||||
]
|
|
||||||
first_col_names = list(item_names) + [
|
|
||||||
"Subtotal",
|
|
||||||
"Tip",
|
|
||||||
"Tax",
|
|
||||||
"Cashback",
|
|
||||||
"Total",
|
|
||||||
]
|
|
||||||
splits = split_amounts + [
|
|
||||||
split_subtotals,
|
|
||||||
split_tips,
|
|
||||||
split_taxes,
|
|
||||||
split_cashback,
|
|
||||||
split_totals_post_cashback,
|
|
||||||
]
|
|
||||||
horizontal_totals = list(item_amounts) + [
|
|
||||||
subtotal,
|
|
||||||
tip,
|
|
||||||
tax,
|
|
||||||
sum(split_cashback),
|
|
||||||
sum(split_totals_post_cashback),
|
|
||||||
]
|
|
||||||
full_calculation_df = (
|
|
||||||
pl.DataFrame(
|
|
||||||
{
|
|
||||||
"Item": first_col_names,
|
|
||||||
"splits": splits,
|
|
||||||
"Total": horizontal_totals,
|
|
||||||
},
|
|
||||||
schema={
|
|
||||||
"Item": pl.String,
|
|
||||||
"splits": pl.List(pl.Float64),
|
|
||||||
"Total": pl.Float64,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.with_columns(pl.col("splits").list.to_struct(fields=people_list))
|
|
||||||
.unnest("splits")
|
|
||||||
.with_columns(pl.col(pl.Float64).round(2))
|
|
||||||
)
|
|
||||||
if return_detailed_table:
|
|
||||||
return gr.DataFrame(full_calculation_df, visible=True)
|
|
||||||
else:
|
|
||||||
simple_calculation = (
|
|
||||||
full_calculation_df.filter(pl.col("Item").eq("Total"))
|
|
||||||
.select(pl.exclude("Total"))
|
|
||||||
.transpose(
|
|
||||||
include_header=True, header_name="Person", column_names=["Split"]
|
|
||||||
)
|
|
||||||
.filter(pl.col("Person").ne("Item"))
|
|
||||||
)
|
|
||||||
return gr.DataFrame(simple_calculation, visible=True)
|
|
||||||
|
|
||||||
|
|
||||||
class Item(TypedDict):
|
|
||||||
name: str
|
|
||||||
amount: float
|
|
||||||
|
|
||||||
|
|
||||||
class ItemSplitter:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
item: Item,
|
|
||||||
people_list: list[str],
|
|
||||||
) -> None:
|
|
||||||
self.people_list_state = people_list
|
|
||||||
self.item = item
|
|
||||||
self.no_interaction_kwargs = {"interactive": False, "min_width": 10}
|
|
||||||
self.interaction_kwargs = self.no_interaction_kwargs | {
|
|
||||||
"container": False,
|
|
||||||
"interactive": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
def name_textbox(self, item_name: str) -> gr.Textbox:
|
|
||||||
return gr.Textbox(
|
|
||||||
item_name, show_label=False, scale=8, **self.interaction_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def amount_number(self, item_amount: float) -> gr.Number:
|
|
||||||
return gr.Number(
|
|
||||||
value=item_amount, precision=2, scale=3, **self.interaction_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def split_status_button(
|
|
||||||
self, choices: list[str] | None = None, status: Literal["⚠️", "🆗"] = "⚠️"
|
|
||||||
) -> gr.Button:
|
|
||||||
if choices is not None:
|
|
||||||
if len(choices) == 0:
|
|
||||||
status = "⚠️"
|
|
||||||
else:
|
|
||||||
status = "🆗"
|
|
||||||
else:
|
|
||||||
choices = []
|
|
||||||
variant: Literal["huggingface", "primary"] = (
|
|
||||||
"huggingface" if (status == "⚠️") | (len(choices) == 0) else "primary"
|
|
||||||
)
|
|
||||||
return gr.Button(
|
|
||||||
value=status, variant=variant, scale=1, **self.no_interaction_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_item(self, item_list: list[Item]) -> list[Item]:
|
|
||||||
item_list.remove(self.item)
|
|
||||||
return item_list
|
|
||||||
|
|
||||||
def delete_item_button(self) -> gr.Button:
|
|
||||||
kwargs = self.no_interaction_kwargs | {"interactive": True}
|
|
||||||
return gr.Button(value="❌", variant="stop", **kwargs)
|
|
||||||
|
|
||||||
def people_list_checkbox(self, people_list: list[str]) -> gr.CheckboxGroup:
|
|
||||||
return gr.CheckboxGroup(choices=people_list, **self.interaction_kwargs)
|
|
||||||
|
|
||||||
def generate(self) -> tuple[gr.Textbox, gr.CheckboxGroup, gr.Number, gr.Button]:
|
|
||||||
return self.generate_mobile()
|
|
||||||
|
|
||||||
def generate_mobile(self):
|
|
||||||
with gr.Row(variant="default", equal_height=True):
|
|
||||||
item_name_textbox = self.name_textbox(self.item["name"])
|
|
||||||
item_amount_number = self.amount_number(self.item["amount"])
|
|
||||||
split_status_button = self.split_status_button(status="⚠️")
|
|
||||||
delete_item_button = self.delete_item_button()
|
|
||||||
people_list_checkbox = self.people_list_checkbox(self.people_list_state)
|
|
||||||
people_list_checkbox.change(
|
|
||||||
lambda x: self.split_status_button(choices=x),
|
|
||||||
people_list_checkbox,
|
|
||||||
split_status_button,
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
item_name_textbox,
|
|
||||||
people_list_checkbox,
|
|
||||||
item_amount_number,
|
|
||||||
delete_item_button,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SplitAIApp:
|
|
||||||
valid_split_variant: Literal["primary"] = "primary"
|
|
||||||
invalid_split_variant: Literal["huggingface"] = "huggingface"
|
|
||||||
|
|
||||||
def __init__(self, llm_model: str):
|
|
||||||
self.receipt_reader = ReceiptReader(llm_model)
|
|
||||||
self.demo = self.create_app()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prepare_calculate_splits_kwargs(num_records: int, *all_values) -> gr.DataFrame:
|
|
||||||
"""
|
|
||||||
This method is necessary because only a list[gr.Component] or similar can be sent as
|
|
||||||
`inputs` to an event listener. Therefore, it is unpacked here and prepared into a
|
|
||||||
dictionary based on how it is sent by the event. This method is specifically for
|
|
||||||
the `get_split_button.click` event listener.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_records: The number of items present to split.
|
|
||||||
*all_values: A list of components to forward to `calculate_splits`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
gr.DataFrame
|
|
||||||
"""
|
|
||||||
kwargs = {
|
|
||||||
"item_names": all_values[:num_records],
|
|
||||||
"item_people": all_values[num_records : num_records * 2],
|
|
||||||
"item_amounts": all_values[num_records * 2 : num_records * 3],
|
|
||||||
}
|
|
||||||
additional_kwargs = {
|
|
||||||
k: v
|
|
||||||
for k, v in zip(
|
|
||||||
[
|
|
||||||
"total",
|
|
||||||
"tip",
|
|
||||||
"tax",
|
|
||||||
"people_list",
|
|
||||||
"tip_split_proportionally",
|
|
||||||
"tax_split_proportionally",
|
|
||||||
"cashback_discount",
|
|
||||||
"return_detailed_table",
|
|
||||||
],
|
|
||||||
tuple(all_values[num_records * 3 :]),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
additional_kwargs["cashback_discount"] /= 100
|
|
||||||
kwargs.update(additional_kwargs)
|
|
||||||
return calculate_splits(**kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def update_component_attributes(
|
|
||||||
component: ComponentType, **kwargs
|
|
||||||
) -> ComponentType:
|
|
||||||
"""
|
|
||||||
This requirement is in place because Gradio expects you to provide A NEW INSTANCE of
|
|
||||||
the component that you want to update with its attributes changed. It seems like it
|
|
||||||
doesn't replace the component, but updates it this way. Very weird behavior.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component: The gradio component to update attributes for.
|
|
||||||
**kwargs: (attribute, value) pairs to update in child.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A new instance of child's class with the updated attributes.
|
|
||||||
"""
|
|
||||||
gradio_class = type(component)
|
|
||||||
try:
|
|
||||||
return gradio_class(**kwargs)
|
|
||||||
except Exception as err:
|
|
||||||
print(
|
|
||||||
f"The Gradio component {gradio_class} does not have one of the provided attribute keys."
|
|
||||||
)
|
|
||||||
raise err
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate_people_list(people_textbox) -> tuple[gr.Image, list]:
|
|
||||||
if "," in people_textbox and people_textbox[-1] != ",":
|
|
||||||
people_list = [x.strip() for x in people_textbox.split(",")]
|
|
||||||
return gr.Image(interactive=True), people_list
|
|
||||||
else:
|
|
||||||
gr.Warning("You need to enter a list of names separated by commas.")
|
|
||||||
return gr.Image(interactive=False), []
|
|
||||||
|
|
||||||
def create_app(self) -> gr.Blocks():
|
|
||||||
# `head_html` required to prevent iOS from scaling the UI when clicking on a textbox.
|
|
||||||
with gr.Blocks(
|
|
||||||
css=css_code,
|
|
||||||
head=head_html,
|
|
||||||
theme="JohnSmith9982/small_and_pretty",
|
|
||||||
fill_width=True,
|
|
||||||
) as split_app:
|
|
||||||
with gr.Column():
|
|
||||||
self.people_textbox = gr.Textbox(
|
|
||||||
placeholder="Split names with a comma",
|
|
||||||
label="Who all are splitting this receipt?",
|
|
||||||
lines=1,
|
|
||||||
autofocus=True,
|
|
||||||
submit_btn="Submit",
|
|
||||||
)
|
|
||||||
self.people_list = gr.State([])
|
|
||||||
self.image_uploader = gr.Image(
|
|
||||||
show_label=False, scale=1, type="pil", interactive=False
|
|
||||||
)
|
|
||||||
self.people_textbox.submit(
|
|
||||||
SplitAIApp.validate_people_list,
|
|
||||||
[self.people_textbox],
|
|
||||||
[self.image_uploader, self.people_list],
|
|
||||||
)
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
self.merchant = gr.Textbox(
|
|
||||||
interactive=True,
|
|
||||||
label="Merchant Name",
|
|
||||||
min_width=20,
|
|
||||||
visible=False,
|
|
||||||
scale=2,
|
|
||||||
)
|
|
||||||
self.receipt_date = gr.DateTime(
|
|
||||||
interactive=True,
|
|
||||||
include_time=False,
|
|
||||||
type="datetime",
|
|
||||||
label="Date",
|
|
||||||
min_width=20,
|
|
||||||
visible=False,
|
|
||||||
scale=1,
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
self.total_amount = gr.Number(
|
|
||||||
interactive=True,
|
|
||||||
label="Total",
|
|
||||||
minimum=0,
|
|
||||||
min_width=20,
|
|
||||||
visible=False,
|
|
||||||
precision=2,
|
|
||||||
)
|
|
||||||
self.tip_amount = gr.Number(
|
|
||||||
interactive=True,
|
|
||||||
label="Tip",
|
|
||||||
min_width=20,
|
|
||||||
minimum=0,
|
|
||||||
visible=False,
|
|
||||||
precision=2,
|
|
||||||
)
|
|
||||||
self.tax_amount = gr.Number(
|
|
||||||
interactive=True,
|
|
||||||
label="Tax",
|
|
||||||
minimum=0,
|
|
||||||
min_width=20,
|
|
||||||
visible=False,
|
|
||||||
precision=2,
|
|
||||||
)
|
|
||||||
self.items = gr.State([])
|
|
||||||
|
|
||||||
@gr.render(inputs=[self.items, self.people_list])
|
|
||||||
def render_items(items: list[Item], people_list: list[str]):
|
|
||||||
item_names = []
|
|
||||||
item_peoples = []
|
|
||||||
item_amounts = []
|
|
||||||
for key, item in enumerate(items):
|
|
||||||
with gr.Column(variant="compact"):
|
|
||||||
splitter = ItemSplitter(item, people_list)
|
|
||||||
item_name, item_people, item_amount, delete_item_button = (
|
|
||||||
splitter.generate()
|
|
||||||
)
|
|
||||||
# This event needs to be defined outside the ItemSplitter class
|
|
||||||
# because it references a gr.State variable. All Gradio components
|
|
||||||
# can be properly pass ONLY via event listeners, as their state is
|
|
||||||
# managed by Gradio outside the flow of the Python app.
|
|
||||||
delete_item_button.click(
|
|
||||||
splitter.delete_item, self.items, self.items
|
|
||||||
)
|
|
||||||
item_names.append(item_name)
|
|
||||||
item_peoples.append(item_people)
|
|
||||||
item_amounts.append(item_amount)
|
|
||||||
|
|
||||||
self.split_tip_proportionally = gr.Checkbox(
|
|
||||||
value=True,
|
|
||||||
label="Split tip proportional to other costs",
|
|
||||||
info="If unchecked, will split equally.",
|
|
||||||
interactive=True,
|
|
||||||
)
|
|
||||||
self.split_tax_proportionally = gr.Checkbox(
|
|
||||||
value=True,
|
|
||||||
label="Split tax proportional to other costs",
|
|
||||||
info="If unchecked, will split equally.",
|
|
||||||
interactive=True,
|
|
||||||
)
|
|
||||||
self.add_cashback_discount = gr.Number(
|
|
||||||
minimum=0,
|
|
||||||
maximum=100,
|
|
||||||
value=0,
|
|
||||||
step=0.5,
|
|
||||||
label="Cashback discount to apply on total",
|
|
||||||
info="Choose a number between 0% and 100%.",
|
|
||||||
interactive=True,
|
|
||||||
)
|
|
||||||
self.show_detailed_table = gr.Checkbox(
|
|
||||||
value=False,
|
|
||||||
label="Show a detailed calculation table",
|
|
||||||
info="If unchecked, will just show the splits.",
|
|
||||||
interactive=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
self.integrity_markdown = gr.Markdown(
|
|
||||||
show_label=False, value="", visible=False
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
get_splits_button = gr.Button(
|
|
||||||
"Get Splits", variant="primary", scale=5, min_width=10
|
|
||||||
)
|
|
||||||
get_splits_button.click(
|
|
||||||
lambda *x: SplitAIApp.prepare_calculate_splits_kwargs(
|
|
||||||
len(item_names), *x
|
|
||||||
),
|
|
||||||
inputs=(
|
|
||||||
item_names
|
|
||||||
+ item_peoples
|
|
||||||
+ item_amounts
|
|
||||||
+ [
|
|
||||||
self.total_amount,
|
|
||||||
self.tip_amount,
|
|
||||||
self.tax_amount,
|
|
||||||
self.people_list,
|
|
||||||
self.split_tip_proportionally,
|
|
||||||
self.split_tax_proportionally,
|
|
||||||
self.add_cashback_discount,
|
|
||||||
self.show_detailed_table,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
outputs=self.display_result,
|
|
||||||
)
|
|
||||||
|
|
||||||
add_item_button = gr.Button(
|
|
||||||
"➕", variant="secondary", scale=1, min_width=10
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_item(
|
|
||||||
items: list[Item],
|
|
||||||
):
|
|
||||||
new_item_name = f"Item {len(items)+1}"
|
|
||||||
return items + [
|
|
||||||
{
|
|
||||||
"name": new_item_name,
|
|
||||||
"amount": 0.0,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
add_item_button.click(
|
|
||||||
add_item,
|
|
||||||
inputs=[self.items],
|
|
||||||
outputs=[self.items],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Constantly keep track of whether totals match or not.
|
|
||||||
def integrity_checker(*args) -> gr.Markdown:
|
|
||||||
items = args[: len(args) - 3]
|
|
||||||
tip_amount, tax_amount, total_amount = args[len(args) - 3 :]
|
|
||||||
subtotal = sum(items)
|
|
||||||
if subtotal + tip_amount + tax_amount != total_amount:
|
|
||||||
return gr.Markdown(
|
|
||||||
f"⚠️ Looks like the total ({total_amount}) doesn't match the value of subtotal ({subtotal}) + tip ({tip_amount}) + tax ({tax_amount}) ⚠️",
|
|
||||||
show_label=False,
|
|
||||||
visible=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return gr.Markdown(visible=False)
|
|
||||||
|
|
||||||
gr.on(
|
|
||||||
triggers=[x.change for x in item_amounts]
|
|
||||||
+ [
|
|
||||||
self.tip_amount.change,
|
|
||||||
self.tax_amount.change,
|
|
||||||
self.total_amount.change,
|
|
||||||
],
|
|
||||||
fn=integrity_checker,
|
|
||||||
inputs=item_amounts
|
|
||||||
+ [
|
|
||||||
self.tip_amount,
|
|
||||||
self.tax_amount,
|
|
||||||
self.total_amount,
|
|
||||||
],
|
|
||||||
outputs=[self.integrity_markdown],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.display_result = gr.DataFrame(value=None, visible=False)
|
|
||||||
|
|
||||||
self.spinner_html = gr.HTML(
|
|
||||||
spinner_html,
|
|
||||||
visible=False,
|
|
||||||
padding=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.image_uploader.upload(
|
|
||||||
lambda: gr.HTML(visible=True), inputs=None, outputs=self.spinner_html
|
|
||||||
).then(
|
|
||||||
self.process_image,
|
|
||||||
inputs=[self.image_uploader, self.items],
|
|
||||||
outputs=[
|
|
||||||
self.merchant,
|
|
||||||
self.receipt_date,
|
|
||||||
self.total_amount,
|
|
||||||
self.tip_amount,
|
|
||||||
self.tax_amount,
|
|
||||||
self.items,
|
|
||||||
],
|
|
||||||
show_progress="hidden",
|
|
||||||
).then(
|
|
||||||
lambda: gr.HTML(visible=False), inputs=None, outputs=self.spinner_html
|
|
||||||
)
|
|
||||||
|
|
||||||
return split_app
|
|
||||||
|
|
||||||
def process_image(
|
|
||||||
self,
|
|
||||||
image: Image,
|
|
||||||
items: gr.State,
|
|
||||||
): # -> gr.State:
|
|
||||||
receipt_string = self.receipt_reader.get_ordered_text(image)
|
|
||||||
receipt_extracted = self.receipt_reader.extract_components(receipt_string)
|
|
||||||
print(receipt_extracted)
|
|
||||||
# receipt_extracted = {
|
|
||||||
# "merchant": "FUBAR",
|
|
||||||
# "receipt_date": datetime.now(),
|
|
||||||
# "total": 15,
|
|
||||||
# "tip": 0,
|
|
||||||
# "tax": 3,
|
|
||||||
# "item_amounts": [
|
|
||||||
# {"name": "PET TOY", "currency": "$", "amount": 2},
|
|
||||||
# {"name": "FLOPPY PUPPY", "currency": "$", "amount": 4},
|
|
||||||
# {"name": "SSSUPREME S", "currency": "$", "amount": 6},
|
|
||||||
# ],
|
|
||||||
# }
|
|
||||||
key_value_updates = [
|
|
||||||
{
|
|
||||||
"component": self.merchant,
|
|
||||||
"kwargs": {"value": receipt_extracted["merchant"], "visible": True},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"component": self.receipt_date,
|
|
||||||
"kwargs": {
|
|
||||||
"value": receipt_extracted["receipt_date"],
|
|
||||||
"visible": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"component": self.total_amount,
|
|
||||||
"kwargs": {
|
|
||||||
"value": receipt_extracted["total"],
|
|
||||||
"visible": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"component": self.tip_amount,
|
|
||||||
"kwargs": {
|
|
||||||
"value": receipt_extracted["tip"],
|
|
||||||
"visible": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"component": self.tax_amount,
|
|
||||||
"kwargs": {
|
|
||||||
"value": receipt_extracted["tax"],
|
|
||||||
"visible": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
out = [
|
|
||||||
self.update_component_attributes(x["component"], **x["kwargs"])
|
|
||||||
for x in key_value_updates
|
|
||||||
]
|
|
||||||
items += [
|
|
||||||
{"name": x["name"], "amount": x["amount"]}
|
|
||||||
for x in receipt_extracted["receipt_items"]
|
|
||||||
]
|
|
||||||
out += [items]
|
|
||||||
return out
|
|
||||||
|
|
||||||
def launch(self, expose_to_local_network: bool = False):
|
|
||||||
if expose_to_local_network:
|
|
||||||
self.demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
|
||||||
else:
|
|
||||||
self.demo.queue().launch()
|
|
||||||
|
|
||||||
|
|
||||||
def arg_parser() -> agp.ArgumentParser:
|
|
||||||
ag = agp.ArgumentParser()
|
|
||||||
ag.add_argument(
|
|
||||||
"-m",
|
|
||||||
"--model",
|
|
||||||
type=str,
|
|
||||||
default="qwen2.5:7b",
|
|
||||||
help="Choose the LLM model used.",
|
|
||||||
)
|
|
||||||
return ag
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = arg_parser().parse_args()
|
|
||||||
demo = SplitAIApp(args.model)
|
|
||||||
demo.launch(True)
|
|
||||||
275
src/core/split.py
Normal file
275
src/core/split.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
import polars as pl
|
||||||
|
|
||||||
|
|
||||||
|
class IncompleteSplitError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message,
|
||||||
|
):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class SplitCalculator:
|
||||||
|
"""
|
||||||
|
A simple, but long class to calculate splits for a provided receipt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item_names: Names of the items being split.
|
||||||
|
item_people: A list of people for each item who are splitting its cost.
|
||||||
|
item_amounts: Amounts of the items being split.
|
||||||
|
total: The total amount in the receipt
|
||||||
|
tip: The tip in the receipt
|
||||||
|
tax: The tax in the receipt
|
||||||
|
people_list: The total number of people splitting the receipt.
|
||||||
|
tip_split_proportionally: Indicator for whether the tip is split proportional to pre-tax/tip cost.
|
||||||
|
tax_split_proportionally: Indicator for whether the tax is split proportional to pre-tax/tip cost.
|
||||||
|
cashback_discount: The total will be reduced by this percentage value.
|
||||||
|
return_detailed_table: Indicator to return full calculation table or a simplified one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
item_names: list[str],
|
||||||
|
item_people: list[list[str]],
|
||||||
|
item_amounts: list[float],
|
||||||
|
receipt_total: float,
|
||||||
|
receipt_tip: float,
|
||||||
|
receipt_tax: float,
|
||||||
|
people_list: list[str],
|
||||||
|
tip_split_proportionally: bool,
|
||||||
|
tax_split_proportionally: bool,
|
||||||
|
cashback_discount: float,
|
||||||
|
return_detailed_table: bool = False,
|
||||||
|
):
|
||||||
|
self.item_names = item_names
|
||||||
|
self.item_people = item_people
|
||||||
|
self.item_amounts = item_amounts
|
||||||
|
self.receipt_total = receipt_total
|
||||||
|
self.receipt_tip = receipt_tip
|
||||||
|
self.receipt_tax = receipt_tax
|
||||||
|
self.people_list = people_list
|
||||||
|
self.tip_split_proportionally = tip_split_proportionally
|
||||||
|
self.tax_split_proportionally = tax_split_proportionally
|
||||||
|
self.cashback_discount = cashback_discount
|
||||||
|
self.return_detailed_table = return_detailed_table
|
||||||
|
|
||||||
|
self.subtotal = self.receipt_total - self.receipt_tip - self.receipt_tax
|
||||||
|
self.split_tips: float | None = None
|
||||||
|
self.split_taxes: float | None = None
|
||||||
|
|
||||||
|
def validate_splits(self):
|
||||||
|
split_count = 0
|
||||||
|
unsplit_names = []
|
||||||
|
for name, split in zip(self.item_names, self.item_people):
|
||||||
|
if len(split) > 0:
|
||||||
|
split_count += 1
|
||||||
|
else:
|
||||||
|
unsplit_names.append(name)
|
||||||
|
if split_count != len(self.item_people):
|
||||||
|
raise IncompleteSplitError(
|
||||||
|
f"The following items have not been assigned splits: {','.join(unsplit_names)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def distribute_amount(self, amount: float, split_subtotals: list[float]):
|
||||||
|
"""
|
||||||
|
Distribute `amount` equally, or distribute it proportionally, among
|
||||||
|
all the people involved in the split.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
x / self.subtotal * amount
|
||||||
|
if self.tax_split_proportionally
|
||||||
|
else amount / len(self.people_list)
|
||||||
|
for x in split_subtotals
|
||||||
|
]
|
||||||
|
|
||||||
|
def subtract_cashback(
|
||||||
|
self, split_totals: list[float]
|
||||||
|
) -> tuple[list[float], list[float]]:
|
||||||
|
split_cashback = [-x * cashback_discount for x in split_totals]
|
||||||
|
split_totals_minus_cashback = [
|
||||||
|
x * (1 - cashback_discount) for x in split_totals
|
||||||
|
]
|
||||||
|
return split_cashback, split_totals_minus_cashback
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
split_arrays: list[list[float]] = []
|
||||||
|
for split in self.item_people:
|
||||||
|
split_array = [
|
||||||
|
1 / len(split) if x in split else 0.0 for x in self.people_list
|
||||||
|
]
|
||||||
|
split_arrays.append(split_array)
|
||||||
|
split_amounts: list[list[float]] = []
|
||||||
|
for split_array, amount in zip(split_arrays, self.item_amounts):
|
||||||
|
split_amount = [amount * split for split in split_array]
|
||||||
|
split_amounts.append(split_amount)
|
||||||
|
|
||||||
|
split_subtotals = [sum(x) for x in zip(*split_amounts)]
|
||||||
|
split_tips = self.distribute_amount(self.receipt_tip, split_subtotals)
|
||||||
|
split_taxes = self.distribute_amount(self.receipt_tax, split_subtotals)
|
||||||
|
|
||||||
|
split_totals = [
|
||||||
|
split_subtotal + split_tip + split_tax
|
||||||
|
for split_subtotal, split_tip, split_tax in zip(
|
||||||
|
split_subtotals, split_tips, split_taxes
|
||||||
|
)
|
||||||
|
]
|
||||||
|
split_cashback, split_totals_minus_cashback = self.subtract_cashback(
|
||||||
|
split_totals
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_splits(
|
||||||
|
item_names: list[str],
|
||||||
|
item_people: list[list[str]],
|
||||||
|
item_amounts: list[float],
|
||||||
|
total: float,
|
||||||
|
tip: float,
|
||||||
|
tax: float,
|
||||||
|
people_list: list[str],
|
||||||
|
tip_split_proportionally: bool,
|
||||||
|
tax_split_proportionally: bool,
|
||||||
|
cashback_discount: float,
|
||||||
|
return_detailed_table: bool = False,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
A simple, but long function to calculate splits for a provided receipt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item_names: Names of the items being split.
|
||||||
|
item_people: A list of people for each item who are splitting its cost.
|
||||||
|
item_amounts: Amounts of the items being split.
|
||||||
|
total: The total amount in the receipt
|
||||||
|
tip: The tip in the receipt
|
||||||
|
tax: The tax in the receipt
|
||||||
|
people_list: The total number of people splitting the receipt.
|
||||||
|
tip_split_proportionally: Indicator for whether the tip is split proportional to pre-tax/tip cost.
|
||||||
|
tax_split_proportionally: Indicator for whether the tax is split proportional to pre-tax/tip cost.
|
||||||
|
cashback_discount: The total will be reduced by this percentage value.
|
||||||
|
return_detailed_table: Indicator to return full calculation table or a simplified one.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A DataFrame form of the provided values along with their calculated splits or a simplified version.
|
||||||
|
"""
|
||||||
|
split_count = 0
|
||||||
|
unsplit_names = []
|
||||||
|
checkbox_count = len(item_people)
|
||||||
|
for name, split in zip(item_names, item_people):
|
||||||
|
if len(split) > 0:
|
||||||
|
split_count += 1
|
||||||
|
else:
|
||||||
|
unsplit_names.append(name)
|
||||||
|
if split_count != checkbox_count:
|
||||||
|
raise IncompleteSplitError(
|
||||||
|
f"The following items have not been assigned splits: {','.join(unsplit_names)}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# Deliberately avoiding going the numpy route here since the data is very small anyway.
|
||||||
|
split_arrays: list[list[float]] = []
|
||||||
|
for split in item_people:
|
||||||
|
split_array = [1 / len(split) if x in split else 0.0 for x in people_list]
|
||||||
|
split_arrays.append(split_array)
|
||||||
|
split_amounts: list[list[float]] = []
|
||||||
|
for split_array, amount in zip(split_arrays, item_amounts):
|
||||||
|
split_amount = [amount * split for split in split_array]
|
||||||
|
split_amounts.append(split_amount)
|
||||||
|
|
||||||
|
split_subtotals = [sum(x) for x in zip(*split_amounts)]
|
||||||
|
subtotal = total - tip - tax
|
||||||
|
split_tips = [
|
||||||
|
x / subtotal * tip if tip_split_proportionally else tax / len(people_list)
|
||||||
|
for x in split_subtotals
|
||||||
|
]
|
||||||
|
split_taxes = [
|
||||||
|
x / subtotal * tax if tax_split_proportionally else tax / len(people_list)
|
||||||
|
for x in split_subtotals
|
||||||
|
]
|
||||||
|
split_totals_pre_cashback = [
|
||||||
|
split_subtotal + split_tip + split_tax
|
||||||
|
for split_subtotal, split_tip, split_tax in zip(
|
||||||
|
split_subtotals, split_tips, split_taxes
|
||||||
|
)
|
||||||
|
]
|
||||||
|
split_cashback = [-x * cashback_discount for x in split_totals_pre_cashback]
|
||||||
|
split_totals_post_cashback = [
|
||||||
|
x * (1 - cashback_discount) for x in split_totals_pre_cashback
|
||||||
|
]
|
||||||
|
first_col_names = list(item_names) + [
|
||||||
|
"Subtotal",
|
||||||
|
"Tip",
|
||||||
|
"Tax",
|
||||||
|
"Cashback",
|
||||||
|
"Total",
|
||||||
|
]
|
||||||
|
splits = split_amounts + [
|
||||||
|
split_subtotals,
|
||||||
|
split_tips,
|
||||||
|
split_taxes,
|
||||||
|
split_cashback,
|
||||||
|
split_totals_post_cashback,
|
||||||
|
]
|
||||||
|
horizontal_totals = list(item_amounts) + [
|
||||||
|
subtotal,
|
||||||
|
tip,
|
||||||
|
tax,
|
||||||
|
sum(split_cashback),
|
||||||
|
sum(split_totals_post_cashback),
|
||||||
|
]
|
||||||
|
full_calculation_df = (
|
||||||
|
pl.DataFrame(
|
||||||
|
{
|
||||||
|
"Item": first_col_names,
|
||||||
|
"splits": splits,
|
||||||
|
"Total": horizontal_totals,
|
||||||
|
},
|
||||||
|
schema={
|
||||||
|
"Item": pl.String,
|
||||||
|
"splits": pl.List(pl.Float64),
|
||||||
|
"Total": pl.Float64,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.with_columns(pl.col("splits").list.to_struct(fields=people_list))
|
||||||
|
.unnest("splits")
|
||||||
|
.with_columns(pl.col(pl.Float64).round(2))
|
||||||
|
)
|
||||||
|
if return_detailed_table:
|
||||||
|
return full_calculation_df
|
||||||
|
else:
|
||||||
|
simple_calculation = (
|
||||||
|
full_calculation_df.filter(pl.col("Item").eq("Total"))
|
||||||
|
.select(pl.exclude("Total"))
|
||||||
|
.transpose(
|
||||||
|
include_header=True, header_name="Person", column_names=["Split"]
|
||||||
|
)
|
||||||
|
.filter(pl.col("Person").ne("Item"))
|
||||||
|
)
|
||||||
|
return simple_calculation
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example usage
|
||||||
|
item_names = ["Item 1", "Item 2", "Item 3"]
|
||||||
|
item_people = [["Alice", "Bob"], ["Alice"], ["Bob", "Charlie"]]
|
||||||
|
item_amounts = [10.0, 20.0, 30.0]
|
||||||
|
total = 70.0
|
||||||
|
tip = 6.0
|
||||||
|
tax = 4.0
|
||||||
|
people_list = ["Alice", "Bob", "Charlie"]
|
||||||
|
tip_split_proportionally = True
|
||||||
|
tax_split_proportionally = True
|
||||||
|
cashback_discount = 0.03
|
||||||
|
|
||||||
|
result_df = calculate_splits(
|
||||||
|
item_names,
|
||||||
|
item_people,
|
||||||
|
item_amounts,
|
||||||
|
total,
|
||||||
|
tip,
|
||||||
|
tax,
|
||||||
|
people_list,
|
||||||
|
tip_split_proportionally,
|
||||||
|
tax_split_proportionally,
|
||||||
|
cashback_discount,
|
||||||
|
return_detailed_table=True,
|
||||||
|
)
|
||||||
|
print(result_df)
|
||||||
Reference in New Issue
Block a user