Initial stylized working version
This commit is contained in:
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# Data Files
|
||||||
|
data/*
|
||||||
15
.idea/SplitAI.iml
generated
Normal file
15
.idea/SplitAI.iml
generated
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="Python 3.11 (SplitAI)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PackageRequirementsSettings">
|
||||||
|
<option name="requirementsPath" value="" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="GOOGLE" />
|
||||||
|
<option name="myDocStringFormat" value="Google" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.11 (SplitAI)" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (SplitAI)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/../SplitAI/.idea/SplitAI.iml" filepath="$PROJECT_DIR$/../SplitAI/.idea/SplitAI.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
8
.idea/ruff.xml
generated
Normal file
8
.idea/ruff.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="RuffConfigService">
|
||||||
|
<option name="runRuffOnSave" value="true" />
|
||||||
|
<option name="useRuffFormat" value="true" />
|
||||||
|
<option name="useRuffServer" value="true" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
11
.idea/ryecharm-overrides.xml
generated
Normal file
11
.idea/ryecharm-overrides.xml
generated
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="insyncwithfoo.ryecharm.configurations.ruff.Override">
|
||||||
|
<option name="names">
|
||||||
|
<map>
|
||||||
|
<entry key="crossPlatformExecutableResolution" value="true" />
|
||||||
|
<entry key="executable" value="true" />
|
||||||
|
</map>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/ryecharm.xml
generated
Normal file
6
.idea/ryecharm.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="insyncwithfoo.ryecharm.configurations.ruff.Local">
|
||||||
|
<option name="executable" value="ruff" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
70
.idea/workspace.xml
generated
Normal file
70
.idea/workspace.xml
generated
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="AutoImportSettings">
|
||||||
|
<option name="autoReloadType" value="SELECTIVE" />
|
||||||
|
</component>
|
||||||
|
<component name="ChangeListManager">
|
||||||
|
<list default="true" id="b6145d7e-32b4-4917-b6ec-19187eb432a7" name="Changes" comment="">
|
||||||
|
<change afterPath="$PROJECT_DIR$/.gitignore" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/SplitAI.iml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/modules.xml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/ruff.xml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/ryecharm-overrides.xml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/ryecharm.xml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/.python-version" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/pyproject.toml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/src/app/gradio_ui.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/src/app/split_ai.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/src/app/utils.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/src/llm/ollama.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/src/llm/prompts.yaml" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/src/ocr/surya.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/tests/test_ollama.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/tests/test_surya.py" afterDir="false" />
|
||||||
|
<change afterPath="$PROJECT_DIR$/uv.lock" afterDir="false" />
|
||||||
|
</list>
|
||||||
|
<option name="SHOW_DIALOG" value="false" />
|
||||||
|
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||||
|
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
||||||
|
<option name="LAST_RESOLUTION" value="IGNORE" />
|
||||||
|
</component>
|
||||||
|
<component name="Git.Settings">
|
||||||
|
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||||
|
</component>
|
||||||
|
<component name="ProblemsViewState">
|
||||||
|
<option name="selectedTabId" value="CurrentFile" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectColorInfo">{
|
||||||
|
"associatedIndex": 6
|
||||||
|
}</component>
|
||||||
|
<component name="ProjectId" id="2qtfzFBMMaI540YQLSoIQUrWHCr" />
|
||||||
|
<component name="ProjectViewState">
|
||||||
|
<option name="hideEmptyMiddlePackages" value="true" />
|
||||||
|
<option name="showLibraryContents" value="true" />
|
||||||
|
</component>
|
||||||
|
<component name="PropertiesComponent">{
|
||||||
|
"keyToString": {
|
||||||
|
"RunOnceActivity.ShowReadmeOnStart": "true",
|
||||||
|
"RunOnceActivity.git.unshallow": "true",
|
||||||
|
"SHARE_PROJECT_CONFIGURATION_FILES": "true",
|
||||||
|
"git-widget-placeholder": "main",
|
||||||
|
"last_opened_file_path": "/Users/avimallu/Documents/SplitAI",
|
||||||
|
"settings.editor.selected.configurable": "preferences.pluginManager"
|
||||||
|
}
|
||||||
|
}</component>
|
||||||
|
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
||||||
|
<component name="TaskManager">
|
||||||
|
<task active="true" id="Default" summary="Default task">
|
||||||
|
<changelist id="b6145d7e-32b4-4917-b6ec-19187eb432a7" name="Changes" comment="" />
|
||||||
|
<created>1735487052741</created>
|
||||||
|
<option name="number" value="Default" />
|
||||||
|
<option name="presentableId" value="Default" />
|
||||||
|
<updated>1735487052741</updated>
|
||||||
|
</task>
|
||||||
|
<servers />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.11
|
||||||
21
pyproject.toml
Normal file
21
pyproject.toml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
[project]
|
||||||
|
name = "SplitAI"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"gradio>=5.9.1",
|
||||||
|
"great-tables>=0.15.0",
|
||||||
|
"more-itertools>=10.5.0",
|
||||||
|
"ollama>=0.4.4",
|
||||||
|
"polars>=1.18.0",
|
||||||
|
"selenium>=4.27.1",
|
||||||
|
"surya-ocr>=0.8.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.3.4",
|
||||||
|
"ruff>=0.8.4",
|
||||||
|
]
|
||||||
596
src/app/gradio_ui.py
Normal file
596
src/app/gradio_ui.py
Normal file
@@ -0,0 +1,596 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal, TypedDict, TypeVar
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import polars as pl
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
from src.app.split_ai import ReceiptReader
|
||||||
|
from src.app.utils import css_code, head_html, spinner_html
|
||||||
|
|
||||||
|
ComponentType = TypeVar("ComponentType")
|
||||||
|
|
||||||
|
def calculate_splits(
|
||||||
|
item_names: list[str],
|
||||||
|
item_people: list[list[str]],
|
||||||
|
item_amounts: list[float],
|
||||||
|
total: float,
|
||||||
|
tip: float,
|
||||||
|
tax: float,
|
||||||
|
people_list: list[str],
|
||||||
|
tip_split_proportionally: bool,
|
||||||
|
tax_split_proportionally: bool,
|
||||||
|
cashback_discount: float,
|
||||||
|
return_detailed_table: bool = False
|
||||||
|
) -> gr.DataFrame:
|
||||||
|
"""
|
||||||
|
A simple, but long function to calculate splits for a provided receipt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item_names: Names of the items being split.
|
||||||
|
item_people: A list of people for each item who are splitting its cost.
|
||||||
|
item_amounts: Amounts of the items being split.
|
||||||
|
total: The total amount in the receipt
|
||||||
|
tip: The tip in the receipt
|
||||||
|
tax: The tax in the receipt
|
||||||
|
people_list: The total number of people splitting the receipt.
|
||||||
|
tip_split_proportionally: Indicator for whether the tip is split proportional to pre-tax/tip cost.
|
||||||
|
tax_split_proportionally: Indicator for whether the tax is split proportional to pre-tax/tip cost.
|
||||||
|
cashback_discount: The total will be reduced by this percentage value.
|
||||||
|
return_detailed_table: Indicator to return full calculation table or a simplified one.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A DataFrame form of the provided values along with their calculated splits or a simplified version.
|
||||||
|
"""
|
||||||
|
split_count = 0
|
||||||
|
unsplit_names = []
|
||||||
|
checkbox_count = len(item_people)
|
||||||
|
for name, split in zip(item_names, item_people):
|
||||||
|
if len(split) > 0:
|
||||||
|
split_count += 1
|
||||||
|
else:
|
||||||
|
unsplit_names.append(name)
|
||||||
|
if split_count != checkbox_count:
|
||||||
|
gr.Warning(
|
||||||
|
f"Pending splits: {','.join(unsplit_names)}",
|
||||||
|
title="Can't show splits yet",
|
||||||
|
)
|
||||||
|
return gr.DataFrame(pl.DataFrame(), visible=False)
|
||||||
|
else:
|
||||||
|
# Deliberately avoiding going the numpy route here since the data is very small anyway.
|
||||||
|
split_arrays: list[list[float]] = []
|
||||||
|
for split in item_people:
|
||||||
|
split_array = [1 / len(split) if x in split else 0.0 for x in people_list]
|
||||||
|
split_arrays.append(split_array)
|
||||||
|
split_amounts: list[list[float]] = []
|
||||||
|
for split_array, amount in zip(split_arrays, item_amounts):
|
||||||
|
split_amount = [amount * split for split in split_array]
|
||||||
|
split_amounts.append(split_amount)
|
||||||
|
|
||||||
|
split_subtotals = [sum(x) for x in zip(*split_amounts)]
|
||||||
|
subtotal = total - tip - tax
|
||||||
|
split_tips = [
|
||||||
|
x / subtotal * tip if tip_split_proportionally else tax / len(people_list)
|
||||||
|
for x in split_subtotals
|
||||||
|
]
|
||||||
|
split_taxes = [
|
||||||
|
x / subtotal * tax if tax_split_proportionally else tax / len(people_list)
|
||||||
|
for x in split_subtotals
|
||||||
|
]
|
||||||
|
split_totals_pre_cashback = [
|
||||||
|
split_subtotal + split_tip + split_tax
|
||||||
|
for split_subtotal, split_tip, split_tax in zip(
|
||||||
|
split_subtotals, split_tips, split_taxes
|
||||||
|
)
|
||||||
|
]
|
||||||
|
split_cashback = [
|
||||||
|
- x * cashback_discount for x in split_totals_pre_cashback
|
||||||
|
]
|
||||||
|
split_totals_post_cashback = [
|
||||||
|
x * (1 - cashback_discount) for x in split_totals_pre_cashback
|
||||||
|
]
|
||||||
|
first_col_names = list(item_names) + ["Subtotal", "Tip", "Tax", "Cashback", "Total"]
|
||||||
|
splits = split_amounts + [
|
||||||
|
split_subtotals,
|
||||||
|
split_tips,
|
||||||
|
split_taxes,
|
||||||
|
split_cashback,
|
||||||
|
split_totals_post_cashback,
|
||||||
|
]
|
||||||
|
horizontal_totals = list(item_amounts) + [
|
||||||
|
subtotal,
|
||||||
|
tip,
|
||||||
|
tax,
|
||||||
|
sum(split_cashback),
|
||||||
|
sum(split_totals_post_cashback),
|
||||||
|
]
|
||||||
|
full_calculation_df = (
|
||||||
|
pl.DataFrame(
|
||||||
|
{
|
||||||
|
"Item": first_col_names,
|
||||||
|
"splits": splits,
|
||||||
|
"Total": horizontal_totals,
|
||||||
|
},
|
||||||
|
schema={
|
||||||
|
"Item": pl.String,
|
||||||
|
"splits": pl.List(pl.Float64),
|
||||||
|
"Total": pl.Float64,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.with_columns(pl.col("splits").list.to_struct(fields=people_list))
|
||||||
|
.unnest("splits")
|
||||||
|
.with_columns(pl.col(pl.Float64).round(2))
|
||||||
|
)
|
||||||
|
if return_detailed_table:
|
||||||
|
return gr.DataFrame(full_calculation_df, visible=True)
|
||||||
|
else:
|
||||||
|
simple_calculation = (
|
||||||
|
full_calculation_df
|
||||||
|
.filter(pl.col("Item").eq("Total"))
|
||||||
|
.select(pl.exclude("Total"))
|
||||||
|
.transpose(include_header=True, header_name="Person", column_names=["Split"])
|
||||||
|
.filter(pl.col("Person").ne("Item"))
|
||||||
|
)
|
||||||
|
return gr.DataFrame(simple_calculation, visible=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Item(TypedDict):
|
||||||
|
name: str
|
||||||
|
amount: float
|
||||||
|
|
||||||
|
|
||||||
|
class ItemSplitter:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
item: Item,
|
||||||
|
people_list: list[str],
|
||||||
|
) -> None:
|
||||||
|
self.people_list_state = people_list
|
||||||
|
self.item = item
|
||||||
|
self.no_interaction_kwargs = {"interactive": False, "min_width": 10}
|
||||||
|
self.interaction_kwargs = self.no_interaction_kwargs | {
|
||||||
|
"container": False,
|
||||||
|
"interactive": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def name_textbox(self, item_name: str) -> gr.Textbox:
|
||||||
|
return gr.Textbox(
|
||||||
|
item_name, show_label=False, scale=8, **self.interaction_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def amount_number(self, item_amount: float) -> gr.Number:
|
||||||
|
return gr.Number(
|
||||||
|
value=item_amount, precision=2, scale=3, **self.interaction_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_status_button(
|
||||||
|
self, choices: list[str] | None = None, status: Literal["⚠️", "🆗"] = "⚠️"
|
||||||
|
) -> gr.Button:
|
||||||
|
if choices is not None:
|
||||||
|
if len(choices) == 0:
|
||||||
|
status = "⚠️"
|
||||||
|
else:
|
||||||
|
status = "🆗"
|
||||||
|
else:
|
||||||
|
choices = []
|
||||||
|
variant: Literal["huggingface", "primary"] = (
|
||||||
|
"huggingface" if (status == "⚠️") | (len(choices) == 0) else "primary"
|
||||||
|
)
|
||||||
|
return gr.Button(
|
||||||
|
value=status, variant=variant, scale=1, **self.no_interaction_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_item(self, item_list: list[Item]) -> list[Item]:
|
||||||
|
item_list.remove(self.item)
|
||||||
|
return item_list
|
||||||
|
|
||||||
|
def delete_item_button(self) -> gr.Button:
|
||||||
|
kwargs = self.no_interaction_kwargs | {"interactive": True}
|
||||||
|
return gr.Button(value="❌", variant="stop", **kwargs)
|
||||||
|
|
||||||
|
def people_list_checkbox(self, people_list: list[str]) -> gr.CheckboxGroup:
|
||||||
|
return gr.CheckboxGroup(choices=people_list, **self.interaction_kwargs)
|
||||||
|
|
||||||
|
def generate(self) -> tuple[gr.Textbox, gr.CheckboxGroup, gr.Number, gr.Button]:
|
||||||
|
return self.generate_mobile()
|
||||||
|
|
||||||
|
def generate_mobile(self):
|
||||||
|
with gr.Row(variant="default", equal_height=True):
|
||||||
|
item_name_textbox = self.name_textbox(self.item["name"])
|
||||||
|
item_amount_number = self.amount_number(self.item["amount"])
|
||||||
|
split_status_button = self.split_status_button(status="⚠️")
|
||||||
|
delete_item_button = self.delete_item_button()
|
||||||
|
people_list_checkbox = self.people_list_checkbox(self.people_list_state)
|
||||||
|
people_list_checkbox.change(
|
||||||
|
lambda x: self.split_status_button(choices=x),
|
||||||
|
people_list_checkbox,
|
||||||
|
split_status_button,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
item_name_textbox,
|
||||||
|
people_list_checkbox,
|
||||||
|
item_amount_number,
|
||||||
|
delete_item_button,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SplitAIApp:
|
||||||
|
valid_split_variant: Literal["primary"] = "primary"
|
||||||
|
invalid_split_variant: Literal["huggingface"] = "huggingface"
|
||||||
|
|
||||||
|
def __init__(self, llm_model: str):
|
||||||
|
self.receipt_reader = ReceiptReader(llm_model)
|
||||||
|
self.demo = self.create_app()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_calculate_splits_kwargs(num_records: int, *all_values) -> gr.DataFrame:
|
||||||
|
"""
|
||||||
|
This method is necessary because only a list[gr.Component] or similar can be sent as
|
||||||
|
`inputs` to an event listener. Therefore, it is unpacked here and prepared into a
|
||||||
|
dictionary based on how it is sent by the event. This method is specifically for
|
||||||
|
the `get_split_button.click` event listener.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_records: The number of items present to split.
|
||||||
|
*all_values: A list of components to forward to `calculate_splits`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
gr.DataFrame
|
||||||
|
"""
|
||||||
|
kwargs = {
|
||||||
|
"item_names": all_values[:num_records],
|
||||||
|
"item_people": all_values[num_records : num_records * 2],
|
||||||
|
"item_amounts": all_values[num_records * 2 : num_records * 3],
|
||||||
|
}
|
||||||
|
additional_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in zip(
|
||||||
|
[
|
||||||
|
"total",
|
||||||
|
"tip",
|
||||||
|
"tax",
|
||||||
|
"people_list",
|
||||||
|
"tip_split_proportionally",
|
||||||
|
"tax_split_proportionally",
|
||||||
|
"cashback_discount",
|
||||||
|
"return_detailed_table",
|
||||||
|
],
|
||||||
|
tuple(all_values[num_records * 3 :]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
additional_kwargs["cashback_discount"] /= 100
|
||||||
|
kwargs.update(additional_kwargs)
|
||||||
|
return calculate_splits(**kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_component_attributes(component: ComponentType, **kwargs) -> ComponentType:
|
||||||
|
"""
|
||||||
|
This requirement is in place because Gradio expects you to provide A NEW INSTANCE of
|
||||||
|
the component that you want to update with its attributes changed. It seems like it
|
||||||
|
doesn't replace the component, but updates it this way. Very weird behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component: The gradio component to update attributes for.
|
||||||
|
**kwargs: (attribute, value) pairs to update in child.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of child's class with the updated attributes.
|
||||||
|
"""
|
||||||
|
gradio_class = type(component)
|
||||||
|
try:
|
||||||
|
return gradio_class(**kwargs)
|
||||||
|
except Exception as err:
|
||||||
|
print(
|
||||||
|
f"The Gradio component {gradio_class} does not have one of the provided attribute keys."
|
||||||
|
)
|
||||||
|
raise err
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_people_list(people_textbox) -> tuple[gr.Image, list]:
|
||||||
|
if "," in people_textbox and people_textbox[-1] != ",":
|
||||||
|
people_list = [x.strip() for x in people_textbox.split(",")]
|
||||||
|
return gr.Image(interactive=True), people_list
|
||||||
|
else:
|
||||||
|
gr.Warning("You need to enter a list of names separated by commas.")
|
||||||
|
return gr.Image(interactive=False), []
|
||||||
|
|
||||||
|
def create_app(self) -> gr.Blocks():
|
||||||
|
# `head_html` required to prevent iOS from scaling the UI when clicking on a textbox.
|
||||||
|
with gr.Blocks(
|
||||||
|
css=css_code,
|
||||||
|
head=head_html,
|
||||||
|
theme="JohnSmith9982/small_and_pretty",
|
||||||
|
fill_width=True,
|
||||||
|
) as split_app:
|
||||||
|
with gr.Column():
|
||||||
|
self.people_textbox = gr.Textbox(
|
||||||
|
placeholder="Split names with a comma",
|
||||||
|
label="Who all are splitting this receipt?",
|
||||||
|
lines=1,
|
||||||
|
autofocus=True,
|
||||||
|
submit_btn="Submit",
|
||||||
|
)
|
||||||
|
self.people_list = gr.State([])
|
||||||
|
self.image_uploader = gr.Image(
|
||||||
|
show_label=False, scale=1, type="pil", interactive=False
|
||||||
|
)
|
||||||
|
self.people_textbox.submit(
|
||||||
|
SplitAIApp.validate_people_list,
|
||||||
|
[self.people_textbox],
|
||||||
|
[self.image_uploader, self.people_list],
|
||||||
|
)
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
self.merchant = gr.Textbox(
|
||||||
|
interactive=True,
|
||||||
|
label="Merchant Name",
|
||||||
|
min_width=20,
|
||||||
|
visible=False,
|
||||||
|
scale=2,
|
||||||
|
)
|
||||||
|
self.receipt_date = gr.DateTime(
|
||||||
|
interactive=True,
|
||||||
|
include_time=False,
|
||||||
|
type="datetime",
|
||||||
|
label="Date",
|
||||||
|
min_width=20,
|
||||||
|
visible=False,
|
||||||
|
scale=1,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
self.total_amount = gr.Number(
|
||||||
|
interactive=True,
|
||||||
|
label="Total",
|
||||||
|
min_width=20,
|
||||||
|
visible=False,
|
||||||
|
precision=2,
|
||||||
|
)
|
||||||
|
self.tip_amount = gr.Number(
|
||||||
|
interactive=True,
|
||||||
|
label="Tip",
|
||||||
|
min_width=20,
|
||||||
|
visible=False,
|
||||||
|
precision=2,
|
||||||
|
)
|
||||||
|
self.tax_amount = gr.Number(
|
||||||
|
interactive=True,
|
||||||
|
label="Tax",
|
||||||
|
min_width=20,
|
||||||
|
visible=False,
|
||||||
|
precision=2,
|
||||||
|
)
|
||||||
|
self.items = gr.State([])
|
||||||
|
|
||||||
|
@gr.render(inputs=[self.items, self.people_list])
|
||||||
|
def render_items(items: list[Item], people_list: list[str]):
|
||||||
|
item_names = []
|
||||||
|
item_peoples = []
|
||||||
|
item_amounts = []
|
||||||
|
for key, item in enumerate(items):
|
||||||
|
with gr.Column(variant="compact"):
|
||||||
|
splitter = ItemSplitter(item, people_list)
|
||||||
|
item_name, item_people, item_amount, delete_item_button = (
|
||||||
|
splitter.generate()
|
||||||
|
)
|
||||||
|
# This event needs to be defined outside the ItemSplitter class
|
||||||
|
# because it references a gr.State variable. All Gradio components
|
||||||
|
# can be properly pass ONLY via event listeners, as their state is
|
||||||
|
# managed by Gradio outside the flow of the Python app.
|
||||||
|
delete_item_button.click(
|
||||||
|
splitter.delete_item, self.items, self.items
|
||||||
|
)
|
||||||
|
item_names.append(item_name)
|
||||||
|
item_peoples.append(item_people)
|
||||||
|
item_amounts.append(item_amount)
|
||||||
|
|
||||||
|
self.split_tip_proportionally = gr.Checkbox(
|
||||||
|
value=True,
|
||||||
|
label="Split tip proportional to other costs",
|
||||||
|
info="If unchecked, will split equally.",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.split_tax_proportionally = gr.Checkbox(
|
||||||
|
value=True,
|
||||||
|
label="Split tax proportional to other costs",
|
||||||
|
info="If unchecked, will split equally.",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.add_cashback_discount = gr.Slider(
|
||||||
|
minimum=0, maximum=25, value=0, step=0.5,
|
||||||
|
label="Cashback discount to apply on total",
|
||||||
|
info="Choose a number between 0% and 100%.",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
self.show_detailed_table = gr.Checkbox(
|
||||||
|
value=False,
|
||||||
|
label="Show a detailed calculation table",
|
||||||
|
info="If unchecked, will just show the splits.",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.integrity_markdown = gr.Markdown(
|
||||||
|
show_label=False, value="", visible=False
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
get_splits_button = gr.Button(
|
||||||
|
"Get Splits", variant="primary", scale=5, min_width=10
|
||||||
|
)
|
||||||
|
get_splits_button.click(
|
||||||
|
lambda *x: SplitAIApp.prepare_calculate_splits_kwargs(
|
||||||
|
len(item_names), *x
|
||||||
|
),
|
||||||
|
inputs=(
|
||||||
|
item_names
|
||||||
|
+ item_peoples
|
||||||
|
+ item_amounts
|
||||||
|
+ [
|
||||||
|
self.total_amount,
|
||||||
|
self.tip_amount,
|
||||||
|
self.tax_amount,
|
||||||
|
self.people_list,
|
||||||
|
self.split_tip_proportionally,
|
||||||
|
self.split_tax_proportionally,
|
||||||
|
self.add_cashback_discount,
|
||||||
|
self.show_detailed_table,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
outputs=self.display_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
add_item_button = gr.Button("➕", variant="secondary", scale=1, min_width=10)
|
||||||
|
|
||||||
|
def add_item(
|
||||||
|
items: list[Item],
|
||||||
|
):
|
||||||
|
new_item_name = f"Item {len(items)+1}"
|
||||||
|
return items + [
|
||||||
|
{
|
||||||
|
"name": new_item_name,
|
||||||
|
"amount": 0.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
add_item_button.click(
|
||||||
|
add_item,
|
||||||
|
inputs=[self.items],
|
||||||
|
outputs=[self.items],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Constantly keep track of whether totals match or not.
|
||||||
|
def integrity_checker(*args) -> gr.Markdown:
|
||||||
|
items = args[: len(args) - 3]
|
||||||
|
tip_amount, tax_amount, total_amount = args[len(args) - 3 :]
|
||||||
|
subtotal = sum(items)
|
||||||
|
if subtotal + tip_amount + tax_amount != total_amount:
|
||||||
|
return gr.Markdown(
|
||||||
|
f"⚠️ Looks like the total ({total_amount}) doesn't match the value of subtotal ({subtotal}) + tip ({tip_amount}) + tax ({tax_amount}) ⚠️",
|
||||||
|
show_label=False,
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return gr.Markdown(visible=False)
|
||||||
|
|
||||||
|
gr.on(
|
||||||
|
triggers=[x.change for x in item_amounts]
|
||||||
|
+ [
|
||||||
|
self.tip_amount.change,
|
||||||
|
self.tax_amount.change,
|
||||||
|
self.total_amount.change,
|
||||||
|
],
|
||||||
|
fn=integrity_checker,
|
||||||
|
inputs=item_amounts
|
||||||
|
+ [
|
||||||
|
self.tip_amount,
|
||||||
|
self.tax_amount,
|
||||||
|
self.total_amount,
|
||||||
|
],
|
||||||
|
outputs=[self.integrity_markdown],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.display_result = gr.DataFrame(value=None, visible=False)
|
||||||
|
|
||||||
|
self.spinner_html = gr.HTML(
|
||||||
|
spinner_html,
|
||||||
|
visible=False,
|
||||||
|
padding=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.image_uploader.upload(
|
||||||
|
lambda: gr.HTML(visible=True), inputs=None, outputs=self.spinner_html
|
||||||
|
).then(
|
||||||
|
self.process_image,
|
||||||
|
inputs=[self.image_uploader, self.items],
|
||||||
|
outputs=[
|
||||||
|
self.merchant,
|
||||||
|
self.receipt_date,
|
||||||
|
self.total_amount,
|
||||||
|
self.tip_amount,
|
||||||
|
self.tax_amount,
|
||||||
|
self.items,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
lambda: gr.HTML(visible=False), inputs=None, outputs=self.spinner_html
|
||||||
|
)
|
||||||
|
|
||||||
|
return split_app
|
||||||
|
|
||||||
|
def process_image(
|
||||||
|
self,
|
||||||
|
image: Image,
|
||||||
|
items: gr.State,
|
||||||
|
): # -> gr.State:
|
||||||
|
# gr.Info("Running OCR on image of receipt.")
|
||||||
|
# receipt_string = self.receipt_reader.get_ordered_text(image)
|
||||||
|
# gr.Info("Extracting components. Please be patient.")
|
||||||
|
# receipt_extracted = self.receipt_reader.extract_components(receipt_string)
|
||||||
|
receipt_extracted = {
|
||||||
|
"merchant": "FUBAR",
|
||||||
|
"receipt_date": datetime.now(),
|
||||||
|
"total": {"amount": 15},
|
||||||
|
"tip": {"amount": 0},
|
||||||
|
"tax": {"amount": 3},
|
||||||
|
"item_amounts": [
|
||||||
|
{"name": "PET TOY", "currency": "$", "amount": 2},
|
||||||
|
{"name": "FLOPPY PUPPY", "currency": "$", "amount": 4},
|
||||||
|
{"name": "SSSUPREME S", "currency": "$", "amount": 6},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
key_value_updates = [
|
||||||
|
{
|
||||||
|
"component": self.merchant,
|
||||||
|
"kwargs": {"value": receipt_extracted["merchant"], "visible": True},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"component": self.receipt_date,
|
||||||
|
"kwargs": {
|
||||||
|
"value": receipt_extracted["receipt_date"],
|
||||||
|
"visible": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"component": self.total_amount,
|
||||||
|
"kwargs": {
|
||||||
|
"value": receipt_extracted["total"]["amount"],
|
||||||
|
"visible": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"component": self.tip_amount,
|
||||||
|
"kwargs": {
|
||||||
|
"value": receipt_extracted["tip"]["amount"],
|
||||||
|
"visible": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"component": self.tax_amount,
|
||||||
|
"kwargs": {
|
||||||
|
"value": receipt_extracted["tax"]["amount"],
|
||||||
|
"visible": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
out = [
|
||||||
|
self.update_component_attributes(x["component"], **x["kwargs"])
|
||||||
|
for x in key_value_updates
|
||||||
|
]
|
||||||
|
items += [
|
||||||
|
{"name": x["name"], "amount": x["amount"]}
|
||||||
|
for x in receipt_extracted["item_amounts"]
|
||||||
|
]
|
||||||
|
out += [items]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def launch(self, expose_to_local_network: bool = False):
|
||||||
|
if expose_to_local_network:
|
||||||
|
self.demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
||||||
|
else:
|
||||||
|
self.demo.queue().launch()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo = SplitAIApp("qwen2.5")
|
||||||
|
demo.launch(True)
|
||||||
22
src/app/split_ai.py
Normal file
22
src/app/split_ai.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from src.llm.ollama import LLMExtractor, ReceiptExtracted
|
||||||
|
from src.ocr.surya import SuryaOCR
|
||||||
|
from PIL.Image import Image
|
||||||
|
from PIL.Image import open as PIL_open
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptReader:
|
||||||
|
def __init__(self, llm_model: str):
|
||||||
|
self.ocr_model = SuryaOCR()
|
||||||
|
self.llm_model = LLMExtractor(model=llm_model)
|
||||||
|
|
||||||
|
def get_ordered_text(self, image: Image) -> str:
|
||||||
|
return self.ocr_model.ordered_ocr_text(image)
|
||||||
|
|
||||||
|
def extract_components(self, receipt_string: str) -> ReceiptExtracted:
|
||||||
|
return self.llm_model.forward(receipt_string, enable_alerts=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
image = PIL_open("data/pytest/image_upload_test.png")
|
||||||
|
sr = ReceiptReader()
|
||||||
|
print(sr.extract_components(image))
|
||||||
52
src/app/utils.py
Normal file
52
src/app/utils.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
head_html = """
|
||||||
|
<meta name="viewport"
|
||||||
|
content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
|
||||||
|
"""
|
||||||
|
|
||||||
|
css_code = """
|
||||||
|
input {
|
||||||
|
font-size: 16px !important;
|
||||||
|
}
|
||||||
|
input[type="number"] {
|
||||||
|
inputmode: numeric;
|
||||||
|
pattern: "[0-9]*";
|
||||||
|
}
|
||||||
|
#fullwidth-checkgroup .gr-checkgroup {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
footer {
|
||||||
|
display: none !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loader {
|
||||||
|
width: 48px;
|
||||||
|
height: 48px;
|
||||||
|
border: 5px solid #FFF;
|
||||||
|
border-bottom-color: transparent;
|
||||||
|
border-radius: 50%;
|
||||||
|
display: inline-block;
|
||||||
|
box-sizing: border-box;
|
||||||
|
animation: rotation 1s linear infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes rotation {
|
||||||
|
0% {
|
||||||
|
transform: rotate(0deg);
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: rotate(360deg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
spinner_html = """
|
||||||
|
<div style="
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
height: 10vh;
|
||||||
|
width: 100%;
|
||||||
|
">
|
||||||
|
<span class='loader'></span>
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
196
src/llm/ollama.py
Normal file
196
src/llm/ollama.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
import gradio as gr
|
||||||
|
from ollama import chat
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import TypedDict, Optional, Literal, Any, Type, Annotated
|
||||||
|
|
||||||
|
|
||||||
|
class Message(TypedDict):
|
||||||
|
role: Literal["user", "assistant", "system", "tool"]
|
||||||
|
content: Optional[str]
|
||||||
|
images: Optional[list[str] | bytes | list[Path]]
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptMerchant(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class ReceiptDate(BaseModel):
|
||||||
|
date: datetime
|
||||||
|
|
||||||
|
def is_valid_currency_code(value: str) -> bool:
|
||||||
|
return len(value) == 3 and value.isupper() and value.isalpha()
|
||||||
|
|
||||||
|
class ReceiptAmount(BaseModel):
|
||||||
|
currency: Annotated[str, is_valid_currency_code]
|
||||||
|
amount: float
|
||||||
|
|
||||||
|
class Amount(TypedDict):
|
||||||
|
currency: Annotated[str, is_valid_currency_code]
|
||||||
|
amount: float
|
||||||
|
|
||||||
|
class ReceiptItemized(BaseModel):
|
||||||
|
class ReceiptLineItemAmount(BaseModel):
|
||||||
|
name: str
|
||||||
|
currency: Annotated[str, is_valid_currency_code]
|
||||||
|
amount: float
|
||||||
|
ItemizedReceipt: list[ReceiptLineItemAmount]
|
||||||
|
|
||||||
|
class ItemizedAmounts(TypedDict):
|
||||||
|
name: str
|
||||||
|
currency: Annotated[str, is_valid_currency_code]
|
||||||
|
amount: float
|
||||||
|
|
||||||
|
class ReceiptExtracted(TypedDict):
|
||||||
|
merchant: str
|
||||||
|
receipt_date: datetime
|
||||||
|
total_amount: Amount
|
||||||
|
tip_amount: Amount
|
||||||
|
tax_amount: Amount
|
||||||
|
item_amounts: list[ItemizedAmounts]
|
||||||
|
|
||||||
|
class LLMExtractor:
|
||||||
|
def __init__(self, model: str, prompt_path: str | Path = "./src/llm/prompts.yaml"):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A model name that has been downloaded by ``ollama``.
|
||||||
|
prompt_path: A str or Path to a YAML file with various prompts.
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.prompt_path = Path(prompt_path)
|
||||||
|
with open(self.prompt_path, "r") as f:
|
||||||
|
self.prompts = yaml.safe_load(f)
|
||||||
|
|
||||||
|
def get_chat_response(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
structured_output_format: Type[BaseModel] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The conversation so far.
|
||||||
|
structured_output_format: The dictionary format of the pydantic model schema. Can be ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Depends on the model chosen, and if `structured_output_format` is provided.
|
||||||
|
"""
|
||||||
|
response = chat(
|
||||||
|
model=self.model, messages=messages, format=structured_output_format.model_json_schema(),
|
||||||
|
options={"temperature": 0}
|
||||||
|
)
|
||||||
|
message = response.message.content
|
||||||
|
if structured_output_format is not None:
|
||||||
|
return json.loads(message)
|
||||||
|
else:
|
||||||
|
return message
|
||||||
|
|
||||||
|
def load_substituted_prompt(self, prompt_name: str, **kwargs) -> str:
|
||||||
|
prompt = self.prompts[prompt_name]
|
||||||
|
for patt, repl in kwargs.items():
|
||||||
|
prompt = prompt.replace(f"[[ {patt} ]]", repl)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def extract_merchant_name(self, receipt_string: str) -> str:
|
||||||
|
messages: list[Message] = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.load_substituted_prompt(
|
||||||
|
"extract_merchant_name", **{"receipt_string": receipt_string}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return self.get_chat_response(messages, ReceiptMerchant)['name']
|
||||||
|
|
||||||
|
def extract_receipt_date(self, receipt_string: str) -> datetime:
|
||||||
|
messages: list[Message] = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.load_substituted_prompt(
|
||||||
|
"extract_receipt_date", **{"receipt_string": receipt_string}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
receipt_date = self.get_chat_response(messages, ReceiptDate)['date']
|
||||||
|
return datetime.fromisoformat(receipt_date)
|
||||||
|
|
||||||
|
def extract_receipt_total_amount(self, receipt_string: str) -> Amount:
|
||||||
|
messages: list[Message] = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.load_substituted_prompt(
|
||||||
|
"extract_receipt_total_amount", **{"receipt_string": receipt_string}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
currency_amount = self.get_chat_response(messages, ReceiptAmount)
|
||||||
|
return currency_amount
|
||||||
|
|
||||||
|
def extract_receipt_tax_amount(self, receipt_string: str) -> Amount:
|
||||||
|
messages: list[Message] = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.load_substituted_prompt(
|
||||||
|
"extract_receipt_tax_amount", **{"receipt_string": receipt_string}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
currency_amount = self.get_chat_response(messages, ReceiptAmount)
|
||||||
|
return currency_amount
|
||||||
|
|
||||||
|
def extract_receipt_tip_amount(self, receipt_string: str) -> Amount:
|
||||||
|
messages: list[Message] = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.load_substituted_prompt(
|
||||||
|
"extract_receipt_tip_amount", **{"receipt_string": receipt_string}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
currency_amount = self.get_chat_response(messages, ReceiptAmount)
|
||||||
|
return currency_amount
|
||||||
|
|
||||||
|
def extract_receipt_items(self, receipt_string: str) -> list[ItemizedAmounts]:
|
||||||
|
messages: list[Message] = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.load_substituted_prompt(
|
||||||
|
"extract_receipt_items", **{"receipt_string": receipt_string}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
itemized_amounts = self.get_chat_response(messages, ReceiptItemized)['ItemizedReceipt']
|
||||||
|
return itemized_amounts
|
||||||
|
|
||||||
|
def forward(self, receipt_string, enable_alerts:bool=False) -> ReceiptExtracted:
|
||||||
|
merchant = self.extract_merchant_name(receipt_string)
|
||||||
|
# gr.Info("Extracted merchant name.")
|
||||||
|
receipt_date = self.extract_receipt_date(receipt_string)
|
||||||
|
# gr.Info("Extracted receipt date.")
|
||||||
|
total_amount = self.extract_receipt_total_amount(receipt_string)
|
||||||
|
# gr.Info("Extracted total amount.")
|
||||||
|
tip_amount = self.extract_receipt_tip_amount(receipt_string)
|
||||||
|
# gr.Info("Extracted tip amount.")
|
||||||
|
tax_amount = self.extract_receipt_tax_amount(receipt_string)
|
||||||
|
# gr.Info("Extracted tax amount. Extracting individual items - this will take time.")
|
||||||
|
item_amounts = self.extract_receipt_items(receipt_string)
|
||||||
|
# gr.Info("Extracted individual items. Populating items now!")
|
||||||
|
return {
|
||||||
|
"merchant": merchant,
|
||||||
|
"receipt_date": receipt_date,
|
||||||
|
"total": total_amount,
|
||||||
|
"tip": tip_amount,
|
||||||
|
"tax": tax_amount,
|
||||||
|
"item_amounts": item_amounts
|
||||||
|
}
|
||||||
|
# return {
|
||||||
|
# "merchant": self.extract_merchant_name(receipt_string),
|
||||||
|
# "receipt_date": self.extract_receipt_date(receipt_string),
|
||||||
|
# "total": self.extract_receipt_total_amount(receipt_string),
|
||||||
|
# "tip": self.extract_receipt_tip_amount(receipt_string),
|
||||||
|
# "tax": self.extract_receipt_tax_amount(receipt_string),
|
||||||
|
# "item_amounts": self.extract_receipt_items(receipt_string)
|
||||||
|
# }
|
||||||
89
src/llm/prompts.yaml
Normal file
89
src/llm/prompts.yaml
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
extract_merchant_name: >
|
||||||
|
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
|
||||||
|
|
||||||
|
You are trying to identify the merchant name.
|
||||||
|
For restaurants, this will be the restaurant name.
|
||||||
|
For a grocery store or supermarket, this will be the supermarket's name.
|
||||||
|
For other retailers, it will be the retailer's name.
|
||||||
|
If you can't find a name, provide `null` as the response.
|
||||||
|
|
||||||
|
The receipt string is provided below:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[ receipt_string ]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Extract the name of the merchant from this receipt using your knowledge of the industry and the receipt itself.
|
||||||
|
|
||||||
|
extract_receipt_date: >
|
||||||
|
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
|
||||||
|
|
||||||
|
You are trying to identify the date (in YYYY-MM-DD format) on the receipt.
|
||||||
|
If you can't find a date, provide `null` as the response.
|
||||||
|
|
||||||
|
The receipt string is provided below:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[ receipt_string ]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Extract the date of this receipt using your knowledge of the industry and the receipt itself.
|
||||||
|
|
||||||
|
extract_receipt_total_amount: >
|
||||||
|
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
|
||||||
|
|
||||||
|
You are trying to identify the total amount on the receipt, as well as the three letter currency code (ISO 4217).
|
||||||
|
If you cannot find the total amount, provide '0' as the total amount.
|
||||||
|
|
||||||
|
The receipt string is provided below:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[ receipt_string ]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Extract the currency and amount of the total in this receipt using your knowledge of the industry
|
||||||
|
and the receipt itself.
|
||||||
|
|
||||||
|
extract_receipt_tax_amount: >
|
||||||
|
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
|
||||||
|
|
||||||
|
You are trying to identify the overall tax amount on the receipt, as well as the three letter currency code (ISO 4217).
|
||||||
|
If you cannot find the tax amount, provide '0' for the tax amount.
|
||||||
|
|
||||||
|
The receipt string is provided below:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[ receipt_string ]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Extract the currency and amount of the overall tax in this receipt using your knowledge of the industry
|
||||||
|
and the receipt itself.
|
||||||
|
|
||||||
|
extract_receipt_tip_amount: >
|
||||||
|
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
|
||||||
|
|
||||||
|
You are trying to identify the tip paid (including gratuity, but not tax) on the receipt, as well as the three letter
|
||||||
|
currency code (ISO 4217). If there is no tip, provide '0' as the tip.
|
||||||
|
|
||||||
|
The receipt string is provided below:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[ receipt_string ]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Extract the currency and amount of the overall tip in this receipt using your knowledge of the industry
|
||||||
|
and the receipt itself.
|
||||||
|
|
||||||
|
extract_receipt_items: >
|
||||||
|
You are an experience cashier who is trying to make sense of a receipt string that a computing system has run OCR on.
|
||||||
|
|
||||||
|
You are trying to identify each of the items present, and the amount paid for those items.
|
||||||
|
|
||||||
|
The receipt string is provided below:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[ receipt_string ]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Extract a list of items by their name, currency (3 letter code) and amount paid for it using your knowledge of the industry
|
||||||
|
and the receipt itself.
|
||||||
185
src/ocr/surya.py
Normal file
185
src/ocr/surya.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from PIL.Image import open as PIL_open
|
||||||
|
from PIL.Image import Image
|
||||||
|
from surya.ocr import run_ocr
|
||||||
|
from surya.model.detection.model import (
|
||||||
|
load_model as load_det_model,
|
||||||
|
load_processor as load_det_processor,
|
||||||
|
)
|
||||||
|
from surya.model.recognition.model import load_model as load_rec_model
|
||||||
|
from surya.model.recognition.processor import load_processor as load_rec_processor
|
||||||
|
from typing import Literal, TypedDict, overload, Optional
|
||||||
|
from collections import defaultdict
|
||||||
|
from more_itertools import bucket
|
||||||
|
|
||||||
|
Polygon = list[list[float]]
|
||||||
|
BBox = list[float]
|
||||||
|
|
||||||
|
|
||||||
|
class SuryaOCRResult(TypedDict):
|
||||||
|
polygon: Polygon
|
||||||
|
confidence: Optional[float]
|
||||||
|
text: str
|
||||||
|
bbox: BBox
|
||||||
|
|
||||||
|
|
||||||
|
class SplitAIOCRResult(TypedDict):
|
||||||
|
polygons: list[Polygon]
|
||||||
|
texts: list[str]
|
||||||
|
bboxes: list[BBox]
|
||||||
|
|
||||||
|
|
||||||
|
class SuryaOCR:
|
||||||
|
def __init__(self):
|
||||||
|
self.det_processor = load_det_processor()
|
||||||
|
self.det_model = load_det_model()
|
||||||
|
self.rec_model = load_rec_model()
|
||||||
|
self.rec_processor = load_rec_processor()
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def ocr_image(self, image: Image, return_format: Literal["polygon"]) -> list[Polygon]: ...
|
||||||
|
@overload
|
||||||
|
def ocr_image(self, image: Image, return_format: Literal["text"]) -> list[str]: ...
|
||||||
|
@overload
|
||||||
|
def ocr_image(self, image: Image, return_format: Literal["bbox"]) -> list[BBox]: ...
|
||||||
|
@overload
|
||||||
|
def ocr_image(self, image: Image, return_format: Literal["confidence"]) -> list[BBox]: ...
|
||||||
|
@overload
|
||||||
|
def ocr_image(self, image: Image, return_format: None) -> SplitAIOCRResult: ...
|
||||||
|
def ocr_image(
|
||||||
|
self,
|
||||||
|
image: Image,
|
||||||
|
return_format: Literal["polygon", "text", "bbox", "confidence"] | None,
|
||||||
|
) -> SplitAIOCRResult | list[Polygon] | list[str] | list[BBox]:
|
||||||
|
"""
|
||||||
|
Specify either a path to an image or image file itself (as a PIL image) to run OCR on.
|
||||||
|
If both are provided, then the image file is prioritized.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: the PIL image for Surya to process
|
||||||
|
return_format: Specify one of the allowed values to return only this key from SuryaOCR's output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OCRResult, in Surya's format. See `<https://github.com/VikParuchuri/surya>` for details.
|
||||||
|
"""
|
||||||
|
images = [image]
|
||||||
|
langs = [["en"]]
|
||||||
|
|
||||||
|
ocr_output = run_ocr(
|
||||||
|
images,
|
||||||
|
langs,
|
||||||
|
self.det_model,
|
||||||
|
self.det_processor,
|
||||||
|
self.rec_model,
|
||||||
|
self.rec_processor,
|
||||||
|
)
|
||||||
|
ocr_as_list: list[SuryaOCRResult] = ocr_output[0].model_dump()["text_lines"]
|
||||||
|
if isinstance(return_format, str):
|
||||||
|
return [x[return_format] for x in ocr_as_list]
|
||||||
|
else:
|
||||||
|
polygons = [x["polygon"] for x in ocr_as_list]
|
||||||
|
texts = [x["text"] for x in ocr_as_list]
|
||||||
|
bboxes = [x["bbox"] for x in ocr_as_list]
|
||||||
|
return {
|
||||||
|
"polygons": polygons,
|
||||||
|
"texts": texts,
|
||||||
|
"bboxes": bboxes,
|
||||||
|
}
|
||||||
|
|
||||||
|
def ordered_ocr_text(self, image: Image) -> str:
|
||||||
|
split_ocr_result = self.ocr_image(image, None)
|
||||||
|
ordered_text: list[str] = []
|
||||||
|
for group in SuryaOCR.order_polygons(split_ocr_result["polygons"]):
|
||||||
|
line = [split_ocr_result["texts"][x] for x in group]
|
||||||
|
ordered_text += [" ".join(line)]
|
||||||
|
return "\n".join(ordered_text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_centroid_of_bounding_polygon(
|
||||||
|
polygon: Polygon,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Function name is self-explanatory. Invariant of rotation. Any bounding polygon whose centroid differs
|
||||||
|
significantly to another one is unlikely to be in the same line.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
polygon: `[[x_1, y_1], [x_2, y_2], [x_3, y_3], [x_4, y_4]]` coordinates of
|
||||||
|
the polygon, clockwise from top left.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The centroid of the bounding polygon.
|
||||||
|
"""
|
||||||
|
(x_1, y_1), (x_2, y_2), (x_3, y_3), (x_4, y_4) = tuple(
|
||||||
|
tuple(x) for x in polygon
|
||||||
|
)
|
||||||
|
|
||||||
|
return (x_1 + x_4) / 2, (y_1 + y_2) / 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_height_of_bounding_polygon(polygon: Polygon) -> float:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
polygon: `[[x_1, y_1], [x_2, y_2], [x_3, y_3], [x_4, y_4]]` coordinates of
|
||||||
|
the polygon, clockwise from top left.
|
||||||
|
Returns:
|
||||||
|
The height of the bounding polygon.
|
||||||
|
"""
|
||||||
|
(x_1, y_1), (x_2, y_2), (x_3, y_3), (x_4, y_4) = tuple(
|
||||||
|
tuple(x) for x in polygon
|
||||||
|
)
|
||||||
|
|
||||||
|
return y_4 - y_1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def order_polygons(polygons: list[Polygon]) -> list[list[int]]:
|
||||||
|
"""
|
||||||
|
Given a list of bounding polygons (from an OCR framework) for recognized text, attempt
|
||||||
|
to reorder them (left-to-right, top-to-bottom) that their text contents are expected to
|
||||||
|
be read.
|
||||||
|
|
||||||
|
This is intended to work independent of the orientation of the receipts, but currently
|
||||||
|
is at a PoC stage where it is assumed that the receipt is horizontally positioned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
polygons: Bounding polygons of recognized text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``list[int]`` with the estimated line numbers.
|
||||||
|
"""
|
||||||
|
x_midpoints, y_midpoints = zip(
|
||||||
|
*[SuryaOCR.get_centroid_of_bounding_polygon(x) for x in polygons]
|
||||||
|
)
|
||||||
|
heights = [SuryaOCR.get_height_of_bounding_polygon(x) for x in polygons]
|
||||||
|
threshold = 0.6
|
||||||
|
y_ranges = [
|
||||||
|
(
|
||||||
|
midpoint - height * threshold * 0.5,
|
||||||
|
midpoint,
|
||||||
|
midpoint + height * threshold * 0.5,
|
||||||
|
)
|
||||||
|
for midpoint, height in zip(y_midpoints, heights)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assign line groups
|
||||||
|
line_groups: dict[int, int | None] = defaultdict(lambda: None)
|
||||||
|
for idx, i_range in enumerate(y_ranges):
|
||||||
|
if idx not in line_groups:
|
||||||
|
line_groups[idx] = idx
|
||||||
|
for jdx, j_range in [
|
||||||
|
(jdx, j_range) for jdx, j_range in enumerate(y_ranges) if jdx > idx
|
||||||
|
]:
|
||||||
|
if (i_range[0] <= j_range[1] <= i_range[2]) and jdx not in line_groups:
|
||||||
|
line_groups[jdx] = idx
|
||||||
|
|
||||||
|
# Reorder by x_midpoints within a group:
|
||||||
|
line_groups_reversed = [(group, key) for key, group in line_groups.items()]
|
||||||
|
bucketed_line_groups = bucket(line_groups_reversed, key=lambda x: x[0])
|
||||||
|
ordered_text = []
|
||||||
|
for key in sorted(list(bucketed_line_groups)):
|
||||||
|
idx_group = list(bucketed_line_groups[key])
|
||||||
|
sorted_list = sorted(idx_group, key=lambda x: x_midpoints[x[1]])
|
||||||
|
sorted_idx = [x[1] for x in sorted_list]
|
||||||
|
ordered_text += [sorted_idx]
|
||||||
|
|
||||||
|
return ordered_text
|
||||||
54
tests/test_ollama.py
Normal file
54
tests/test_ollama.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from src.llm.ollama import LLMExtractor
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_response():
|
||||||
|
messages: list[Message] = [{"role": "user", "content": "Hello"}]
|
||||||
|
output = get_chat_response(model="gemma2:27b", messages=messages)
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_structured_response():
|
||||||
|
class ModelData(BaseModel):
|
||||||
|
name: str
|
||||||
|
knowledge_cutoff: date
|
||||||
|
|
||||||
|
messages: list[Message] = [{"role": "user", "content": "Tell me about yourself"}]
|
||||||
|
output = get_chat_response(
|
||||||
|
model="gemma2:27b",
|
||||||
|
messages=messages,
|
||||||
|
structured_output_format=ModelData.model_json_schema(),
|
||||||
|
)
|
||||||
|
assert ModelData.model_validate(output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_image_upload():
|
||||||
|
messages: list[Message] = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Run OCR on this image.",
|
||||||
|
"images": ["./data/pytest/image_upload_test.png"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
output = get_chat_response(
|
||||||
|
model="llava:34b",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
print(output)
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_extractor():
|
||||||
|
llm_extractor = LLMExtractor(model="gemma2:27b")
|
||||||
|
with open("data/pytest/receipt_ocr_test.txt", "r") as f:
|
||||||
|
receipt_ocr_text = "".join(f.readlines())
|
||||||
|
assert llm_extractor.extract_merchant_name(receipt_ocr_text) == 'Walmart'
|
||||||
|
assert llm_extractor.extract_receipt_date(receipt_ocr_text) == date.fromisoformat('2017-07-28')
|
||||||
|
assert llm_extractor.extract_receipt_total_amount(receipt_ocr_text) == {'currency': 'USD', 'amount': 98.21}
|
||||||
|
assert llm_extractor.extract_receipt_tip_amount(receipt_ocr_text) == {'currency': 'USD', 'amount': 0}
|
||||||
|
assert llm_extractor.extract_receipt_tax_amount(receipt_ocr_text) == {'currency': 'USD', 'amount': 4.59}
|
||||||
|
items = llm_extractor.extract_receipt_items(receipt_ocr_text)
|
||||||
|
# The exact items in these cannot be always identical as there will be some inherent variance in LLM output.
|
||||||
|
# Thus, we check if a few control totals match, at least approximately.
|
||||||
|
assert len(items) == 26
|
||||||
|
assert round(sum(x["amount"] for x in items), 0) == 94
|
||||||
7
tests/test_surya.py
Normal file
7
tests/test_surya.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from src.ocr.surya import SuryaOCR
|
||||||
|
|
||||||
|
|
||||||
|
def test_ocr_sample_image():
|
||||||
|
image_path = "data/pytest/image_upload_test.png"
|
||||||
|
output = ocr_image(image_path=image_path)
|
||||||
|
assert output is not None
|
||||||
Reference in New Issue
Block a user