| """
|
| This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
|
| - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
|
| - https://github.com/openai/prm800k
|
| """
|
| import multiprocessing
|
| from math import isclose
|
| from typing import Union
|
|
|
| from sympy import simplify, N
|
| from sympy.parsing.sympy_parser import parse_expr
|
| from sympy.parsing.latex import parse_latex
|
|
|
|
|
| def is_digit(s):
|
| try:
|
| float(str(s).replace(",", ""))
|
| return True
|
| except ValueError:
|
| return False
|
|
|
| def math_equal(prediction: Union[bool, float, str],
|
| reference: Union[float, str],
|
| include_percentage: bool = True,
|
| is_close: bool = True,
|
| timeout: bool = False,
|
| ) -> bool:
|
| """
|
| Exact match of math if and only if:
|
| 1. numerical equal: both can convert to float and are equal
|
| 2. symbolic equal: both can convert to sympy expression and are equal
|
| """
|
| try:
|
| if is_digit(prediction) and is_digit(reference):
|
| prediction = float(str(prediction).replace(",", ""))
|
| reference = float(str(reference).replace(",", ""))
|
|
|
| if include_percentage:
|
| gt_result = [reference / 100, reference, reference * 100]
|
| else:
|
| gt_result = [reference]
|
| for item in gt_result:
|
| try:
|
| if is_close:
|
| if isclose(item, prediction, rel_tol=1e-4):
|
| return True
|
| else:
|
| if item == prediction:
|
| return True
|
| except Exception:
|
| continue
|
| return False
|
| except:
|
| pass
|
|
|
| if not prediction and prediction not in [0, False]:
|
| return False
|
|
|
|
|
| reference = str(reference).strip()
|
| prediction = str(prediction).strip()
|
|
|
|
|
| pred_str, ref_str = prediction, reference
|
| if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \
|
| (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
|
| pred_str = pred_str.strip("[]()")
|
| ref_str = ref_str.strip("[]()")
|
| for s in ['{', "}", "(", ")"]:
|
| ref_str = ref_str.replace(s, "")
|
| pred_str = pred_str.replace(s, "")
|
| if pred_str == ref_str:
|
| return True
|
|
|
|
|
| if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \
|
| (prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")):
|
| pred_parts = prediction[1:-1].split(",")
|
| ref_parts = reference[1:-1].split(",")
|
| if len(pred_parts) == len(ref_parts):
|
| if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
|
| return True
|
|
|
|
|
| if timeout:
|
| if call_with_timeout(symbolic_equal_process, prediction, reference):
|
| return True
|
| else:
|
| if symbolic_equal(prediction, reference):
|
| return True
|
|
|
| return False
|
|
|
|
|
| def math_equal_process(param):
|
| print(param[-2], param[-1],math_equal(param[-2], param[-1]))
|
| return math_equal(param[-2], param[-1])
|
|
|
|
|
| def symbolic_equal(a, b):
|
| def _parse(s):
|
| for f in [parse_latex, parse_expr]:
|
| try:
|
| return f(s)
|
| except:
|
| pass
|
| return s
|
| a = _parse(a)
|
| b = _parse(b)
|
|
|
| try:
|
| if simplify(a-b) == 0:
|
| return True
|
| except:
|
| pass
|
|
|
| try:
|
| if isclose(N(a), N(b), rel_tol=1e-3):
|
| return True
|
| except:
|
| pass
|
| return False
|
|
|
|
|
| def symbolic_equal_process(a, b, output_queue):
|
| result = symbolic_equal(a, b)
|
| output_queue.put(result)
|
|
|
|
|
| def call_with_timeout(func, *args, timeout=1, **kwargs):
|
| output_queue = multiprocessing.Queue()
|
| process_args = args + (output_queue,)
|
| process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
| process.start()
|
| process.join(timeout)
|
|
|
| if process.is_alive():
|
| process.terminate()
|
| process.join()
|
| return False
|
|
|
| return output_queue.get()
|
|
|
|
|
| def _test_math_equal():
|
|
|
|
|
| print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
|
|
|
| if __name__ == "__main__":
|
| _test_math_equal()
|
|
|