mirrored 18 minutes ago
0
Pengxiang-Lifeat/dart_gui (#371) 00b6468
"""
Dart Agent - Custom agent for GUI automation using Dart models
Based on UITARSAgent structure but using Dart-specific utilities and prompts
"""
import ast
import base64
import logging
import math
import os
import re
import time
from io import BytesIO
from typing import Dict, List, Any
from PIL import Image
from openai import OpenAI
import backoff
import openai
import requests
from requests.exceptions import SSLError
from google.api_core.exceptions import (
    BadRequest,
    InternalServerError,
    InvalidArgument,
    ResourceExhausted,
)

# Import Dart-specific utilities and prompts
from mm_agents.dart_gui.utils import (
    pil_to_base64,
    parse_action_to_structure_output,
    parsing_response_to_pyautogui_code,
    parse_action,
    escape_single_quotes,
    round_by_factor,
    ceil_by_factor,
    floor_by_factor,
    linear_resize,
    smart_resize,
    add_box_token,
    IMAGE_FACTOR,
    MIN_PIXELS,
    MAX_PIXELS,
    MAX_RATIO,
    FINISH_WORD,
    WAIT_WORD,
    ENV_FAIL_WORD,
    CALL_USER
)

from mm_agents.dart_gui.prompts import (
    COMPUTER_USE_PROMPT,
    COMPUTER_USE_PROMPT_WITH_CALL_USER,
    UITARS_ACTION_SPACE,
    UITARS_CALL_USR_ACTION_SPACE,
    UITARS_USR_PROMPT_THOUGHT,
    UITARS_USR_PROMPT_NOTHOUGHT
)

logger = logging.getLogger("desktopenv.agent")

class DartAgent:
    def __init__(
        self,
        model: str,
        runtime_conf: Dict,
        platform="ubuntu",
        max_tokens=1000,
        top_p=0.9,
        top_k=1.0,
        temperature=0.0,
        action_space="pyautogui",
        observation_type="screenshot",
        max_trajectory_length=50,
        model_type="qwen25vl",
        **kwargs
    ):
        self.model = model
        self.platform = platform
        self.action_space = action_space
        self.observation_type = observation_type
        self.max_trajectory_length = max_trajectory_length
        self.model_type = model_type
        self.runtime_conf = runtime_conf
        
        # Extract runtime configuration parameters
        self.max_tokens = self.runtime_conf.get("max_tokens", max_tokens)
        self.top_p = self.runtime_conf.get("top_p", top_p)
        self.top_k = self.runtime_conf.get("top_k", top_k)
        self.temperature = self.runtime_conf.get("temperature", temperature)
        self.infer_mode = self.runtime_conf.get("infer_mode", "dart_mode")
        self.prompt_style = self.runtime_conf.get("prompt_style", "dart_style")
        self.input_swap = self.runtime_conf.get("input_swap", False)
        self.language = self.runtime_conf.get("language", "English")
        self.max_pixels = self.runtime_conf.get("max_pixels", MAX_PIXELS)
        self.min_pixels = self.runtime_conf.get("min_pixels", MIN_PIXELS)
        self.history_n = self.runtime_conf.get("history_n", 5)
        
        # Dart specific configurations
        self.max_images = self.runtime_conf.get("max_images", 5)
        self.max_texts = self.runtime_conf.get("max_texts", 35)
        
        # Initialize OpenAI client - use Dart API if provided
        dart_api_key = self.runtime_conf.get("dart_api_key", "")
        dart_base_url = self.runtime_conf.get("dart_base_url", "")
        
        if dart_base_url:
            # 检查是否为直接的生成端点(包含 /generate)
            if '/generate' in dart_base_url:
                # 直接使用提供的 URL,不添加 /v1
                logger.info(f"使用直接生成端点: {dart_base_url}")
                self.dart_direct_url = dart_base_url
                self.vlm = None  # 不使用 OpenAI 客户端
            else:
                # 传统的 OpenAI 兼容端点,确保以 /v1 结尾
                if not dart_base_url.endswith('/v1'):
                    dart_base_url = dart_base_url.rstrip('/') + '/v1'
                
                self.vlm = OpenAI(
                    base_url=dart_base_url,
                    api_key=dart_api_key,
                )
                self.dart_direct_url = None
        else:
            # Fallback to environment variables
            base_url = os.environ.get('DART_API_URL', os.environ.get('DOUBAO_API_URL'))
            if base_url:
                if '/generate' in base_url:
                    # 直接生成端点
                    self.dart_direct_url = base_url
                    self.vlm = None
                else:
                    if not base_url.endswith('/v1'):
                        base_url = base_url.rstrip('/') + '/v1'
                    self.vlm = OpenAI(
                        base_url=base_url,
                        api_key=os.environ.get('DART_API_KEY', os.environ.get('DOUBAO_API_KEY')),
                    )
                    self.dart_direct_url = None
            else:
                self.vlm = None
                self.dart_direct_url = None

        # Initialize trajectory storage - similar to trajectory_runner.py
        self.thoughts = []
        self.actions = []
        self.observations = []
        self.history_images = []
        self.history_responses = []
        
        # Message handling similar to trajectory_runner.py
        self.base_messages = []          # for model client (with base64 images)
        self.base_messages_for_save = [] # for storage (with file paths)
        self.prompt_dialogue = []        # for model client
        self.save_dialogue = []          # for storage
        self.save_dialogue_full = []     # for full storage (保存所有图片路径)
        self.image_refs = []             # record image position
        
        # All image paths storage - to keep track of all images even when trimmed
        self.all_image_paths = []
        
        # Current screenshot file path for proper saving
        self.current_screenshot_path = None
        
        # Configure prompt and action space based on mode
        if self.infer_mode == "dart_mode":           
            self.prompt_action_space = UITARS_ACTION_SPACE
            self.prompt_template = COMPUTER_USE_PROMPT
        else:
            # For qwen2vl_user mode
            self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
            if self.prompt_style == "qwen2vl_user":
                self.prompt_template = UITARS_USR_PROMPT_THOUGHT
            elif self.prompt_style == "qwen2vl_no_thought":
                self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT
            else:
                self.prompt_template = UITARS_USR_PROMPT_THOUGHT
       
        self.action_parse_res_factor = 1000
        
        logger.info(f"Initialized DartAgent with model: {self.model}, mode: {self.infer_mode}")

    def reset(self, runtime_logger=None):
        """Reset the agent state"""
        self.thoughts = []
        self.actions = []
        self.observations = []
        self.history_images = []
        self.history_responses = []
        
        # Reset message handling
        self.base_messages = []
        self.base_messages_for_save = []
        self.prompt_dialogue = []
        self.save_dialogue = []
        self.save_dialogue_full = []
        self.image_refs = []
        self.all_image_paths = []
        self.current_screenshot_path = None
        
        logger.info("DartAgent reset")

    def set_base_messages(self, instruction: str):
        """Initialize base messages similar to task_loader.py"""
        system_prompt = COMPUTER_USE_PROMPT
            
        self.base_messages = [
            {
                "role": "system",
                "content": "You are a helpful assistant."
            },
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": system_prompt.format(
                            instruction=instruction,
                            language=self.language
                        )
                    }
                ]
            }
        ]
        
        # Copy for save version
        from copy import deepcopy
        self.base_messages_for_save = deepcopy(self.base_messages)

    def set_current_screenshot_path(self, screenshot_path: str):
        """Set the current screenshot file path for proper saving"""
        self.current_screenshot_path = screenshot_path

    def predict(
        self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
    ) -> tuple:
        """
        Predict the next action(s) based on the current observation.
        Returns: (response_text, actions_list)
        """
        # Initialize base messages if not set
        if not self.base_messages:
            self.set_base_messages(instruction)
        
        # Store current observation
        self._add_observation(obs)
        
        # For first step, set the first frame
        if len(self.observations) == 1:
            self._set_first_frame(obs["screenshot"], self.current_screenshot_path)
        else:
            # For subsequent steps, add the new image to dialogue
            # This represents the result of the previous action
            self._add_image(obs["screenshot"], self.current_screenshot_path)
        
        # Build prompt messages (base_messages + prompt_dialogue)
        messages = self._build_messages()
        
        # Call model to get response
        prediction = self._call_model(messages)
        if prediction is None:
            return "client error", ["DONE"]
        
        # Store response and parse actions
        self._add_text(prediction)
        
        # Parse response to actions
        try:
            image_size = self._get_current_image_size()
            actions = self._parse_and_convert_actions(prediction, image_size)
            
            # Check for terminal actions
            terminal_action = self._check_terminal_actions(actions)
            if terminal_action:
                self.actions.append(actions)
                return prediction, [terminal_action]
                
        except Exception as e:
            logger.error(f"Parsing action error: {prediction}, error: {e}")
            return f"Parsing action error: {prediction}, error: {e}", ["DONE"]
        
        self.actions.append(actions)        
        # Check max steps
        if len(self.history_responses) >= self.max_trajectory_length:
            actions = ["FAIL"]
            
        return prediction, actions

    @backoff.on_exception(
        backoff.constant,
        (
            # General exceptions
            SSLError,
            # OpenAI exceptions
            openai.RateLimitError,
            openai.BadRequestError,
            openai.InternalServerError,
            # Google exceptions
            InvalidArgument,
            ResourceExhausted,
            InternalServerError,
            BadRequest,
        ),
        interval=30,
        max_tries=10,
    )
    def predict_with_backoff(self, instruction: str, obs: Dict, last_action_after_obs: Dict = None):
        """Predict with backoff for rate limiting and temporary errors"""
        return self.predict(instruction, obs, last_action_after_obs)

    def get_trajectory(self) -> List[Dict]:
        """Get the current trajectory for saving"""
        trajectory = []
        for i in range(len(self.observations)):
            trajectory.append({
                "observation": self.observations[i],
                "thought": self.thoughts[i] if i < len(self.thoughts) else "",
                "action": self.actions[i] if i < len(self.actions) else []
            })
        return trajectory
    
    def get_full_messages(self) -> List[Dict]:
        """Get the complete conversation messages for saving (including base messages and dialogue)"""
        # Combine base_messages_for_save with save_dialogue_full to get complete conversation
        full_messages = []
        
        # Add base messages (system prompt and initial user message)
        full_messages.extend(self.base_messages_for_save)
        
        # Add dialogue messages (user images + assistant responses) with all images
        full_messages.extend(self.save_dialogue_full)
        
        return full_messages
    
    def get_all_image_paths(self) -> List[str]:
        """Get all image paths that have been used throughout the conversation"""
        return self.all_image_paths.copy()

    # ========== Private Methods ==========
    
    def _validate_trajectory(self):
        """Validate trajectory consistency"""
        assert len(self.observations) == len(self.actions) and len(self.actions) == len(
            self.thoughts
        ), "The number of observations and actions should be the same."

    def _add_observation(self, obs: Dict):
        """Process observation and add to history"""
        # Store observation
        if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
            base64_image = obs["screenshot"]
            try:
                # Handle accessibility tree if needed
                linearized_accessibility_tree = None
                if self.observation_type == "screenshot_a11y_tree" and "accessibility_tree" in obs:
                    # For now, we'll skip accessibility tree processing in Dart mode
                    linearized_accessibility_tree = None
            except:
                linearized_accessibility_tree = None

            if self.observation_type == "screenshot_a11y_tree":
                self.observations.append({
                    "screenshot": base64_image,
                    "accessibility_tree": linearized_accessibility_tree,
                })
            else:
                self.observations.append({
                    "screenshot": base64_image, 
                    "accessibility_tree": None
                })
        else:
            raise ValueError("Invalid observation_type type: " + self.observation_type)


    def _build_messages(self) -> List[Dict]:
        """Build messages for model API call - similar to trajectory_runner._build_messages"""
        return self.base_messages + self.prompt_dialogue

    def _call_model(self, messages: List[Dict]) -> str:
        """Call model with retry logic"""
        try_times = 3
        while try_times > 0:
            try:
                # 如果使用直接生成端点
                if hasattr(self, 'dart_direct_url') and self.dart_direct_url:
                    prediction = self._call_direct_generate_endpoint(messages)
                else:
                    # 使用标准 OpenAI 客户端
                    response = self.vlm.chat.completions.create(
                        model=self.model,
                        messages=messages,
                        frequency_penalty=1,
                        max_tokens=self.max_tokens,
                        temperature=self.temperature,
                        top_p=self.top_p
                    )
                    prediction = response.choices[0].message.content
                
                logger.info(f"Model response: {prediction}")
                return prediction
                
            except Exception as e:
                logger.error(f"Error when fetching response from client: {e}")
                try_times -= 1
                if try_times <= 0:
                    logger.error("Reach max retry times to fetch response from client")
                    return None
        return None

    def _call_direct_generate_endpoint(self, messages: List[Dict]) -> str:
        """直接调用生成端点"""
        try:
   
            
            # 构建请求数据
            payload = {
                "messages": messages,
                "model": self.model,
                "max_tokens": self.max_tokens,
                "temperature": self.temperature,
                "top_p": self.top_p,
                "frequency_penalty": 1
            }
            
            # 添加 API key 到 headers
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.runtime_conf.get('dart_api_key', '')}"
            }
            
          
            # 重试机制:最多重试3次,每次推理60秒
            max_retries = 3
            response = None
            
            for attempt in range(max_retries):
                try:
                    logger.info(f"尝试第 {attempt + 1} 次请求...")
                    response = requests.post(
                        self.dart_direct_url,
                        json=payload,
                        headers=headers,
                        timeout=60
                    )
                    response.raise_for_status()
                    break  # 成功则跳出重试循环
                except Exception as e:
                    logger.warning(f"第 {attempt + 1} 次请求失败: {e}")
                    if attempt == max_retries - 1:  # 最后一次重试失败
                        logger.error(f"所有 {max_retries} 次重试都失败了")
                        raise e
                    else:
                        logger.info(f"等待后重试...")
                        import time
                        time.sleep(2)  # 等待2秒后重试
            
            # 解析响应
            result = response.json()
            
            # 尝试多种可能的响应格式
            if 'choices' in result and len(result['choices']) > 0:
                # OpenAI 兼容格式
                return result['choices'][0]['message']['content']
            elif 'response' in result:
                # 简单的 response 字段
                return result['response']
            elif 'text' in result:
                # text 字段
                return result['text']
            elif 'content' in result:
                # content 字段
                return result['content']
            else:
                # 如果找不到标准字段,返回整个响应的字符串
                logger.warning(f"未知的响应格式: {result}")
                return str(result)
                
        except Exception as e:
            logger.error(f"直接端点调用失败: {e}")
            raise e

    def _add_text(self, assistant_txt: str):
        """Add text response to history - similar to trajectory_runner.py"""
        self.history_responses.append(assistant_txt)
        self.thoughts.append(assistant_txt)
        
        # Add to dialogue similar to trajectory_runner._add_text
        msg = {
            "role": "assistant",
            "content": add_box_token(assistant_txt)
        }
        self.prompt_dialogue.append(msg)
        self.save_dialogue.append(msg)
        self.save_dialogue_full.append(msg)
        self._trim()

    def _set_first_frame(self, obs_img: bytes, frame_path: str = None):
        """Set first frame in base_messages - similar to trajectory_runner._set_first_frame"""
        self.base_messages[1]["content"].append(
            {
                "type": "image_url",
                "image_url": {"url": "data:image/png;base64," + pil_to_base64(obs_img)}
            }
        )
        
        # Use actual frame path if provided, otherwise use current_screenshot_path or placeholder
        if frame_path:
            first_frame_path = frame_path
        elif self.current_screenshot_path:
            first_frame_path = self.current_screenshot_path
        else:
            first_frame_path = "first_frame.png"
            
        # Store in all_image_paths
        self.all_image_paths.append(first_frame_path)
        
        self.base_messages_for_save[1]["content"].append(
            {
                "type": "image_url", 
                "image_url": first_frame_path
            }
        )
        
        self.image_refs.append(
            {"source": "base", "msg_idx": 1,
            "content_idx": len(self.base_messages[1]["content"]) - 1}
        )

    def _add_image(self, img_bytes: bytes, frame_path: str = None):
        """Add image to dialogue - similar to trajectory_runner._add_image"""
        self.prompt_dialogue.append({
            "role": "user",
            "content": [{
                "type": "image_url",
                "image_url": {"url": "data:image/png;base64," + pil_to_base64(img_bytes)}
            }]
        })

        # Use actual frame path if provided, otherwise use current_screenshot_path
        if frame_path:
            image_url = frame_path
        elif self.current_screenshot_path:
            image_url = self.current_screenshot_path
        else:
            # Fallback to a placeholder - this should rarely happen in practice
            image_url = f"frame_{len(self.save_dialogue)}.png"
            
        # Store in all_image_paths for complete record
        self.all_image_paths.append(image_url)

        # Add to save_dialogue (trimmed version)
        self.save_dialogue.append({
            "role": "user", 
            "content": [{
                "type": "image_url",
                "image_url": image_url
            }]
        })
        
        # Add to save_dialogue_full (complete version - never trimmed)
        self.save_dialogue_full.append({
            "role": "user", 
            "content": [{
                "type": "image_url",
                "image_url": image_url
            }]
        })
        
        self.image_refs.append(
            {"source": "dialogue", "msg_idx": len(self.prompt_dialogue) - 1,
            "content_idx": None}
        )
        
        self._trim()

    def _trim(self):
        """Ensure image num ≤ max_images and assistant text num ≤ max_texts - similar to trajectory_runner._trim"""
        img_cnt = len(self.image_refs)
        txt_cnt = sum(m["role"] == "assistant" for m in self.prompt_dialogue)

        while img_cnt > self.max_images or txt_cnt > self.max_texts:
            # 图片超限:最早一张
            if img_cnt > self.max_images:
                ref = self.image_refs.pop(0)
                if ref["source"] == "base":
                    self.base_messages[ref["msg_idx"]]["content"].pop(ref["content_idx"])
                else:  # dialogue 图
                    self._remove_dialogue_msg(ref["msg_idx"])
                img_cnt -= 1
                continue

            # 文本超限:最早 assistant 文本
            if txt_cnt > self.max_texts:
                for i, m in enumerate(self.prompt_dialogue):
                    if m["role"] == "assistant":
                        self._remove_dialogue_msg(i)
                        txt_cnt -= 1
                        break

    def _remove_dialogue_msg(self, idx: int):
        """Remove dialogue message and update refs - similar to trajectory_runner._remove_dialogue_msg"""
        self.prompt_dialogue.pop(idx)
        self.save_dialogue.pop(idx)
        # Note: save_dialogue_full is never trimmed, so we don't remove from it

        # 更新 image_refs
        self.image_refs = [
            r if not (r["source"] == "dialogue" and r["msg_idx"] == idx)
            else None  # 同一条被删掉的图引用直接丢弃
            for r in self.image_refs
        ]
        self.image_refs = [
            (
                {**r, "msg_idx": r["msg_idx"] - 1}
                if r and r["source"] == "dialogue" and r["msg_idx"] > idx  # idx后的图片索引均-1
                else r
            )
            for r in self.image_refs
            if r  # 剔除 None
        ]

    def _get_current_image_size(self) -> tuple:
        """Get current image size for coordinate conversion"""
        if len(self.observations) > 0:
            try:
                current_image_bytes = self.observations[-1]["screenshot"]
                if isinstance(current_image_bytes, bytes):
                    current_image = Image.open(BytesIO(current_image_bytes))
                    return (current_image.height, current_image.width)
            except Exception as e:
                logger.warning(f"Error getting image size: {e}")
        
        # Fallback to default screen size
        return (1080, 1920)

    def _parse_and_convert_actions(self, prediction: str, image_size: tuple) -> List[str]:
        """Parse response and convert to pyautogui actions - similar to trajectory_runner._parse"""
        image_height, image_width = image_size
        
        # Parse the response to structured actions
        parsed_responses = parse_action_to_structure_output(
            prediction,
            factor=self.action_parse_res_factor,
            origin_resized_height=image_height,
            origin_resized_width=image_width,
            model_type=self.model_type,
            max_pixels=self.max_pixels,
            min_pixels=self.min_pixels
        )
        
        # Convert parsed responses to pyautogui actions
        actions = []
        for parsed_response in parsed_responses:
            try:
                pyautogui_code = parsing_response_to_pyautogui_code(
                    parsed_response,
                    image_height=image_height,
                    image_width=image_width,
                    input_swap=self.input_swap
                )
                
               
                
                actions.append(pyautogui_code)
                
            except Exception as e:
                logger.error(f"Error generating pyautogui code: {e}")
                actions.append("FAIL")
        
        return actions

    

    def _check_terminal_actions(self, actions: List[str]) -> str:
        """Check if any action is terminal and return appropriate code"""
        for action in actions:
            if isinstance(action, dict) and "action_type" in action:
                action_type = action["action_type"]
                if action_type == FINISH_WORD:
                    return "DONE"
                elif action_type == WAIT_WORD:
                    return "WAIT"
                elif action_type == ENV_FAIL_WORD:
                    return "FAIL"
                elif action_type == CALL_USER:
                    return "FAIL"
        return None