File size: 8,760 Bytes
c28dddb
7847fd8
 
 
 
 
c28dddb
7847fd8
 
 
c28dddb
7847fd8
 
 
c28dddb
7847fd8
c28dddb
 
 
 
 
 
7847fd8
 
 
 
 
 
 
 
 
 
 
 
c28dddb
 
 
 
7847fd8
 
c28dddb
 
 
7847fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684bfaa
 
 
 
 
 
 
 
 
 
 
 
 
 
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7847fd8
 
 
 
 
 
 
 
 
 
c28dddb
7847fd8
c28dddb
 
 
 
 
 
684bfaa
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7847fd8
 
 
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import argparse
import base64
import json
import os
import re
import sys
from io import BytesIO
from typing import Optional

import json_repair
from openai import AzureOpenAI
from PIL import Image

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from scripts.graph_pred.prompt_workflow_new import messages

# Initialize the OpenAI client

endpoint = os.environ.get("ENDPOINT")
api_key = os.environ.get("API_KEY")
api_version = os.environ.get("API_VERSION")
model_name = os.environ.get("MODEL_NAME")
DEFAULT_GPT_TIMEOUT = float(os.environ.get("GPT_TIMEOUT", 120))
# GPT-5.x counts reasoning tokens against this cap, so it must leave
# enough room for both reasoning and the visible reply.
GPT5_DEFAULT_MAX_COMPLETION_TOKENS = 8192
GPT5_UNSUPPORTED_PARAMS = (
    "temperature",
    "top_p",
    "frequency_penalty",
    "presence_penalty",
    "stop",
)

client = AzureOpenAI(
    azure_endpoint=endpoint,
    api_key=api_key,
    api_version=api_version,
    timeout=DEFAULT_GPT_TIMEOUT,
    max_retries=0,
)


def is_gpt5_model(model: Optional[str]) -> bool:
    name = (model or "").lower()
    return "gpt-5" in name or "gpt5" in name


def prepare_chat_completion_payload(payload: dict) -> dict:
    """Convert chat completion kwargs for GPT-5.x compatibility."""
    payload = dict(payload)
    if not is_gpt5_model(payload.get("model")):
        return payload

    max_tokens = payload.pop("max_tokens", None)
    if "max_completion_tokens" not in payload:
        payload["max_completion_tokens"] = max(
            max_tokens or 0,
            GPT5_DEFAULT_MAX_COMPLETION_TOKENS,
        )

    for key in GPT5_UNSUPPORTED_PARAMS:
        payload.pop(key, None)
    return payload


def parse_graph_response(content: str) -> dict:
    """Parse a graph JSON response with or without markdown fences."""
    if not content:
        raise ValueError("GPT response is empty.")

    match = re.search(r"```(?:json)?\s*(.*?)\s*```", content, re.DOTALL)
    json_text = match.group(1) if match else content
    try:
        return json_repair.loads(json_text)
    except Exception as exc:
        preview = content[:500].replace("\n", "\\n")
        raise ValueError(f"Failed to parse GPT graph response: {preview}") from exc


def encode_image(image_path: str, center_crop=False):
    """Resize and encode the image as base64"""
    # load the image
    image = Image.open(image_path)

    # resize the image to 224x224
    if center_crop: # (resize to 256x256 and then center crop to 224x224)
        image = image.resize((256, 256))
        width, height = image.size
        left = (width - 224) / 2
        top = (height - 224) / 2
        right = (width + 224) / 2
        bottom = (height + 224) / 2
        image = image.crop((left, top, right, bottom))
    else:
        image = image.resize((224, 224))

    # conver the image to bytes
    buffer = BytesIO()
    image.save(buffer, format="PNG")
    buffer.seek(0)
    # encode the image as base64
    encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
    return encoded_image

def display_image(image_data):
    """Display the image from the base64 encoded image data"""
    img = Image.open(BytesIO(base64.b64decode(image_data)))
    img.show()
    img.close()


def convert_format(src):
    '''Convert the JSON format from the response to a tree format'''
    def _sort_nodes(tree):
        num_nodes = len(tree)
        sorted_tree = [dict() for _ in range(num_nodes)]
        for node in tree:
            sorted_tree[node["id"]] = node
        return sorted_tree

    def _traverse(node, parent_id, current_id):
        for key, value in node.items():
            node_id = current_id[0]
            current_id[0] += 1

            # Create the node
            tree_node = {
                "id": node_id,
                "parent": parent_id,
                "name": key,
                "children": [],
            }

            # Traverse children if they exist
            if isinstance(value, list):
                for child in value:
                    child_id = _traverse(child, node_id, current_id)
                    tree_node["children"].append(child_id)

            # Add this node to the tree
            tree.append(tree_node)
            return node_id

    tree = []
    current_id = [0]
    _traverse(src, -1, current_id)
    diffuse_tree = _sort_nodes(tree)
    return diffuse_tree

def predict_graph_twomode(image_path, first_img_data=None, second_img_data=None, debug=False, center_crop=False):
    '''Predict the part connectivity graph from the image'''
    # Encode the image
    if first_img_data is None or second_img_data is None:
        first_img_data = encode_image(image_path, center_crop)
        second_img_data = encode_image(image_path.replace('close', 'open'), center_crop)
    # if debug:
    #     display_image(image_data) # for double checking the image
    #     breakpoint()
    new_message = messages.copy()
    new_message.append(
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{first_img_data}"},
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{second_img_data}"},
                    }
                ],
            },
    )
    # Get the completion from the model
    payload = {
        "model": model_name,
        "messages": new_message,
        "response_format": {"type": "text"},
        "temperature": 1,
        "max_tokens": 4096,
        "top_p": 1,
        "frequency_penalty": 0,
        "presence_penalty": 0,
    }
    completion = client.chat.completions.create(
        **prepare_chat_completion_payload(payload)
    )
    print('processing the response...')

    # Extract the response
    content = completion.choices[0].message.content

    src = parse_graph_response(content)
    print(src)
    # Convert the JSON format to tree format
    diffuse_tree = convert_format(src)

    return {"diffuse_tree": diffuse_tree, "original_response": content}

def save_response(save_path, response):
    '''Save the response to a json file'''
    with open(save_path, "w") as file:
        json.dump(response, file, indent=4)



def gpt_infer_image_category(image1, image2):
    system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties."

    text_prompt = (
        "Given two images of an object, determine its category. "
        "The category must be one of the following: Table, Dishwasher, StorageFurniture, "
        "Refrigerator, WashingMachine, Microwave, Oven. "
        "Output only the category name and nothing else. Do not include any other text."
    )

    content_user = [
        {
            "type": "text",
            "text": text_prompt,
        },
        {
            "type": "image_url",
            "image_url": {"url": f"data:image/png;base64,{image1}"},
        },
        {
            "type": "image_url",
            "image_url": {"url": f"data:image/png;base64,{image2}"},
        },
    ]
    payload = {
        "messages": [
            {"role": "system", "content": system_role},
            {"role": "user", "content": content_user},
        ],
        "temperature": 0.1,
        "max_tokens": 500,
        "top_p": 0.1,
        "frequency_penalty": 0,
        "presence_penalty": 0,
        "stop": None,
        "model": model_name,
    }
    completion = client.chat.completions.create(
        **prepare_chat_completion_payload(payload)
    )
    response = completion.choices[0].message.content
    json_repair.loads(response)

    return response


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Predict the part connectivity graph from an image")
    parser.add_argument("--img_path", type=str, required=True, help="path to the image")
    parser.add_argument("--save_path", type=str, required=True, help="path to the save the response")
    parser.add_argument("--center_crop", action="store_true", help="whether to center crop the image to 224x224, otherwise resize to 224x224")   
    args = parser.parse_args()

    try:
        response = predict_graph(args.img_path, args.center_crop)
        save_response(args.save_path, response)
        response = predict_graph_twomode(args.img_path, args.center_crop)
        save_response(args.save_path[:-5] + 'twomode.json', response)
    except Exception as e:
        with open('openai_err.log', 'a') as f:
            f.write('---------------------------\n')
            f.write(f'{args.img_path}\n')
            f.write(f'{e}\n')