diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..2a281bb7 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,3 @@ +{ + "onCreateCommand": "./scripts/codespaces_create_and_start_containers.sh" +} diff --git a/.env.example b/.env.example index c5a40e8e..3b2d9e0d 100644 --- a/.env.example +++ b/.env.example @@ -8,11 +8,26 @@ conn.port=2222 # exchange with the user for your target VM conn.username='bob' +#To just use keyauth only, use '' with no space for conn.password +#Otherwise, insert the password for instance here conn.password='secret' +#To just use username and password auth only, use '' with no space for conn.keyfilename +#Otherwise, insert the filepath for the keyfile here (for example, '/home/bob/.ssh/sshkey.rsa') +conn.keyfilename='' # which LLM model to use (can be anything openai supports, or if you use a custom llm.api_url, anything your api provides for the model parameter llm.model='gpt-3.5-turbo' llm.context_size=16385 # how many rounds should this thing go? -max_turns = 20 \ No newline at end of file +max_turns = 20 + +# The following four parameters are only relevant for the usecase rag +# rag_database_folder_name: Name of the folder where the vector store will be saved. +# rag_embedding: The name of the embedding model used. Currently only OpenAI api supported. +# openai_api_key: API key that is used for the embedding model. +# rag_return_token_limit: The upper bound for the RAG output. +rag_database_folder_name = "vetorDB" +rag_embedding = "text-embedding-3-small" +openai_api_key = 'your-openai-key' +rag_return_token_limit = 1000 diff --git a/.env.example.aws b/.env.example.aws new file mode 100644 index 00000000..0577209e --- /dev/null +++ b/.env.example.aws @@ -0,0 +1,23 @@ +llm.api_key='your-openai-key' +log_db.connection_string='log_db.sqlite3' + +# exchange with the IP of your target VM +conn.host='enter the public IP of AWS Instance' +conn.hostname='DNS of AWS Instance ' +conn.port=22 + +# user of target AWS Instance +conn.username='bob' +#To just use keyauth only, use '' with no space for conn.password +#Otherwise, insert the password for instance here +conn.password='' +#To just use username and password auth only, use '' with no space for conn.keyfilename +#Otherwise, insert the filepath for the keyfile here (for example, '/home/bob/.ssh/awskey.pem') +conn.keyfilename='/home/bob/.ssh/awskey.pem' + +# which LLM model to use (can be anything openai supports, or if you use a custom llm.api_url, anything your api provides for the model parameter +llm.model='gpt-3.5-turbo' +llm.context_size=16385 + +# how many rounds should this thing go? +max_turns = 20 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 52d2ad20..5b8b06cb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .env venv/ +.venv/ __pycache__/ *.swp *.log @@ -15,3 +16,12 @@ src/hackingBuddyGPT/usecases/web_api_testing/openapi_spec/ src/hackingBuddyGPT/usecases/web_api_testing/converted_files/ /src/hackingBuddyGPT/usecases/web_api_testing/documentation/openapi_spec/ /src/hackingBuddyGPT/usecases/web_api_testing/documentation/reports/ +scripts/codespaces_ansible.cfg +scripts/codespaces_ansible_hosts.ini +scripts/codespaces_ansible_id_rsa +scripts/codespaces_ansible_id_rsa.pub +scripts/mac_ansible.cfg +scripts/mac_ansible_hosts.ini +scripts/mac_ansible_id_rsa +scripts/mac_ansible_id_rsa.pub +.aider* diff --git a/CODESPACES.md b/CODESPACES.md new file mode 100644 index 00000000..23296ef7 --- /dev/null +++ b/CODESPACES.md @@ -0,0 +1,179 @@ +# Use Case: GitHub Codespaces + +**Backstory** + +https://github.com/ipa-lab/hackingBuddyGPT/pull/85#issuecomment-2331166997 + +> Would it be possible to add codespace support to hackingbuddygpt in a way, that only spawns a single container (maybe with the suid/sudo use-case) and starts hackingBuddyGPT against that container? That might be the 'easiest' show-case/use-case for a new user. + +**Steps** +1. Go to https://github.com/ipa-lab/hackingBuddyGPT +2. Click the "Code" button. +3. Click the "Codespaces" tab. +4. Click the "Create codespace on main" button. +5. Wait for Codespaces to start — This may take upwards of 10 minutes. + +> Setting up remote connection: Building codespace... + +6. After Codespaces started, you may need to restart a new Terminal via the Command Palette: + +Press the key combination: + +> `⇧⌘P` `Shift+Command+P` (Mac) / `Ctrl+Shift+P` (Windows/Linux) + +In the Command Palette, type `>` and `Terminal: Create New Terminal` and press the return key. + +7. You should see a new terminal similar to the following: + +> 👋 Welcome to Codespaces! You are on our default image. +> +> `-` It includes runtimes and tools for Python, Node.js, Docker, and more. See the full list here: https://aka.ms/ghcs-default-image +> +> `-` Want to use a custom image instead? Learn more here: https://aka.ms/configure-codespace +> +> 🔍 To explore VS Code to its fullest, search using the Command Palette (Cmd/Ctrl + Shift + P or F1). +> +> 📝 Edit away, run your app as usual, and we'll automatically make it available for you to access. +> +> @github-username ➜ /workspaces/hackingBuddyGPT (main) $ + +Type the following to manually run: +```bash +./scripts/codespaces_start_hackingbuddygpt_against_a_container.sh +``` +7. Eventually, you should see: + +> Currently, May 2024, running hackingBuddyGPT with GPT-4-turbo against a benchmark containing 13 VMs (with maximum 20 tries per VM) cost around $5. +> +> Therefore, running hackingBuddyGPT with GPT-4-turbo against containing a container with maximum 10 tries would cost around $0.20. +> +> Enter your OpenAI API key and press the return key: + +8. As requested, please enter your OpenAI API key and press the return key. + +9. hackingBuddyGPT should start: + +> Starting hackingBuddyGPT against a container... + +10. If your OpenAI API key is *valid*, then you should see output similar to the following: + +> [00:00:00] Starting turn 1 of 10 +> +> Got command from LLM: +> +> … +> +> [00:01:00] Starting turn 10 of 10 +> +> … +> +> Run finished +> +> maximum turn number reached + +11. If your OpenAI API key is *invalid*, then you should see output similar to the following: + +> [00:00:00] Starting turn 1 of 10 +> +> Traceback (most recent call last): +> +> … +> +> Exception: Error from OpenAI Gateway (401 + +12. Alternatively, use Google Gemini instead of OpenAI + +**Preqrequisites:** + +```bash +python -m venv venv +``` + +```bash +source ./venv/bin/activate +``` + +```bash +pip install -e . +``` + +**Use gemini-openai-proxy and Gemini:** + +http://localhost:8080 is gemini-openai-proxy + +`gpt-4` maps to `gemini-1.5-flash-latest` + +Hence use `gpt-4` below in `--llm.model=gpt-4` + +Gemini free tier has a limit of 15 requests per minute, and 1500 requests per day + +Hence `--max_turns 999999999` will exceed the daily limit + +**Run gemini-openai-proxy** + +```bash +docker run --restart=unless-stopped -it -d -p 8080:8080 --name gemini zhu327/gemini-openai-proxy:latest +``` + +**Manually enter your GEMINI_API_KEY value based on** https://aistudio.google.com/app/apikey + +```bash +export GEMINI_API_KEY= +``` + +**Starting hackingBuddyGPT against a container...** + +```bash +wintermute LinuxPrivesc --llm.api_key=$GEMINI_API_KEY --llm.model=gpt-4 --llm.context_size=1000000 --conn.host=192.168.122.151 --conn.username=lowpriv --conn.password=trustno1 --conn.hostname=test1 --llm.api_url=http://localhost:8080 --llm.api_backoff=60 --max_turns 999999999 +``` + +**Google AI Studio: Gemini free tier has a limit of 15 requests per minute, and 1500 requests per day:** + +https://ai.google.dev/pricing#1_5flash + +> Gemini 1.5 Flash +> +> The Gemini API “free tier” is offered through the API service with lower rate limits for testing purposes. Google AI Studio usage is completely free in all available countries. +> +> Rate Limits +> +> 15 RPM (requests per minute) +> +> 1 million TPM (tokens per minute) +> +> 1,500 RPD (requests per day) +> +> Used to improve Google's products +> +> Yes + +https://ai.google.dev/gemini-api/terms#data-use-unpaid + +> How Google Uses Your Data +> +> When you use Unpaid Services, including, for example, Google AI Studio and the unpaid quota on Gemini API, Google uses the content you submit to the Services and any generated responses to provide, improve, and develop Google products and services and machine learning technologies, including Google's enterprise features, products, and services, consistent with our Privacy Policy https://policies.google.com/privacy +> +> To help with quality and improve our products, human reviewers may read, annotate, and process your API input and output. Google takes steps to protect your privacy as part of this process. This includes disconnecting this data from your Google Account, API key, and Cloud project before reviewers see or annotate it. **Do not submit sensitive, confidential, or personal information to the Unpaid Services.** + +**README.md and Disclaimers:** + +https://github.com/ipa-lab/hackingBuddyGPT/blob/main/README.md + +**Please refer to [README.md](https://github.com/ipa-lab/hackingBuddyGPT/blob/main/README.md) for all disclaimers.** + +Please note and accept all of them. + +**References:** +* https://docs.github.com/en/codespaces +* https://docs.github.com/en/codespaces/getting-started/quickstart +* https://docs.github.com/en/codespaces/reference/using-the-vs-code-command-palette-in-codespaces +* https://openai.com/api/pricing/ +* https://platform.openai.com/docs/quickstart +* https://platform.openai.com/api-keys +* https://ai.google.dev/gemini-api/docs/ai-studio-quickstart +* https://aistudio.google.com/ +* https://aistudio.google.com/app/apikey +* https://ai.google.dev/ +* https://ai.google.dev/gemini-api/docs/api-key +* https://github.com/zhu327/gemini-openai-proxy +* https://hub.docker.com/r/zhu327/gemini-openai-proxy diff --git a/MAC.md b/MAC.md new file mode 100644 index 00000000..067ceff7 --- /dev/null +++ b/MAC.md @@ -0,0 +1,129 @@ +## Use Case: Mac, Docker Desktop and Gemini-OpenAI-Proxy + +**Docker Desktop runs containers in a virtual machine on Mac.** + +**Run hackingBuddyGPT on Mac as follows:** + +Target a localhost container ansible-ready-ubuntu + +via Docker Desktop https://docs.docker.com/desktop/setup/install/mac-install/ + +and Gemini-OpenAI-Proxy https://github.com/zhu327/gemini-openai-proxy + +There are bugs in Docker Desktop on Mac that prevent creation of a custom Docker network 192.168.65.0/24 + +Therefore, localhost TCP port 49152 (or higher) dynamic port number is used for an ansible-ready-ubuntu container + +http://localhost:8080 is gemini-openai-proxy + +gpt-4 maps to gemini-1.5-flash-latest + +Hence use gpt-4 below in --llm.model=gpt-4 + +Gemini free tier has a limit of 15 requests per minute, and 1500 requests per day + +Hence --max_turns 999999999 will exceed the daily limit + +For example: + +```zsh +export GEMINI_API_KEY= + +export PORT=49152 + +wintermute LinuxPrivesc --llm.api_key=$GEMINI_API_KEY --llm.model=gpt-4 --llm.context_size=1000000 --conn.host=localhost --conn.port $PORT --conn.username=lowpriv --conn.password=trustno1 --conn.hostname=test1 --llm.api_url=http://localhost:8080 --llm.api_backoff=60 --max_turns 999999999 +``` + +The above example is consolidated into shell scripts with prerequisites as follows: + +**Preqrequisite: Install Homebrew and Bash version 5:** + +```zsh +/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" +``` + +**Install Bash version 5 via Homebrew:** + +```zsh +brew install bash +``` + +Bash version 4 or higher is needed for `scripts/mac_create_and_start_containers.sh` + +Homebrew provides GNU Bash version 5 via license GPLv3+ + +Whereas Mac provides Bash version 3 via license GPLv2 + +**Create and start containers:** + +```zsh +./scripts/mac_create_and_start_containers.sh +``` + +**Start hackingBuddyGPT against a container:** + +```zsh +export GEMINI_API_KEY= +``` + +```zsh +./scripts/mac_start_hackingbuddygpt_against_a_container.sh +``` + +**Troubleshooting:** + +**Docker Desktop: Internal Server Error** + +```zsh +Server: +ERROR: request returned Internal Server Error for API route and version http://%2FUsers%2Fusername%2F.docker%2Frun%2Fdocker.sock/v1.47/info, check if the server supports the requested API version +errors pretty printing info +``` + +You may need to uninstall Docker Desktop https://docs.docker.com/desktop/uninstall/ and reinstall it from https://docs.docker.com/desktop/setup/install/mac-install/ and try again. + +Alternatively, restart Docker Desktop and try again. + +**There are known issues with Docker Desktop on Mac, such as:** + +* Bug: Docker CLI Hangs for all commands +https://github.com/docker/for-mac/issues/6940 + +* Regression: Docker does not recover from resource saver mode +https://github.com/docker/for-mac/issues/6933 + +**Google AI Studio: Gemini free tier has a limit of 15 requests per minute, and 1500 requests per day:** + +https://ai.google.dev/pricing#1_5flash + +> Gemini 1.5 Flash +> +> The Gemini API “free tier” is offered through the API service with lower rate limits for testing purposes. Google AI Studio usage is completely free in all available countries. +> +> Rate Limits +> +> 15 RPM (requests per minute) +> +> 1 million TPM (tokens per minute) +> +> 1,500 RPD (requests per day) +> +> Used to improve Google's products +> +> Yes + +https://ai.google.dev/gemini-api/terms#data-use-unpaid + +> How Google Uses Your Data +> +> When you use Unpaid Services, including, for example, Google AI Studio and the unpaid quota on Gemini API, Google uses the content you submit to the Services and any generated responses to provide, improve, and develop Google products and services and machine learning technologies, including Google's enterprise features, products, and services, consistent with our Privacy Policy https://policies.google.com/privacy +> +> To help with quality and improve our products, human reviewers may read, annotate, and process your API input and output. Google takes steps to protect your privacy as part of this process. This includes disconnecting this data from your Google Account, API key, and Cloud project before reviewers see or annotate it. **Do not submit sensitive, confidential, or personal information to the Unpaid Services.** + +**README.md and Disclaimers:** + +https://github.com/ipa-lab/hackingBuddyGPT/blob/main/README.md + +**Please refer to [README.md](https://github.com/ipa-lab/hackingBuddyGPT/blob/main/README.md) for all disclaimers.** + +Please note and accept all of them. diff --git a/README.md b/README.md index b7a64c12..a33b12f8 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ +**NEITHER THE IPA-LAB NOR HACKINGBUDDYGPT ARE INVOLVED IN ANY CRYPTO COIN! ALL INFORMATION TO THE CONTRARY IS BEING USED TO SCAM YOU! THE TWITTER ACCOUNT THAT CURRENTLY EXISTS IS JUST TRYING TO GET YOUR MONEY, DO NOT FALL FOR IT!** + + #
HackingBuddyGPT [![Discord](https://dcbadge.vercel.app/api/server/vr4PhSM8yN?style=flat&compact=true)](https://discord.gg/vr4PhSM8yN)
*Helping Ethical Hackers use LLMs in 50 Lines of Code or less..* @@ -12,7 +15,7 @@ If you want to use hackingBuddyGPT and need help selecting the best LLM for your ## hackingBuddyGPT in the News -- **upcoming** 2024-11-20: [Manuel Reinsperger](https://www.github.com/neverbolt) will present hackingBuddyGPT at the [European Symposium on Security and Artificial Intelligence (ESSAI)](https://essai-conference.eu/) +- 2024-11-20: [Manuel Reinsperger](https://www.github.com/neverbolt) presented hackingBuddyGPT at the [European Symposium on Security and Artificial Intelligence (ESSAI)](https://essai-conference.eu/) - 2024-07-26: The [GitHub Accelerator Showcase](https://github.blog/open-source/maintainers/github-accelerator-showcase-celebrating-our-second-cohort-and-whats-next/) features hackingBuddyGPT - 2024-07-24: [Juergen](https://github.com/citostyle) speaks at [Open Source + mezcal night @ GitHub HQ](https://lu.ma/bx120myg) - 2024-05-23: hackingBuddyGPT is part of [GitHub Accelerator 2024](https://github.blog/news-insights/company-news/2024-github-accelerator-meet-the-11-projects-shaping-open-source-ai/) @@ -82,38 +85,38 @@ template_next_cmd = Template(filename=str(template_dir / "next_cmd.txt")) class MinimalLinuxPrivesc(Agent): - conn: SSHConnection = None + _sliding_history: SlidingCliHistory = None + _max_history_size: int = 0 def init(self): super().init() + self._sliding_history = SlidingCliHistory(self.llm) + self._max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - self.llm.count_tokens(template_next_cmd.source) + self.add_capability(SSHRunCommand(conn=self.conn), default=True) self.add_capability(SSHTestCredential(conn=self.conn)) - self._template_size = self.llm.count_tokens(template_next_cmd.source) - def perform_round(self, turn: int) -> bool: - got_root: bool = False + @log_conversation("Asking LLM for a new command...") + def perform_round(self, turn: int, log: Logger) -> bool: + # get as much history as fits into the target context size + history = self._sliding_history.get_history(self._max_history_size) - with self._log.console.status("[bold green]Asking LLM for a new command..."): - # get as much history as fits into the target context size - history = self._sliding_history.get_history(self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size) + # get the next command from the LLM + answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn) + message_id = log.call_response(answer) - # get the next command from the LLM - answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn) - cmd = llm_util.cmd_output_fixer(answer.result) + # clean the command, load and execute it + cmd = llm_util.cmd_output_fixer(answer.result) + capability, arguments = cmd.split(" ", 1) + result, got_root = self.run_capability(message_id, "0", capability, arguments, calling_mode=CapabilityCallingMode.Direct, log=log) - with self._log.console.status("[bold green]Executing that command..."): - self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd) - - # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) + # store the results in our local history self._sliding_history.add_command(cmd, result) - self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) - # if we got root, we can stop the loop + # signal if we were successful in our task return got_root @@ -154,7 +157,7 @@ We try to keep our python dependencies as light as possible. This should allow f To get everything up and running, clone the repo, download requirements, setup API keys and credentials, and start `wintermute.py`: -~~~ bash +```bash # clone the repository $ git clone https://github.com/ipa-lab/hackingBuddyGPT.git $ cd hackingBuddyGPT @@ -166,23 +169,135 @@ $ source ./venv/bin/activate # install python requirements $ pip install -e . -# copy default .env.example +# copy default .env.example $ cp .env.example .env +# NOTE: if you are trying to use this with AWS or ssh-key only authentication, copy .env.example.aws +$ cp .env.example.aws .env + # IMPORTANT: setup your OpenAI API key, the VM's IP and credentials within .env $ vi .env # if you start wintermute without parameters, it will list all available use cases -$ python wintermute.py -usage: wintermute.py [-h] {linux_privesc,minimal_linux_privesc,windows privesc} ... -wintermute.py: error: the following arguments are required: {linux_privesc,windows privesc} +$ python src/hackingBuddyGPT/cli/wintermute.py +No command provided +usage: src/hackingBuddyGPT/cli/wintermute.py [--help] [--config config.json] [options...] + +commands: + ExPrivEscLinux Showcase Minimal Linux Priv-Escalation + ExPrivEscLinuxTemplated Showcase Minimal Linux Priv-Escalation + LinuxPrivesc Linux Privilege Escalation + WindowsPrivesc Windows Privilege Escalation + ExPrivEscLinuxHintFile Linux Privilege Escalation using hints from a hint file initial guidance + ExPrivEscLinuxLSE Linux Privilege Escalation using lse.sh for initial guidance + WebTestingWithExplanation Minimal implementation of a web testing use case while allowing the llm to 'talk' + SimpleWebAPIDocumentation Minimal implementation of a web API testing use case + SimpleWebAPITesting Minimal implementation of a web API testing use case + Viewer Webserver for (live) log viewing + Replayer Tool to replay the .jsonl logs generated by the Viewer (not well tested) + ThesisLinuxPrivescPrototype Thesis Linux Privilege Escalation Prototype + +# to get more information about how to configure a use case you can call it with --help +$ python src/hackingBuddyGPT/cli/wintermute.py LinuxPrivesc --help +usage: src/hackingBuddyGPT/cli/wintermute.py LinuxPrivesc [--help] [--config config.json] [options...] + + --log.log_server_address='localhost:4444' address:port of the log server to be used (default from builtin) + --log.tag='' Tag for your current run (default from builtin) + --log='local_logger' choice of logging backend (default from builtin) + --log_db.connection_string='wintermute.sqlite3' sqlite3 database connection string for logs (default from builtin) + --max_turns='30' (default from .env file, alternatives: 10 from builtin) + --llm.api_key= OpenAI API Key (default from .env file) + --llm.model OpenAI model name + --llm.context_size='100000' Maximum context size for the model, only used internally for things like trimming to the context size (default from .env file) + --llm.api_url='https://api.openai.com' URL of the OpenAI API (default from builtin) + --llm.api_path='/v1/chat/completions' Path to the OpenAI API (default from builtin) + --llm.api_timeout=240 Timeout for the API request (default from builtin) + --llm.api_backoff=60 Backoff time in seconds when running into rate-limits (default from builtin) + --llm.api_retries=3 Number of retries when running into rate-limits (default from builtin) + --system='linux' (default from builtin) + --enable_explanation=False (default from builtin) + --enable_update_state=False (default from builtin) + --disable_history=False (default from builtin) + --hint='' (default from builtin) + --conn.host + --conn.hostname + --conn.username + --conn.password + --conn.keyfilename + --conn.port='2222' (default from .env file, alternatives: 22 from builtin) +``` + +### Provide a Target Machine over SSH + +The next important part is having a machine that we can run our agent against. In our case, the target machine will be situated at `192.168.122.151`. + +We are using vulnerable Linux systems running in Virtual Machines for this. Never run this against real systems. + +> 💡 **We also provide vulnerable machines!** +> +> We are using virtual machines from our [Linux Privilege-Escalation Benchmark](https://github.com/ipa-lab/benchmark-privesc-linux) project. Feel free to use them for your own research! + +## Using the web based viewer and replayer + +If you want to have a better representation of the agent's output, you can use the web-based viewer. You can start it using `wintermute Viewer`, which will run the server on `http://127.0.0.1:4444` for the default `wintermute.sqlite3` database. You can change these options using the `--log_server_address` and `--log_db.connection_string` parameters. + +Navigating to the log server address will show you an overview of all runs and clicking on a run will show you the details of that run. The viewer updates live using a websocket connection, and if you enable `Follow new runs` it will automatically switch to the new run when one is started. + +Keep in mind that there is no additional protection for this webserver, other than how it can be reached (per default binding to `127.0.0.1` means it can only be reached from your local machine). If you make it accessible to the internet, everybody will be able to see all of your runs and also be able to inject arbitrary data into the database. + +Therefore **DO NOT** make it accessible to the internet if you're not super sure about what you're doing! + +There is also the experimental replay functionality, which can replay a run live from a capture file, including timing information. This is great for showcases and presentations, because it looks like everything is happening live and for real, but you know exactly what the results will be. + +To use this, the run needs to be captured by a Viewer server by setting `--save_playback_dir` to a directory where the viewer can write the capture files. +With the Viewer server still running, you can then start `wintermute Replayer --replay_file ` to replay the captured run (this will create a new run in the database). +You can configure it to `--pause_on_message` and `--pause_on_tool_calls`, which will interrupt the replay at the respective points until enter is pressed in the shell where you run the Replayer in. You can also configure the `--playback_speed` to control the speed of the replay. + +## Use Cases + +GitHub Codespaces: + +* See [CODESPACES.md](CODESPACES.md) + +Mac, Docker Desktop and Gemini-OpenAI-Proxy: + +* See [MAC.md](MAC.md) + +## Run the Hacking Agent + +Finally we can run hackingBuddyGPT against our provided test VM. Enjoy! + +> ❗ **Don't be evil!** +> +> Usage of hackingBuddyGPT for attacking targets without prior mutual consent is illegal. It's the end user's responsibility to obey all applicable local, state and federal laws. Developers assume no liability and are not responsible for any misuse or damage caused by this program. Only use for educational purposes. + +With that out of the way, let's look at an example hackingBuddyGPT run. Each run is structured in rounds. At the start of each round, hackingBuddyGPT asks a LLM for the next command to execute (e.g., `whoami`) for the first round. It then executes that command on the virtual machine, prints its output and starts a new round (in which it also includes the output of prior rounds) until it reaches step number 10 or becomes root: + +```bash # start wintermute, i.e., attack the configured virtual machine -$ python wintermute.py minimal_linux_privesc +$ python src/hackingBuddyGPT/cli/wintermute.py LinuxPrivesc --llm.api_key=sk...ChangeMeToYourOpenAiApiKey --llm.model=gpt-4-turbo --llm.context_size=8192 --conn.host=192.168.122.151 --conn.username=lowpriv --conn.password=trustno1 --conn.hostname=test1 + # install dependencies for testing if you want to run the tests -$ pip install .[testing] -~~~ +$ pip install '.[testing]' +``` + +## Beta Features + +### Viewer + +The viewer is a simple web-based tool to view the results of hackingBuddyGPT runs. It is currently in beta and can be started with: + +```bash +$ hackingBuddyGPT Viewer +``` + +This will start a webserver on `http://localhost:4444` that can be accessed with a web browser. + +To log to this central viewer, you currently need to change the `GlobalLogger` definition in [./src/hackingBuddyGPT/utils/logging.py](src/hackingBuddyGPT/utils/logging.py) to `GlobalRemoteLogger`. + +This feature is not fully tested yet and therefore is not recommended to be exposed to the internet! ## Publications about hackingBuddyGPT diff --git a/publish_notes.md b/publish_notes.md new file mode 100644 index 00000000..7610762f --- /dev/null +++ b/publish_notes.md @@ -0,0 +1,34 @@ +# how to publish to pypi + +## start with testing if the project builds and tag the version + +```bash +python -m venv venv +source venv/bin/activate +pip install -e . +pytest +git tag v0.3.0 +git push origin v0.3.0 +``` + +## build and new package + +(according to https://packaging.python.org/en/latest/tutorials/packaging-projects/) + +```bash +pip install build twine +python3 -m build +vi ~/.pypirc +twine check dist/* +``` + +Now, for next time.. test install the package in a new vanilla environment, then.. + +```bash +twine upload dist/* +``` + +## repo todos + +- rebase development upon main +- bump the pyproject version number to a new `-dev` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index aac9dd31..61c8f8c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,18 +4,20 @@ build-backend = "setuptools.build_meta" [project] name = "hackingBuddyGPT" +# original author was Andreas Happe, for an up-to-date list see +# https://github.com/ipa-lab/hackingBuddyGPT/graphs/contributors authors = [ - { name = "Andreas Happe", email = "andreas@offensive.one" } + { name = "HackingBuddyGPT maintainers", email = "maintainers@hackingbuddy.ai" } ] maintainers = [ { name = "Andreas Happe", email = "andreas@offensive.one" }, - { name = "Juergen Cito", email = "juergen.cito@tuwiena.c.at" } + { name = "Juergen Cito", email = "juergen.cito@tuwien.ac.at" } ] description = "Helping Ethical Hackers use LLMs in 50 lines of code" readme = "README.md" keywords = ["hacking", "pen-testing", "LLM", "AI", "agent"] requires-python = ">=3.10" -version = "0.3.1" +version = "0.4.0" license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: 3", @@ -24,19 +26,30 @@ classifiers = [ "Development Status :: 4 - Beta", ] dependencies = [ - 'fabric == 3.2.2', - 'Mako == 1.3.2', - 'requests == 2.32.0', - 'rich == 13.7.1', - 'tiktoken == 0.6.0', - 'instructor == 1.3.5', - 'PyYAML == 6.0.1', - 'python-dotenv == 1.0.1', - 'pypsexec == 0.3.0', - 'pydantic == 2.8.2', - 'openai == 1.28.0', - 'BeautifulSoup4', - 'nltk' + 'fabric == 3.2.2', + 'Mako == 1.3.2', + 'requests == 2.32.3', + 'rich == 13.7.1', + 'tiktoken == 0.8.0', + 'instructor == 1.7.2', + 'PyYAML == 6.0.1', + 'python-dotenv == 1.0.1', + 'pypsexec == 0.3.0', + 'pydantic == 2.8.2', + 'openai == 1.65.2', + 'BeautifulSoup4', + 'nltk', + 'fastapi == 0.114.0', + 'fastapi-utils == 0.7.0', + 'jinja2 == 3.1.4', + 'uvicorn[standard] == 0.30.6', + 'dataclasses_json == 0.6.7', + 'websockets == 13.1', + 'langchain-community', + 'langchain-openai', + 'markdown', + 'chromadb', + 'langchain-chroma', ] [project.urls] @@ -54,15 +67,27 @@ where = ["src"] [tool.pytest.ini_options] pythonpath = "src" -addopts = [ - "--import-mode=importlib", -] +addopts = ["--import-mode=importlib"] [project.optional-dependencies] -testing = [ - 'pytest', - 'pytest-mock' +testing = ['pytest', 'pytest-mock'] +dev = [ + 'ruff', +] +rag-usecase = [ + 'langchain-community', + 'langchain-openai', + 'markdown', + 'chromadb', + 'langchain-chroma', ] [project.scripts] wintermute = "hackingBuddyGPT.cli.wintermute:main" hackingBuddyGPT = "hackingBuddyGPT.cli.wintermute:main" + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "B", "I"] +ignore = ["E501", "F401", "F403"] diff --git a/scripts/codespaces_create_and_start_containers.Dockerfile b/scripts/codespaces_create_and_start_containers.Dockerfile new file mode 100644 index 00000000..fe16874a --- /dev/null +++ b/scripts/codespaces_create_and_start_containers.Dockerfile @@ -0,0 +1,67 @@ +# codespaces_create_and_start_containers.Dockerfile + +FROM ubuntu:latest + +ENV DEBIAN_FRONTEND=noninteractive + +# Use the TIMEZONE variable to configure the timezone +ENV TIMEZONE=Etc/UTC +RUN ln -fs /usr/share/zoneinfo/$TIMEZONE /etc/localtime && echo $TIMEZONE > /etc/timezone + +# Update package list and install dependencies in one line +RUN apt-get update && apt-get install -y \ + software-properties-common \ + openssh-server \ + sudo \ + python3 \ + python3-venv \ + python3-setuptools \ + python3-wheel \ + python3-apt \ + passwd \ + tzdata \ + iproute2 \ + wget \ + cron \ + --no-install-recommends && \ + add-apt-repository ppa:deadsnakes/ppa -y && \ + apt-get update && apt-get install -y \ + python3.11 \ + python3.11-venv \ + python3.11-distutils \ + python3.11-dev && \ + dpkg-reconfigure --frontend noninteractive tzdata && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Install pip using get-pip.py +RUN wget https://bootstrap.pypa.io/get-pip.py && python3.11 get-pip.py && rm get-pip.py + +# Install required Python packages +RUN python3.11 -m pip install --no-cache-dir passlib cffi cryptography + +# Ensure python3-apt is properly installed and linked +RUN ln -s /usr/lib/python3/dist-packages/apt_pkg.cpython-310-x86_64-linux-gnu.so /usr/lib/python3/dist-packages/apt_pkg.so || true + +# Prepare SSH server +RUN mkdir /var/run/sshd + +# Create ansible user +RUN useradd -m -s /bin/bash ansible + +# Set up SSH for ansible +RUN mkdir -p /home/ansible/.ssh && \ + chmod 700 /home/ansible/.ssh && \ + chown ansible:ansible /home/ansible/.ssh + +# Configure sudo access for ansible +RUN echo "ansible ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/ansible + +# Disable root SSH login +RUN sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin no/' /etc/ssh/sshd_config + +# Expose SSH port +EXPOSE 22 + +# Start SSH server +CMD ["/usr/sbin/sshd", "-D"] diff --git a/scripts/codespaces_create_and_start_containers.sh b/scripts/codespaces_create_and_start_containers.sh new file mode 100755 index 00000000..0a8d45ab --- /dev/null +++ b/scripts/codespaces_create_and_start_containers.sh @@ -0,0 +1,288 @@ +#!/bin/bash + +# Purpose: In GitHub Codespaces, automates the setup of Docker containers, +# preparation of Ansible inventory, and modification of tasks for testing. +# Usage: ./scripts/codespaces_create_and_start_containers.sh + +# Enable strict error handling for better script robustness +set -e # Exit immediately if a command exits with a non-zero status +set -u # Treat unset variables as an error and exit immediately +set -o pipefail # Return the exit status of the last command in a pipeline that failed +set -x # Print each command before executing it (useful for debugging) + +cd $(dirname $0) + +bash_version=$(/bin/bash --version | head -n 1 | awk '{print $4}' | cut -d. -f1) + +if (( bash_version < 4 )); then + echo 'Error: Requires Bash version 4 or higher.' + exit 1 +fi + +# Step 1: Initialization + +if [ ! -f hosts.ini ]; then + echo "hosts.ini not found! Please ensure your Ansible inventory file exists before running the script." + exit 1 +fi + +if [ ! -f tasks.yaml ]; then + echo "tasks.yaml not found! Please ensure your Ansible playbook file exists before running the script." + exit 1 +fi + +# Default values for network and base port, can be overridden by environment variables +DOCKER_NETWORK_NAME=${DOCKER_NETWORK_NAME:-192_168_122_0_24} +DOCKER_NETWORK_SUBNET="192.168.122.0/24" +BASE_PORT=${BASE_PORT:-49152} + +# Step 2: Define helper functions + +# Function to find an available port starting from a base port +find_available_port() { + local base_port="$1" + local port=$base_port + local max_port=65535 + while ss -tuln | grep -q ":$port "; do + port=$((port + 1)) + if [ "$port" -gt "$max_port" ]; then + echo "No available ports in the range $base_port-$max_port." >&2 + exit 1 + fi + done + echo $port +} + +# Function to generate SSH key pair +generate_ssh_key() { + ssh-keygen -t rsa -b 4096 -f ./codespaces_ansible_id_rsa -N '' -q <<< y + echo "New SSH key pair generated." + chmod 600 ./codespaces_ansible_id_rsa +} + +# Function to create and start Docker container with SSH enabled +start_container() { + local container_name="$1" + local base_port="$2" + local container_ip="$3" + local image_name="ansible-ready-ubuntu" + + if [ "$(docker ps -aq -f name=${container_name})" ]; then + echo "Container ${container_name} already exists. Removing it..." >&2 + docker stop ${container_name} > /dev/null 2>&1 || true + docker rm ${container_name} > /dev/null 2>&1 || true + fi + + echo "Starting Docker container ${container_name} with IP ${container_ip} on port ${base_port}..." >&2 + docker run -d --name ${container_name} -h ${container_name} --network ${DOCKER_NETWORK_NAME} --ip ${container_ip} -p "${base_port}:22" ${image_name} > /dev/null 2>&1 + + # Copy SSH public key to container + docker cp ./codespaces_ansible_id_rsa.pub ${container_name}:/home/ansible/.ssh/authorized_keys + docker exec ${container_name} chown ansible:ansible /home/ansible/.ssh/authorized_keys + docker exec ${container_name} chmod 600 /home/ansible/.ssh/authorized_keys + + echo "${container_ip}" +} + +# Function to check if SSH is ready on a container +check_ssh_ready() { + local container_ip="$1" + timeout 1 ssh -o BatchMode=yes -o ConnectTimeout=10 -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i ./codespaces_ansible_id_rsa ansible@${container_ip} exit 2>/dev/null + return $? +} + +# Function to replace IP address and add Ansible configuration +replace_ip_and_add_config() { + local original_ip="$1" + local container_name="${original_ip//./_}" + + # Find an available port for the container + local available_port=$(find_available_port "$BASE_PORT") + + # Start the container with the available port + local container_ip=$(start_container "$container_name" "$available_port" "$original_ip") + + # Replace the original IP with the new container IP and add Ansible configuration + sed -i "s/^[[:space:]]*$original_ip[[:space:]]*$/$container_ip ansible_user=ansible ansible_ssh_private_key_file=.\/codespaces_ansible_id_rsa ansible_ssh_common_args='-o StrictHostKeyChecking=no -o UserKnownHostsFile=\/dev\/null'/" codespaces_ansible_hosts.ini + + echo "Started container ${container_name} with IP ${container_ip}, mapped to host port ${available_port}" + echo "Updated IP ${original_ip} to ${container_ip} in codespaces_ansible_hosts.ini" + + # Increment BASE_PORT for the next container + BASE_PORT=$((available_port + 1)) +} + +# Step 3: Update and install prerequisites + +echo "Updating package lists..." + +# Install prerequisites and set up Docker +sudo apt-get update +sudo apt-get install -y apt-transport-https ca-certificates curl gnupg lsb-release + +# Step 4: Set up Docker repository and install Docker components + +echo "Adding Docker's official GPG key..." +curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --batch --yes --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg +echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null + +echo "Updating package lists again..." +sudo apt-get update + +echo "Installing Moby components (moby-engine, moby-cli, moby-tini)..." +sudo apt-get install -y moby-engine moby-cli moby-tini moby-containerd + +# Step 5: Start Docker and containerd services + +echo "Starting Docker daemon using Moby..." +sudo service docker start || true +sudo service containerd start || true + +# Step 6: Wait for Docker to be ready + +echo "Waiting for Docker to be ready..." +timeout=60 +while ! sudo docker info >/dev/null 2>&1; do + if [ $timeout -le 0 ]; then + echo "Timed out waiting for Docker to start." + sudo service docker status || true + echo "Docker daemon logs:" + sudo cat /var/log/docker.log || true + exit 1 + fi + echo "Waiting for Docker to be available... ($timeout seconds left)" + timeout=$(($timeout - 1)) + sleep 1 +done + +echo "Docker (Moby) is ready." + +# Step 7: Install Python packages and Ansible + +echo "Verifying Docker installation..." +docker --version +docker info + +echo "Installing other required packages..." +sudo apt-get install -y python3 python3-pip sshpass + +echo "Installing Ansible and passlib using pip..." +pip3 install ansible passlib + +# Step 8: Build Docker image with SSH enabled + +echo "Building Docker image with SSH enabled..." +if ! docker build -t ansible-ready-ubuntu -f codespaces_create_and_start_containers.Dockerfile .; then + echo "Failed to build Docker image." >&2 + exit 1 +fi + +# Step 9: Create a custom Docker network if it does not exist + +echo "Checking if the custom Docker network '${DOCKER_NETWORK_NAME}' with subnet 192.168.122.0/24 exists..." + +if ! docker network inspect ${DOCKER_NETWORK_NAME} >/dev/null 2>&1; then + docker network create --subnet="${DOCKER_NETWORK_SUBNET}" "${DOCKER_NETWORK_NAME}" || echo "Network creation failed, but continuing..." +fi + +# Generate SSH key +generate_ssh_key + +# Step 10: Copy hosts.ini to codespaces_ansible_hosts.ini and update IP addresses + +echo "Copying hosts.ini to codespaces_ansible_hosts.ini and updating IP addresses..." + +# Copy hosts.ini to codespaces_ansible_hosts.ini +cp hosts.ini codespaces_ansible_hosts.ini + +# Read hosts.ini to get IP addresses and create containers +current_group="" +while IFS= read -r line || [ -n "$line" ]; do + if [[ $line =~ ^\[(.+)\] ]]; then + current_group="${BASH_REMATCH[1]}" + echo "Processing group: $current_group" + elif [[ $line =~ ^[[:space:]]*([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)[[:space:]]*$ ]]; then + ip="${BASH_REMATCH[1]}" + echo "Found IP $ip in group $current_group" + replace_ip_and_add_config "$ip" + fi +done < hosts.ini + +# Add [all:vars] section if it doesn't exist +if ! grep -q "\[all:vars\]" codespaces_ansible_hosts.ini; then + echo "Adding [all:vars] section to codespaces_ansible_hosts.ini" + echo "" >> codespaces_ansible_hosts.ini + echo "[all:vars]" >> codespaces_ansible_hosts.ini + echo "ansible_python_interpreter=/usr/bin/python3" >> codespaces_ansible_hosts.ini +fi + +echo "Finished updating codespaces_ansible_hosts.ini" + +# Step 11: Wait for SSH services to start on all containers + +echo "Waiting for SSH services to start on all containers..." +declare -A exit_statuses # Initialize an associative array to track exit statuses + +# Check SSH readiness sequentially for all containers +while IFS= read -r line; do + if [[ "$line" =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+.* ]]; then + container_ip=$(echo "$line" | awk '{print $1}') + + echo "Checking SSH readiness for $container_ip..." + if check_ssh_ready "$container_ip"; then + echo "$container_ip is ready" + exit_statuses["$container_ip"]=0 # Mark as success + else + echo "$container_ip failed SSH check" + exit_statuses["$container_ip"]=1 # Mark as failure + fi + fi +done < codespaces_ansible_hosts.ini + +# Check for any failures in the SSH checks +ssh_check_failed=false +for container_ip in "${!exit_statuses[@]}"; do + if [ "${exit_statuses[$container_ip]}" -ne 0 ]; then + echo "Error: SSH check failed for $container_ip" + ssh_check_failed=true + fi +done + +if [ "$ssh_check_failed" = true ]; then + echo "Not all containers are ready. Exiting." + exit 1 # Exit the script with error if any SSH check failed +else + echo "All containers are ready!" +fi + +# Step 12: Create ansible.cfg file + +# Generate Ansible configuration file +cat << EOF > codespaces_ansible.cfg +[defaults] +interpreter_python = auto_silent +host_key_checking = False +remote_user = ansible + +[privilege_escalation] +become = True +become_method = sudo +become_user = root +become_ask_pass = False +EOF + +# Step 13: Set ANSIBLE_CONFIG environment variable + +export ANSIBLE_CONFIG=$(pwd)/codespaces_ansible.cfg + +echo "Setup complete. You can now run your Ansible playbooks." + +# Step 14: Run Ansible playbooks + +echo "Running Ansible playbook..." + +ansible-playbook -i codespaces_ansible_hosts.ini tasks.yaml + +echo "Feel free to run tests now..." + +exit 0 diff --git a/scripts/codespaces_start_hackingbuddygpt_against_a_container.sh b/scripts/codespaces_start_hackingbuddygpt_against_a_container.sh new file mode 100755 index 00000000..082b8e0b --- /dev/null +++ b/scripts/codespaces_start_hackingbuddygpt_against_a_container.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# Purpose: In GitHub Codespaces, start hackingBuddyGPT against a container +# Usage: ./scripts/codespaces_start_hackingbuddygpt_against_a_container.sh + +# Enable strict error handling for better script robustness +set -e # Exit immediately if a command exits with a non-zero status +set -u # Treat unset variables as an error and exit immediately +set -o pipefail # Return the exit status of the last command in a pipeline that failed +set -x # Print each command before executing it (useful for debugging) + +cd $(dirname $0) + +bash_version=$(/bin/bash --version | head -n 1 | awk '{print $4}' | cut -d. -f1) + +if (( bash_version < 4 )); then + echo 'Error: Requires Bash version 4 or higher.' + exit 1 +fi + +# Step 1: Install prerequisites + +# setup virtual python environment +cd .. +python -m venv venv +source ./venv/bin/activate + +# install python requirements +pip install -e . + +# Step 2: Request an OpenAI API key + +echo +echo 'Currently, May 2024, running hackingBuddyGPT with GPT-4-turbo against a benchmark containing 13 VMs (with maximum 20 tries per VM) cost around $5.' +echo +echo 'Therefore, running hackingBuddyGPT with GPT-4-turbo against containing a container with maximum 10 tries would cost around $0.20.' +echo +echo "Enter your OpenAI API key and press the return key:" +read OPENAI_API_KEY +echo + +# Step 3: Start hackingBuddyGPT against a container + +echo "Starting hackingBuddyGPT against a container..." +echo + +wintermute LinuxPrivesc --llm.api_key=$OPENAI_API_KEY --llm.model=gpt-4-turbo --llm.context_size=8192 --conn.host=192.168.122.151 --conn.username=lowpriv --conn.password=trustno1 --conn.hostname=test1 + +# Alternatively, the following comments demonstrate using gemini-openai-proxy and Gemini + +# http://localhost:8080 is gemini-openai-proxy + +# gpt-4 maps to gemini-1.5-flash-latest + +# Hence use gpt-4 below in --llm.model=gpt-4 + +# Gemini free tier has a limit of 15 requests per minute, and 1500 requests per day + +# Hence --max_turns 999999999 will exceed the daily limit + +# docker run --restart=unless-stopped -it -d -p 8080:8080 --name gemini zhu327/gemini-openai-proxy:latest + +# export GEMINI_API_KEY= + +# wintermute LinuxPrivesc --llm.api_key=$GEMINI_API_KEY --llm.model=gpt-4 --llm.context_size=1000000 --conn.host=192.168.122.151 --conn.username=lowpriv --conn.password=trustno1 --conn.hostname=test1 --llm.api_url=http://localhost:8080 --llm.api_backoff=60 --max_turns 999999999 diff --git a/scripts/hosts.ini b/scripts/hosts.ini new file mode 100644 index 00000000..1e2e187e --- /dev/null +++ b/scripts/hosts.ini @@ -0,0 +1,12 @@ +# Backstory + +# https://github.com/ipa-lab/hackingBuddyGPT/pull/85#issuecomment-2331166997 + +# Would it be possible to add codespace support to hackingbuddygpt in a way, that only spawns a single container (maybe with the suid/sudo use-case) and starts hackingBuddyGPT against that container? That might be the 'easiest' show-case/use-case for a new user. + +192.168.122.151 + +# those are mostly file-based (suid/sudo) + +[vuln_suid_gtfo] +192.168.122.151 diff --git a/scripts/mac_create_and_start_containers.sh b/scripts/mac_create_and_start_containers.sh new file mode 100755 index 00000000..016288a3 --- /dev/null +++ b/scripts/mac_create_and_start_containers.sh @@ -0,0 +1,262 @@ +#!/opt/homebrew/bin/bash + +# Purpose: Automates the setup of docker containers for local testing on Mac +# Usage: ./scripts/mac_create_and_start_containers.sh + +# Enable strict error handling for better script robustness +set -e # Exit immediately if a command exits with a non-zero status +set -u # Treat unset variables as an error and exit immediately +set -o pipefail # Return the exit status of the last command in a pipeline that failed +set -x # Print each command before executing it (useful for debugging) + +cd $(dirname $0) + +bash_version=$(/opt/homebrew/bin/bash --version | head -n 1 | awk '{print $4}' | cut -d. -f1) + +if (( bash_version < 4 )); then + echo 'Error: Requires Bash version 4 or higher.' + exit 1 +fi + +# Step 1: Initialization + +if [ ! -f hosts.ini ]; then + echo "hosts.ini not found! Please ensure your Ansible inventory file exists." + exit 1 +fi + +if [ ! -f tasks.yaml ]; then + echo "tasks.yaml not found! Please ensure your Ansible playbook file exists." + exit 1 +fi + +# Default values for network and base port, can be overridden by environment variables +DOCKER_NETWORK_NAME=${DOCKER_NETWORK_NAME:-192_168_65_0_24} +DOCKER_NETWORK_SUBNET="192.168.65.0/24" +BASE_PORT=${BASE_PORT:-49152} + +# Step 2: Define helper functions + +# Function to find an available port +find_available_port() { + local base_port="$1" + local port=$base_port + local max_port=65535 + while lsof -i :$port; do + port=$((port + 1)) + if [ "$port" -gt "$max_port" ]; then + echo "No available ports in the range $base_port-$max_port." >&2 + exit 1 + fi + done + echo $port +} + +# Function to generate SSH key pair +generate_ssh_key() { + ssh-keygen -t rsa -b 4096 -f ./mac_ansible_id_rsa -N '' -q <<< y + echo "New SSH key pair generated." + chmod 600 ./mac_ansible_id_rsa +} + +# Function to create and start docker container with SSH enabled +start_container() { + local container_name="$1" + local port="$2" + local image_name="ansible-ready-ubuntu" + + if docker --debug ps -aq -f name=${container_name}; then + echo "Container ${container_name} already exists. Removing it..." >&2 + docker --debug stop ${container_name} || true + docker --debug rm ${container_name} || true + fi + + echo "Starting docker container ${container_name} on port ${port}..." >&2 + + # Uncomment the following line to use a custom Docker network + # docker --debug run --restart=unless-stopped -it -d --network ${DOCKER_NETWORK_NAME} -p "${port}:22" --name ${container_name} -h ${container_name} ${image_name} + # The line is commented out because of the bugs in Docker Desktop on Mac causing hangs + + # Alternatively, start Docker container with SSH enabled on localhost without using a custom Docker network + docker --debug run --restart=unless-stopped -it -d -p "${port}:22" --name ${container_name} -h ${container_name} ${image_name} + + # Retrieve the IP address assigned by Docker + container_ip=$(docker --debug inspect -f '{{range.NetworkSettings.Networks}}{{.IPAddress}}{{end}}' "$container_name") + + # Verify that container_ip is not empty + if [ -z "$container_ip" ]; then + echo "Error: Could not retrieve IP address for container $container_name." >&2 + exit 1 + fi + + echo "Container ${container_name} started with IP ${container_ip} and port ${port}." + + # Copy SSH public key to container + docker --debug cp ./mac_ansible_id_rsa.pub ${container_name}:/home/ansible/.ssh/authorized_keys + docker --debug exec ${container_name} chown ansible:ansible /home/ansible/.ssh/authorized_keys + docker --debug exec ${container_name} chmod 600 /home/ansible/.ssh/authorized_keys +} + +# Function to check if SSH is ready on a container +check_ssh_ready() { + local port="$1" + ssh -o BatchMode=yes -o ConnectTimeout=10 -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i ./mac_ansible_id_rsa -p ${port} ansible@localhost exit 2>/dev/null + return $? +} + +# Step 3: Verify docker Desktop + +echo "Checking if docker Desktop is running..." +if ! docker --debug info; then + echo If the above says + echo + echo "Server:" + echo "ERROR: request returned Internal Server Error for API route and version http://%2FUsers%2Fusername%2F.docker%2Frun%2Fdocker.sock/v1.47/info, check if the server supports the requested API version" + echo "errors pretty printing info" + echo + echo You may need to uninstall Docker Desktop https://docs.docker.com/desktop/uninstall/ and reinstall it from https://docs.docker.com/desktop/setup/install/mac-install/ and try again. + echo + echo Alternatively, restart Docker Desktop and try again. + echo + echo There are known issues with Docker Desktop on Mac, such as: + echo + echo Bug: Docker CLI Hangs for all commands + echo https://github.com/docker/for-mac/issues/6940 + echo + echo Regression: Docker does not recover from resource saver mode + echo https://github.com/docker/for-mac/issues/6933 + echo + echo "Docker Desktop is not running. Please start Docker Desktop and try again." + echo + exit 1 +fi + +# Step 4: Install prerequisites + +echo "Installing required Python packages..." +if ! command -v pip3 >/dev/null 2>&1; then + echo "pip3 not found. Please install Python3 and pip3 first." + exit 1 +fi + +echo "Installing Ansible and passlib using pip..." +pip3 install ansible passlib + +# Step 5: Build docker image + +echo "Building docker image with SSH enabled..." +if ! docker --debug build -t ansible-ready-ubuntu -f codespaces_create_and_start_containers.Dockerfile .; then + echo "Failed to build docker image." >&2 + exit 1 +fi + +# Step 6: Create a custom docker network if it does not exist + +# Commenting out this step because Docker bug and its regression that are clausing CLI to hang + +# There is a Docker bug that prevents creating custom networks on MacOS because it hangs + +# Bug: Docker CLI Hangs for all commands +# https://github.com/docker/for-mac/issues/6940 + +# Regression: Docker does not recover from resource saver mode +# https://github.com/docker/for-mac/issues/6933 + +# echo "Checking if the custom docker network '${DOCKER_NETWORK_NAME}' with subnet {DOCKER_NETWORK_SUBNET} exists" + +# if ! docker --debug network inspect ${DOCKER_NETWORK_NAME} >/dev/null 2>&1; then +# docker --debug network create --subnet="${DOCKER_NETWORK_SUBNET}" "${DOCKER_NETWORK_NAME}" || echo "Network creation failed, but continuing..." +# fi + +# Unfortunately, the above just hangs like this: + +# + echo 'Checking if the custom docker network '\''192_168_65_0_24'\'' with subnet {DOCKER_NETWORK_SUBNET} exists' +# Checking if the custom docker network '192_168_65_0_24' with subnet {DOCKER_NETWORK_SUBNET} exists +# + docker --debug network inspect 192_168_65_0_24 +# + docker --debug network create --subnet=192.168.65.0/24 192_168_65_0_24 + +# (It hangs here) + +# For now, the workaround is to use localhost as the IP address on a dynamic or private TCP port, such as 41952 + +# Step 7: Generate SSH key +generate_ssh_key + +# Step 8: Create mac inventory file + +echo "Creating mac Ansible inventory..." +cat > mac_ansible_hosts.ini << EOF +[local] +localhost ansible_port=PLACEHOLDER ansible_user=ansible ansible_ssh_private_key_file=./mac_ansible_id_rsa ansible_ssh_common_args='-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null' + +[all:vars] +ansible_python_interpreter=/usr/bin/python3 +EOF + +# Step 9: Start container and update inventory + +available_port=$(find_available_port "$BASE_PORT") +start_container "ansible-ready-ubuntu" "$available_port" + +# Update the port in the inventory file +sed -i '' "s/PLACEHOLDER/$available_port/" mac_ansible_hosts.ini + +# Step 10: Wait for SSH service + +echo "Waiting for SSH service to start..." +max_attempts=30 +attempt=1 +while [ $attempt -le $max_attempts ]; do + if check_ssh_ready "$available_port"; then + echo "SSH is ready!" + break + fi + echo "Waiting for SSH to be ready (attempt $attempt/$max_attempts)..." + sleep 2 + attempt=$((attempt + 1)) +done + +if [ $attempt -gt $max_attempts ]; then + echo "SSH service failed to start. Exiting." + exit 1 +fi + +# Step 11: Create ansible.cfg + +cat > mac_ansible.cfg << EOF +[defaults] +interpreter_python = auto_silent +host_key_checking = False +remote_user = ansible + +[privilege_escalation] +become = True +become_method = sudo +become_user = root +become_ask_pass = False +EOF + +# Step 12: Set ANSIBLE_CONFIG and run playbook + +export ANSIBLE_CONFIG=$(pwd)/mac_ansible.cfg + +echo "Running Ansible playbook..." +ansible-playbook -i mac_ansible_hosts.ini tasks.yaml + +echo "Setup complete. Container ansible-ready-ubuntu is ready for testing." + +# Step 13: Run gemini-openai-proxy container + +if docker --debug ps -aq -f name=gemini-openai-proxy; then + echo "Container gemini-openai-proxy already exists. Removing it..." >&2 + docker --debug stop gemini-openai-proxy || true + docker --debug rm gemini-openai-proxy || true +fi + +docker --debug run --restart=unless-stopped -it -d -p 8080:8080 --name gemini-openai-proxy zhu327/gemini-openai-proxy:latest + +# Step 14: Ready to run hackingBuddyGPT + +echo "You can now run ./scripts/mac_start_hackingbuddygpt_against_a_container.sh" + +exit 0 diff --git a/scripts/mac_start_hackingbuddygpt_against_a_container.sh b/scripts/mac_start_hackingbuddygpt_against_a_container.sh new file mode 100755 index 00000000..88d5a940 --- /dev/null +++ b/scripts/mac_start_hackingbuddygpt_against_a_container.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# Purpose: On a Mac, start hackingBuddyGPT against a container +# Usage: ./scripts/mac_start_hackingbuddygpt_against_a_container.sh + +# Enable strict error handling for better script robustness +set -e # Exit immediately if a command exits with a non-zero status +set -u # Treat unset variables as an error and exit immediately +set -o pipefail # Return the exit status of the last command in a pipeline that failed +set -x # Print each command before executing it (useful for debugging) + +cd $(dirname $0) + +bash_version=$(/bin/bash --version | head -n 1 | awk '{print $4}' | cut -d. -f1) + +if (( bash_version < 3 )); then + echo 'Error: Requires Bash version 3 or higher.' + exit 1 +fi + +# Step 1: Install prerequisites + +# setup virtual python environment +cd .. +python -m venv venv +source ./venv/bin/activate + +# install python requirements +pip install -e . + +# Step 2: Request a Gemini API key + +echo You can obtain a Gemini API key from the following URLs: +echo https://aistudio.google.com/ +echo https://aistudio.google.com/app/apikey +echo + +echo "Enter your Gemini API key and press the return key:" + +# Check if GEMINI_API_KEY is set, prompt if not +if [ -z "${GEMINI_API_KEY:-}" ]; then + echo "Enter your Gemini API key and press the return key:" + read -r GEMINI_API_KEY +else + echo "Using existing GEMINI_API_KEY from environment." +fi + +echo + +# Step 3: Start hackingBuddyGPT against a container + +echo "Starting hackingBuddyGPT against a container..." +echo + +PORT=$(docker ps | grep ansible-ready-ubuntu | cut -d ':' -f2 | cut -d '-' -f1) + +# http://localhost:8080 is gemini-openai-proxy + +# gpt-4 maps to gemini-1.5-flash-latest + +# https://github.com/zhu327/gemini-openai-proxy/blob/559085101f0ce5e8c98a94fb75fefd6c7a63d26d/README.md?plain=1#L146 + +# | gpt-4 | gemini-1.5-flash-latest | + +# https://github.com/zhu327/gemini-openai-proxy/blob/559085101f0ce5e8c98a94fb75fefd6c7a63d26d/pkg/adapter/models.go#L60-L61 + +# case strings.HasPrefix(openAiModelName, openai.GPT4): +# return Gemini1Dot5Flash + +# Hence use gpt-4 below in --llm.model=gpt-4 + +# Gemini free tier has a limit of 15 requests per minute, and 1500 requests per day + +# Hence --max_turns 999999999 will exceed the daily limit + +wintermute LinuxPrivesc --llm.api_key=$GEMINI_API_KEY --llm.model=gpt-4 --llm.context_size=1000000 --conn.host=localhost --conn.port $PORT --conn.username=lowpriv --conn.password=trustno1 --conn.hostname=test1 --llm.api_url=http://localhost:8080 --llm.api_backoff=60 --max_turns 999999999 diff --git a/scripts/tasks.yaml b/scripts/tasks.yaml new file mode 100644 index 00000000..944339e4 --- /dev/null +++ b/scripts/tasks.yaml @@ -0,0 +1,33 @@ +--- +- name: perform basic system setup + hosts: all + gather_facts: true + remote_user: root + tasks: + - name: update packages + apt: + update_cache: yes + upgrade: dist + autoremove: yes + - name: "add test user 'lowpriv' with password 'trustno1'" + user: + name: lowpriv + password: "$6$PA4G/HgbK6q3pqR8$UMh7rQ51EW8VBKYs00MYZgvZP/AuLXRb99h5JUbaGm3xuHW4Z39Yv1FrkllSlWeGqU6pQvdOh/s4/rKKbw4fm1" + - name: "set root password to 'aim8Du7h'" + user: + name: 'root' + password: "{{ 'aim8Du7h' | password_hash('sha512') }}" + +- name: suid allow access to gtfo bins + hosts: vuln_suid_gtfo + gather_facts: true + remote_user: root + tasks: + - name: install python-is-python3 to make it easier for the AI + apt: + name: python-is-python3 + state: present + - name: set the suid bit for some binaries + command: chmod u+s /usr/bin/find /usr/bin/python /usr/bin/python3 /usr/bin/python3.11 + # python: ./python -c 'import os; os.execl("/bin/sh", "sh", "-p")' + # find: find . -exec /bin/sh -p \; -quit diff --git a/src/hackingBuddyGPT/capabilities/__init__.py b/src/hackingBuddyGPT/capabilities/__init__.py index f5c1f9ad..09f154dc 100644 --- a/src/hackingBuddyGPT/capabilities/__init__.py +++ b/src/hackingBuddyGPT/capabilities/__init__.py @@ -1,5 +1,13 @@ from .capability import Capability -from .psexec_test_credential import PSExecTestCredential from .psexec_run_command import PSExecRunCommand +from .psexec_test_credential import PSExecTestCredential from .ssh_run_command import SSHRunCommand -from .ssh_test_credential import SSHTestCredential \ No newline at end of file +from .ssh_test_credential import SSHTestCredential + +__all__ = [ + "Capability", + "PSExecRunCommand", + "PSExecTestCredential", + "SSHRunCommand", + "SSHTestCredential", +] diff --git a/src/hackingBuddyGPT/capabilities/capability.py b/src/hackingBuddyGPT/capabilities/capability.py index bff42923..0459a090 100644 --- a/src/hackingBuddyGPT/capabilities/capability.py +++ b/src/hackingBuddyGPT/capabilities/capability.py @@ -1,11 +1,11 @@ import abc import inspect -from typing import Union, Type, Dict, Callable, Any, Iterable +from typing import Any, Callable, Dict, Iterable, Type, Union import openai from openai.types.chat import ChatCompletionToolParam from openai.types.chat.completion_create_params import Function -from pydantic import create_model, BaseModel +from pydantic import BaseModel, create_model class Capability(abc.ABC): @@ -18,12 +18,13 @@ class Capability(abc.ABC): At the moment, this is not yet a very powerful class, but in the near-term future, this will provide an automated way of providing a json schema for the capabilities, which can then be used for function-calling LLMs. """ + @abc.abstractmethod def describe(self) -> str: """ describe should return a string that describes the capability. This is used to generate the help text for the LLM. - + This is a method and not just a simple property on purpose (though it could become a @property in the future, if we don't need the name parameter anymore), so that it can template in some of the capabilities parameters into the description. @@ -37,23 +38,30 @@ def get_name(self) -> str: def __call__(self, *args, **kwargs): """ The actual execution of a capability, please make sure, that the parameters and return type of your - implementation are well typed, as this will make it easier to support full function calling soon. + implementation are well typed, as this is used to properly support function calling. """ pass def to_model(self) -> BaseModel: """ Converts the parameters of the `__call__` function of the capability to a pydantic model, that can be used to - interface with an LLM using eg instructor or the openAI function calling API. + interface with an LLM using eg the openAI function calling API. The model will have the same name as the capability class and will have the same fields as the `__call__`, the `__call__` method can then be accessed by calling the `execute` method of the model. """ sig = inspect.signature(self.__call__) - fields = {param: (param_info.annotation, param_info.default if param_info.default is not inspect._empty else ...) for param, param_info in sig.parameters.items()} + fields = { + param: ( + param_info.annotation, + param_info.default if param_info.default is not inspect._empty else ..., + ) + for param, param_info in sig.parameters.items() + } model_type = create_model(self.__class__.__name__, __doc__=self.describe(), **fields) def execute(model): return self(**model.dict()) + model_type.execute = execute return model_type @@ -76,6 +84,7 @@ def capabilities_to_action_model(capabilities: Dict[str, Capability]) -> Type[Ac This allows the LLM to define an action to be used, which can then simply be called using the `execute` function on the model returned from here. """ + class Model(Action): action: Union[tuple([capability.to_model() for capability in capabilities.values()])] @@ -86,7 +95,11 @@ class Model(Action): SimpleTextHandler = Callable[[str], SimpleTextHandlerResult] -def capabilities_to_simple_text_handler(capabilities: Dict[str, Capability], default_capability: Capability = None, include_description: bool = True) -> tuple[Dict[str, str], SimpleTextHandler]: +def capabilities_to_simple_text_handler( + capabilities: Dict[str, Capability], + default_capability: Capability = None, + include_description: bool = True, +) -> tuple[Dict[str, str], SimpleTextHandler]: """ This function generates a simple text handler from a set of capabilities. It is to be used when no function calling is available, and structured output is not to be trusted, which is why it @@ -97,12 +110,16 @@ def capabilities_to_simple_text_handler(capabilities: Dict[str, Capability], def whether the parsing was successful, the second return value is a tuple containing the capability name, the parameters as a string and the result of the capability execution. """ + def get_simple_fields(func, name) -> Dict[str, Type]: sig = inspect.signature(func) fields = {param: param_info.annotation for param, param_info in sig.parameters.items()} for param, param_type in fields.items(): if param_type not in (str, int, float, bool): - raise ValueError(f"The command {name} is not compatible with this calling convention (this is not a LLM error, but rather a problem with the capability itself, the parameter {param} is {param_type} and not a simple type (str, int, float, bool))") + raise ValueError( + f"The command {name} is not compatible with this calling convention (this is not a LLM error," + f"but rather a problem with the capability itself, the parameter {param} is {param_type} and not a simple type (str, int, float, bool))" + ) return fields def parse_params(fields, params) -> tuple[bool, Union[str, Dict[str, Any]]]: @@ -169,13 +186,14 @@ def default_capability_parser(text: str) -> SimpleTextHandlerResult: return True, (capability_name, params, default_capability(**parsing_result)) - resolved_parser = default_capability_parser return capability_descriptions, resolved_parser -def capabilities_to_functions(capabilities: Dict[str, Capability]) -> Iterable[openai.types.chat.completion_create_params.Function]: +def capabilities_to_functions( + capabilities: Dict[str, Capability], +) -> Iterable[openai.types.chat.completion_create_params.Function]: """ This function takes a dictionary of capabilities and returns a dictionary of functions, that can be called with the parameters of the respective capabilities. @@ -186,13 +204,21 @@ def capabilities_to_functions(capabilities: Dict[str, Capability]) -> Iterable[o ] -def capabilities_to_tools(capabilities: Dict[str, Capability]) -> Iterable[openai.types.chat.completion_create_params.ChatCompletionToolParam]: +def capabilities_to_tools( + capabilities: Dict[str, Capability], +) -> Iterable[openai.types.chat.completion_create_params.ChatCompletionToolParam]: """ This function takes a dictionary of capabilities and returns a dictionary of functions, that can be called with the parameters of the respective capabilities. """ return [ - ChatCompletionToolParam(type="function", function=Function(name=name, description=capability.describe(), parameters=capability.to_model().model_json_schema())) + ChatCompletionToolParam( + type="function", + function=Function( + name=name, + description=capability.describe(), + parameters=capability.to_model().model_json_schema(), + ), + ) for name, capability in capabilities.items() ] - diff --git a/src/hackingBuddyGPT/capabilities/http_request.py b/src/hackingBuddyGPT/capabilities/http_request.py index 3a508d81..d89f12b0 100644 --- a/src/hackingBuddyGPT/capabilities/http_request.py +++ b/src/hackingBuddyGPT/capabilities/http_request.py @@ -1,7 +1,8 @@ import base64 from dataclasses import dataclass +from typing import Dict, Literal, Optional + import requests -from typing import Literal, Optional, Dict from . import Capability @@ -19,26 +20,31 @@ def __post_init__(self): self._client = requests def describe(self) -> str: - description = (f"Sends a request to the host {self.host} using the python requests library and returns the response. The schema and host are fixed and do not need to be provided.\n" - f"Make sure that you send a Content-Type header if you are sending a body.") + description = ( + f"Sends a request to the host {self.host} using the python requests library and returns the response. The schema and host are fixed and do not need to be provided.\n" + f"Make sure that you send a Content-Type header if you are sending a body." + ) if self.use_cookie_jar: description += "\nThe cookie jar is used for storing cookies between requests." else: - description += "\nCookies are not automatically stored, and need to be provided as header manually every time." + description += ( + "\nCookies are not automatically stored, and need to be provided as header manually every time." + ) if self.follow_redirects: description += "\nRedirects are followed." else: description += "\nRedirects are not followed." return description - def __call__(self, - method: Literal["GET", "HEAD", "POST", "PUT", "DELETE", "OPTION", "PATCH"], - path: str, - query: Optional[str] = None, - body: Optional[str] = None, - body_is_base64: Optional[bool] = False, - headers: Optional[Dict[str, str]] = None, - ) -> str: + def __call__( + self, + method: Literal["GET", "HEAD", "POST", "PUT", "DELETE", "OPTION", "PATCH"], + path: str, + query: Optional[str] = None, + body: Optional[str] = None, + body_is_base64: Optional[bool] = False, + headers: Optional[Dict[str, str]] = None, + ) -> str: if body is not None and body_is_base64: body = base64.b64decode(body).decode() if self.host[-1] != "/": @@ -64,7 +70,7 @@ def __call__(self, url = self.host + ("" if path.startswith("/") else "/") + path + ("?{query}" if query else "") return f"Could not request '{url}': {e}" - headers = "\r\n".join(f"{k}: {v}" for k, v in resp.headers.items()) + response_headers = "\r\n".join(f"{k}: {v}" for k, v in resp.headers.items()) # turn the response into "plain text format" for responding to the prompt - return f"HTTP/1.1 {resp.status_code} {resp.reason}\r\n{headers}\r\n\r\n{resp.text}""" + return f"HTTP/1.1 {resp.status_code} {resp.reason}\r\n{response_headers}\r\n\r\n{resp.text}" diff --git a/src/hackingBuddyGPT/capabilities/psexec_run_command.py b/src/hackingBuddyGPT/capabilities/psexec_run_command.py index f0a47913..7c30faad 100644 --- a/src/hackingBuddyGPT/capabilities/psexec_run_command.py +++ b/src/hackingBuddyGPT/capabilities/psexec_run_command.py @@ -2,6 +2,7 @@ from typing import Tuple from hackingBuddyGPT.utils import PSExecConnection + from .capability import Capability @@ -11,7 +12,7 @@ class PSExecRunCommand(Capability): @property def describe(self) -> str: - return f"give a command to be executed on the shell and I will respond with the terminal output when running this command on the windows machine. The given command must not require user interaction. Only state the to be executed command. The command should be used for enumeration or privilege escalation." + return "give a command to be executed on the shell and I will respond with the terminal output when running this command on the windows machine. The given command must not require user interaction. Only state the to be executed command. The command should be used for enumeration or privilege escalation." def __call__(self, command: str) -> Tuple[str, bool]: return self.conn.run(command)[0], False diff --git a/src/hackingBuddyGPT/capabilities/psexec_test_credential.py b/src/hackingBuddyGPT/capabilities/psexec_test_credential.py index 7cebcaaf..9e4bbef1 100644 --- a/src/hackingBuddyGPT/capabilities/psexec_test_credential.py +++ b/src/hackingBuddyGPT/capabilities/psexec_test_credential.py @@ -3,6 +3,7 @@ from typing import Tuple from hackingBuddyGPT.utils import PSExecConnection + from .capability import Capability @@ -11,7 +12,7 @@ class PSExecTestCredential(Capability): conn: PSExecConnection def describe(self) -> str: - return f"give credentials to be tested" + return "give credentials to be tested" def get_name(self) -> str: return "test_credential" @@ -20,7 +21,10 @@ def __call__(self, username: str, password: str) -> Tuple[str, bool]: try: test_conn = self.conn.new_with(username=username, password=password) test_conn.init() - warnings.warn("full credential testing is not implemented yet for psexec, we have logged in, but do not know who we are, returning True for now") + warnings.warn( + message="full credential testing is not implemented yet for psexec, we have logged in, but do not know who we are, returning True for now", + stacklevel=1, + ) return "Login as root was successful\n", True except Exception: return "Authentication error, credentials are wrong\n", False diff --git a/src/hackingBuddyGPT/capabilities/record_note.py b/src/hackingBuddyGPT/capabilities/record_note.py index 7e773125..6a45bb71 100644 --- a/src/hackingBuddyGPT/capabilities/record_note.py +++ b/src/hackingBuddyGPT/capabilities/record_note.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Tuple, List +from typing import List, Tuple from . import Capability diff --git a/src/hackingBuddyGPT/capabilities/ssh_run_command.py b/src/hackingBuddyGPT/capabilities/ssh_run_command.py index c0a30ff0..6c4d69d1 100644 --- a/src/hackingBuddyGPT/capabilities/ssh_run_command.py +++ b/src/hackingBuddyGPT/capabilities/ssh_run_command.py @@ -1,21 +1,23 @@ import re - from dataclasses import dataclass -from invoke import Responder from io import StringIO from typing import Tuple +from invoke import Responder + from hackingBuddyGPT.utils import SSHConnection from hackingBuddyGPT.utils.shell_root_detection import got_root + from .capability import Capability + @dataclass class SSHRunCommand(Capability): conn: SSHConnection timeout: int = 10 def describe(self) -> str: - return f"give a command to be executed and I will respond with the terminal output when running this command over SSH on the linux machine. The given command must not require user interaction." + return "give a command to be executed and I will respond with the terminal output when running this command over SSH on the linux machine. The given command must not require user interaction. Do not use quotation marks in front and after your command." def get_name(self): return "exec_command" @@ -23,30 +25,33 @@ def get_name(self): def __call__(self, command: str) -> Tuple[str, bool]: if command.startswith(self.get_name()): cmd_parts = command.split(" ", 1) - command = cmd_parts[1] + if len(cmd_parts) == 1: + command = "" + else: + command = cmd_parts[1] sudo_pass = Responder( - pattern=r'\[sudo\] password for ' + self.conn.username + ':', - response=self.conn.password + '\n', + pattern=r"\[sudo\] password for " + self.conn.username + ":", + response=self.conn.password + "\n", ) out = StringIO() try: - resp = self.conn.run(command, pty=True, warn=True, out_stream=out, watchers=[sudo_pass], timeout=self.timeout) - except Exception as e: + self.conn.run(command, pty=True, warn=True, out_stream=out, watchers=[sudo_pass], timeout=self.timeout) + except Exception: print("TIMEOUT! Could we have become root?") out.seek(0) tmp = "" last_line = "" for line in out.readlines(): - if not line.startswith('[sudo] password for ' + self.conn.username + ':'): + if not line.startswith("[sudo] password for " + self.conn.username + ":"): line.replace("\r", "") last_line = line tmp = tmp + line # remove ansi shell codes - ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') - last_line = ansi_escape.sub('', last_line) + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + last_line = ansi_escape.sub("", last_line) return tmp, got_root(self.conn.hostname, last_line) diff --git a/src/hackingBuddyGPT/capabilities/ssh_test_credential.py b/src/hackingBuddyGPT/capabilities/ssh_test_credential.py index 2f6dd4bb..efa3b57c 100644 --- a/src/hackingBuddyGPT/capabilities/ssh_test_credential.py +++ b/src/hackingBuddyGPT/capabilities/ssh_test_credential.py @@ -1,9 +1,10 @@ from dataclasses import dataclass from typing import Tuple - +from paramiko.ssh_exception import SSHException import paramiko from hackingBuddyGPT.utils import SSHConnection + from .capability import Capability @@ -12,7 +13,7 @@ class SSHTestCredential(Capability): conn: SSHConnection def describe(self) -> str: - return f"give credentials to be tested" + return "give credentials to be tested." def get_name(self): return "test_credential" @@ -20,8 +21,22 @@ def get_name(self): def __call__(self, username: str, password: str) -> Tuple[str, bool]: test_conn = self.conn.new_with(username=username, password=password) try: - test_conn.init() - user = test_conn.run("whoami")[0].strip('\n\r ') + for attempt in range(10): + try: + test_conn.init() + break; + except paramiko.ssh_exception.AuthenticationException: + return "Authentication error, credentials are wrong\n", False + except SSHException as e: + if attempt == 9: + raise + print("-------------------------------------------------------") + print(e) + print("Retrying") + print("-------------------------------------------------------") + + + user = test_conn.run("whoami")[0].strip("\n\r ") if user == "root": return "Login as root was successful\n", True else: diff --git a/src/hackingBuddyGPT/capabilities/submit_flag.py b/src/hackingBuddyGPT/capabilities/submit_flag.py index b481fd97..35da7e6c 100644 --- a/src/hackingBuddyGPT/capabilities/submit_flag.py +++ b/src/hackingBuddyGPT/capabilities/submit_flag.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Tuple, List, Set, Callable +from typing import Callable, Set from . import Capability diff --git a/src/hackingBuddyGPT/capabilities/submit_http_method.py b/src/hackingBuddyGPT/capabilities/submit_http_method.py index 7a9d40b9..79ebb25d 100644 --- a/src/hackingBuddyGPT/capabilities/submit_http_method.py +++ b/src/hackingBuddyGPT/capabilities/submit_http_method.py @@ -1,10 +1,8 @@ import base64 from dataclasses import dataclass, field -from typing import Set, Dict, Callable, Literal, Optional -import inspect +from typing import Callable, Dict, Literal, Optional, Set import requests -from pydantic import create_model, BaseModel from . import Capability @@ -18,39 +16,20 @@ class SubmitHTTPMethod(Capability): follow_redirects: bool = False success_function: Callable[[], None] = None - submitted_valid_http_methods: Set[str] = field(default_factory=set, init=False) def describe(self) -> str: return f"Submits a HTTP methods, which is the goal of all activities. The HTTP methods format is {self.http_methods_format}. If you find a flag following the that format, that you have not yet submitted, then the most important first step is to submit it, ignoring all other possibilities of further action" - def to_model(self) -> BaseModel: - """ - Converts the parameters of the `__call__` function of the capability to a pydantic model, that can be used to - interface with an LLM using eg instructor or the openAI function calling API. - The model will have the same name as the capability class and will have the same fields as the `__call__`, - the `__call__` method can then be accessed by calling the `execute` method of the model. - """ - sig = inspect.signature(self.__call__) - fields = {param: (param_info.annotation, ...) for param, param_info in sig.parameters.items()} - model_type = create_model(self.__class__.__name__, __doc__=self.describe(), **fields) - - def execute(model): - m = model.dict() - return self(**m) - - model_type.execute = execute - - return model_type - - def __call__(self, method: Literal["GET", "HEAD", "POST", "PUT", "DELETE", "OPTION", "PATCH"], - path: str, - query: Optional[str] = None, - body: Optional[str] = None, - body_is_base64: Optional[bool] = False, - headers: Optional[Dict[str, str]] = None - ) -> str: - + def __call__( + self, + method: Literal["GET", "HEAD", "POST", "PUT", "DELETE", "OPTION", "PATCH"], + path: str, + query: Optional[str] = None, + body: Optional[str] = None, + body_is_base64: Optional[bool] = False, + headers: Optional[Dict[str, str]] = None, + ) -> str: if body is not None and body_is_base64: body = base64.b64decode(body).decode() @@ -74,5 +53,4 @@ def __call__(self, method: Literal["GET", "HEAD", "POST", "PUT", "DELETE", "OPTI else: return "All methods submitted, congratulations" # turn the response into "plain text format" for responding to the prompt - return f"HTTP/1.1 {resp.status_code} {resp.reason}\r\n{headers}\r\n\r\n{resp.text}""" - + return f"HTTP/1.1 {resp.status_code} {resp.reason}\r\n{headers}\r\n\r\n{resp.text}" diff --git a/src/hackingBuddyGPT/capabilities/yamlFile.py b/src/hackingBuddyGPT/capabilities/yamlFile.py index e46f3577..c5283ec1 100644 --- a/src/hackingBuddyGPT/capabilities/yamlFile.py +++ b/src/hackingBuddyGPT/capabilities/yamlFile.py @@ -1,35 +1,34 @@ -from dataclasses import dataclass, field -from typing import Tuple, List +from dataclasses import dataclass import yaml from . import Capability + @dataclass class YAMLFile(Capability): - def describe(self) -> str: return "Takes a Yaml file and updates it with the given information" def __call__(self, yaml_str: str) -> str: """ - Updates a YAML string based on provided inputs and returns the updated YAML string. + Updates a YAML string based on provided inputs and returns the updated YAML string. - Args: - yaml_str (str): Original YAML content in string form. - updates (dict): A dictionary representing the updates to be applied. + Args: + yaml_str (str): Original YAML content in string form. + updates (dict): A dictionary representing the updates to be applied. - Returns: - str: Updated YAML content as a string. - """ + Returns: + str: Updated YAML content as a string. + """ try: # Load the YAML content from string data = yaml.safe_load(yaml_str) - print(f'Updates:{yaml_str}') + print(f"Updates:{yaml_str}") # Apply updates from the updates dictionary - #for key, value in updates.items(): + # for key, value in updates.items(): # if key in data: # data[key] = value # else: @@ -37,8 +36,8 @@ def __call__(self, yaml_str: str) -> str: # data[key] = value # ## Convert the updated dictionary back into a YAML string - #updated_yaml_str = yaml.safe_dump(data, sort_keys=False) - #return updated_yaml_str + # updated_yaml_str = yaml.safe_dump(data, sort_keys=False) + # return updated_yaml_str except yaml.YAMLError as e: print(f"Error processing YAML data: {e}") - return "None" \ No newline at end of file + return "None" diff --git a/src/hackingBuddyGPT/cli/stats.py b/src/hackingBuddyGPT/cli/stats.py deleted file mode 100755 index 7f9b13dc..00000000 --- a/src/hackingBuddyGPT/cli/stats.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/python3 - -import argparse - -from utils.db_storage import DbStorage -from rich.console import Console -from rich.table import Table - -# setup infrastructure for outputing information -console = Console() - -parser = argparse.ArgumentParser(description='View an existing log file.') -parser.add_argument('log', type=str, help='sqlite3 db for reading log data') -args = parser.parse_args() -console.log(args) - -# setup in-memory/persistent storage for command history -db = DbStorage(args.log) -db.connect() -db.setup_db() - -# experiment names -names = { - "1" : "suid-gtfo", - "2" : "sudo-all", - "3" : "sudo-gtfo", - "4" : "docker", - "5" : "cron-script", - "6" : "pw-reuse", - "7" : "pw-root", - "8" : "vacation", - "9" : "ps-bash-hist", - "10" : "cron-wildcard", - "11" : "ssh-key", - "12" : "cron-script-vis", - "13" : "cron-wildcard-vis" -} - -# prepare table -table = Table(title="Round Data", show_header=True, show_lines=True) -table.add_column("RunId", style="dim") -table.add_column("Description", style="dim") -table.add_column("Round", style="dim") -table.add_column("State") -table.add_column("Last Command") - -data = db.get_log_overview() -for run in data: - row = data[run] - table.add_row(str(run), names[str(run)], str(row["max_round"]), row["state"], row["last_cmd"]) - -console.print(table) diff --git a/src/hackingBuddyGPT/cli/viewer.py b/src/hackingBuddyGPT/cli/viewer.py deleted file mode 100755 index cca83884..00000000 --- a/src/hackingBuddyGPT/cli/viewer.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/python3 - -import argparse - -from utils.db_storage import DbStorage -from rich.console import Console -from rich.panel import Panel -from rich.table import Table - - -# helper to fill the history table with data from the db -def get_history_table(run_id: int, db: DbStorage, round: int) -> Table: - table = Table(title="Executed Command History", show_header=True, show_lines=True) - table.add_column("ThinkTime", style="dim") - table.add_column("Tokens", style="dim") - table.add_column("Cmd") - table.add_column("Resp. Size", justify="right") - #if config.enable_explanation: - # table.add_column("Explanation") - # table.add_column("ExplTime", style="dim") - # table.add_column("ExplTokens", style="dim") - #if config.enable_update_state: - # table.add_column("StateUpdTime", style="dim") - # table.add_column("StateUpdTokens", style="dim") - - for i in range(0, round+1): - table.add_row(*db.get_round_data(run_id, i, explanation=False, status_update=False)) - #, config.enable_explanation, config.enable_update_state)) - - return table - -# setup infrastructure for outputing information -console = Console() - -parser = argparse.ArgumentParser(description='View an existing log file.') -parser.add_argument('log', type=str, help='sqlite3 db for reading log data') -args = parser.parse_args() -console.log(args) - -# setup in-memory/persistent storage for command history -db = DbStorage(args.log) -db.connect() -db.setup_db() - -# setup round meta-data -run_id : int = 1 -round : int = 0 - -# read run data - -run = db.get_run_data(run_id) -while run is not None: - if run[4] is None: - console.print(Panel(f"run: {run[0]}/{run[1]}\ntest: {run[2]}\nresult: {run[3]}", title="Run Data")) - else: - console.print(Panel(f"run: {run[0]}/{run[1]}\ntest: {run[2]}\nresult: {run[3]} after {run[4]} rounds", title="Run Data")) - console.log(run[5]) - - # Output Round Data - console.print(get_history_table(run_id, db, run[4]-1)) - - # fetch next run - run_id += 1 - run = db.get_run_data(run_id) diff --git a/src/hackingBuddyGPT/cli/wintermute.py b/src/hackingBuddyGPT/cli/wintermute.py index 85552b3b..fef60959 100644 --- a/src/hackingBuddyGPT/cli/wintermute.py +++ b/src/hackingBuddyGPT/cli/wintermute.py @@ -2,21 +2,22 @@ import sys from hackingBuddyGPT.usecases.base import use_cases +from hackingBuddyGPT.utils.configurable import CommandMap, InvalidCommand, Parseable, instantiate def main(): - parser = argparse.ArgumentParser() - subparser = parser.add_subparsers(required=True) - for name, use_case in use_cases.items(): - use_case.build_parser(subparser.add_parser( - name=use_case.name, - help=use_case.description - )) - - parsed = parser.parse_args(sys.argv[1:]) - instance = parsed.use_case(parsed) - instance.init() - instance.run() + use_case_parsers: CommandMap = { + name: Parseable(use_case, description=use_case.description) + for name, use_case in use_cases.items() + } + try: + instance, configuration = instantiate(sys.argv, use_case_parsers) + except InvalidCommand as e: + if len(f"{e}") > 0: + print(e) + print(e.usage) + sys.exit(1) + instance.run(configuration) if __name__ == "__main__": diff --git a/src/hackingBuddyGPT/resources/webui/static/client.js b/src/hackingBuddyGPT/resources/webui/static/client.js new file mode 100644 index 00000000..2f92daa9 --- /dev/null +++ b/src/hackingBuddyGPT/resources/webui/static/client.js @@ -0,0 +1,373 @@ +/* jshint esversion: 9, browser: true */ +/* global console */ + +(function() { + "use strict"; + + function debounce(func, wait = 100, immediate = false) { + let timeout; + return function () { + const context = this, + args = arguments; + const later = function () { + timeout = null; + if (!immediate) { + func.apply(context, args); + } + }; + const callNow = immediate && !timeout; + clearTimeout(timeout); + timeout = setTimeout(later, wait); + if (callNow) { + func.apply(context, args); + } + }; + } + + function isScrollAtBottom() { + const content = document.getElementById("main-body"); + console.log( + "scroll check", + content.scrollHeight, + content.scrollTop, + content.clientHeight, + ); + return content.scrollHeight - content.scrollTop <= content.clientHeight + 30; + } + + function scrollUpdate(wasAtBottom) { + const content = document.getElementById("main-body"); + if (wasAtBottom) { + console.log("scrolling to bottom"); + content.scrollTop = content.scrollHeight; + } + } + + const sidebar = document.getElementById("sidebar"); + const menuToggles = document.getElementsByClassName("menu-toggle"); + Array.from(menuToggles).forEach((menuToggle) => { + menuToggle.addEventListener("click", () => { + sidebar.classList.toggle("active"); + }); + }); + + let ws = null; + let currentRun = null; + + const followNewRunsCheckbox = document.getElementById("follow_new_runs"); + let followNewRuns = + !window.location.hash && localStorage.getItem("followNewRuns") === "true"; + followNewRunsCheckbox.checked = followNewRuns; + + followNewRunsCheckbox.addEventListener("change", () => { + followNewRuns = followNewRunsCheckbox.checked; + localStorage.setItem("followNewRuns", followNewRuns); + }); + + let send = function (type, data) { + const message = {type: type, data: data}; + console.log("> sending ", message); + ws.send(JSON.stringify(message)); + }; + + function initWebsocket() { + console.log("initializing websocket"); + ws = new WebSocket( + `ws${location.protocol === "https:" ? "s" : ""}://${location.host}/client`, + ); + + let runs = {}; + + ws.addEventListener("open", () => { + ws.addEventListener("message", (event) => { + const message = JSON.parse(event.data); + console.log("< receiving", message); + const {type, data} = message; + + const wasAtBottom = isScrollAtBottom(); + switch (type) { + case "Run": + handleRunMessage(data); + break; + case "Section": + handleSectionMessage(data); + break; + case "Message": + handleMessage(data); + break; + case "MessageStreamPart": + handleMessageStreamPart(data); + break; + case "ToolCall": + handleToolCall(data); + break; + case "ToolCallStreamPart": + handleToolCallStreamPart(data); + break; + default: + console.warn("Unknown message type:", type); + } + scrollUpdate(wasAtBottom); + }); + + function createRunListEntry(runId) { + const runList = document.getElementById("run-list"); + const template = document.getElementById("run-list-entry-template"); + const runListEntry = template.content + .cloneNode(true) + .querySelector(".run-list-entry"); + runListEntry.id = `run-list-entry-${runId}`; + const a = runListEntry.querySelector("a"); + a.href = "#" + runId; + a.addEventListener("click", () => { + selectRun(runId); + }); + runList.insertBefore(runListEntry, runList.firstChild); + return runListEntry; + } + + function handleRunMessage(run) { + runs[run.id] = run; + let li = document.getElementById(`run-list-entry-${run.id}`); + if (!li) { + li = createRunListEntry(run.id); + } + + li.querySelector(".run-id").textContent = `Run ${run.id}`; + li.querySelector(".run-model").tExtContent = run.model; + li.querySelector(".run-tags").textContent = run.tag; + li.querySelector(".run-started-at").textContent = run.started_at.slice( + 0, + -3, + ); + if (run.stopped_at) { + li.querySelector(".run-stopped-at").textContent = run.stopped_at.slice( + 0, + -3, + ); + } + li.querySelector(".run-state").textContent = run.state; + + const followNewRunsCheckbox = document.getElementById("follow_new_runs"); + if (followNewRunsCheckbox.checked) { + selectRun(run.id); + } + } + + function addSectionDiv(sectionId) { + const messagesDiv = document.getElementById("messages"); + const template = document.getElementById("section-template"); + const sectionDiv = template.content + .cloneNode(true) + .querySelector(".section"); + sectionDiv.id = `section-${sectionId}`; + messagesDiv.appendChild(sectionDiv); + return sectionDiv; + } + + let sectionColumns = []; + + function handleSectionMessage(section) { + console.log("handling section message", section); + section.from_message += 1; + if (section.to_message === null) { + section.to_message = 99999; + } + section.to_message += 1; + + let sectionDiv = document.getElementById(`section-${section.id}`); + if (!!sectionDiv) { + let columnNumber = sectionDiv.getAttribute("columnNumber"); + let columnPosition = sectionDiv.getAttribute("columnPosition"); + sectionColumns[columnNumber].splice(columnPosition - 1, 1); + sectionDiv.remove(); + } + sectionDiv = addSectionDiv(section.id); + sectionDiv.querySelector(".section-name").textContent = + `${section.name} ${section.duration.toFixed(3)}s`; + + let columnNumber = 0; + let columnPosition = 0; + + // loop over the existing section Columns (format is a list of lists, whereby the inner list is [from_message, from_message], with end_message possibly being None) + let found = false; + for (let i = 0; i < sectionColumns.length; i++) { + const column = sectionColumns[i]; + let columnFits = true; + for (let j = 0; j < column.length; j++) { + const [from_message, to_message] = column[j]; + if ( + section.from_message < to_message && + from_message < section.to_message + ) { + columnFits = false; + break; + } + } + if (!columnFits) { + continue; + } + + column.push([section.from_message, section.to_message]); + columnNumber = i; + columnPosition = column.length; + found = true; + break; + } + if (!found) { + sectionColumns.push([[section.from_message, section.to_message]]); + document.documentElement.style.setProperty( + "--section-column-count", + sectionColumns.length, + ); + console.log( + "added section column", + sectionColumns.length, + sectionColumns, + ); + } + + sectionDiv.style = `grid-column: ${columnNumber}; grid-row: ${section.from_message} / ${section.to_message};`; + sectionDiv.setAttribute("columnNumber", columnNumber); + sectionDiv.setAttribute("columnPosition", columnPosition); + } + + function addMessageDiv(messageId, role) { + const messagesDiv = document.getElementById("messages"); + const template = document.getElementById("message-template"); + const messageDiv = template.content + .cloneNode(true) + .querySelector(".message"); + messageDiv.id = `message-${messageId}`; + messageDiv.style = `grid-row: ${messageId + 1};`; + if (role === "system") { + messageDiv.removeAttribute("open"); + } + messageDiv.querySelector(".tool-calls").id = + `message-${messageId}-tool-calls`; + messagesDiv.appendChild(messageDiv); + return messageDiv; + } + + function handleMessage(message) { + let messageDiv = document.getElementById(`message-${message.id}`); + if (!messageDiv) { + messageDiv = addMessageDiv(message.id, message.role); + } + if (message.content && message.content.length > 0) { + messageDiv.getElementsByTagName("pre")[0].textContent = message.content; + } + messageDiv.querySelector(".role").textContent = message.role; + messageDiv.querySelector(".duration").textContent = + `${message.duration.toFixed(3)} s`; + messageDiv.querySelector(".tokens-query").textContent = + `${message.tokens_query} qry tokens`; + messageDiv.querySelector(".tokens-response").textContent = + `${message.tokens_response} rsp tokens`; + } + + function handleMessageStreamPart(part) { + let messageDiv = document.getElementById(`message-${part.message_id}`); + if (!messageDiv) { + messageDiv = addMessageDiv(part.message_id); + } + messageDiv.getElementsByTagName("pre")[0].textContent += part.content; + } + + function addToolCallDiv(messageId, toolCallId, functionName) { + const toolCallsDiv = document.getElementById( + `message-${messageId}-tool-calls`, + ); + const template = document.getElementById("message-tool-call"); + const toolCallDiv = template.content + .cloneNode(true) + .querySelector(".tool-call"); + toolCallDiv.id = `message-${messageId}-tool-call-${toolCallId}`; + toolCallDiv.querySelector(".tool-call-function").textContent = + functionName; + toolCallsDiv.appendChild(toolCallDiv); + return toolCallDiv; + } + + function handleToolCall(toolCall) { + let toolCallDiv = document.getElementById( + `message-${toolCall.message_id}-tool-call-${toolCall.id}`, + ); + if (!toolCallDiv) { + toolCallDiv = addToolCallDiv( + toolCall.message_id, + toolCall.id, + toolCall.function_name, + ); + } + toolCallDiv.querySelector(".tool-call-state").textContent = + toolCall.state; + toolCallDiv.querySelector(".tool-call-duration").textContent = + `${toolCall.duration.toFixed(3)} s`; + toolCallDiv.querySelector(".tool-call-parameters").textContent = + toolCall.arguments; + toolCallDiv.querySelector(".tool-call-results").textContent = + toolCall.result_text; + } + + function handleToolCallStreamPart(part) { + const messageDiv = document.getElementById( + `message-${part.message_id}-tool-calls`, + ); + if (messageDiv) { + let toolCallDiv = messageDiv.querySelector( + `.tool-call-${part.tool_call_id}`, + ); + if (!toolCallDiv) { + toolCallDiv = document.createElement("div"); + toolCallDiv.className = `tool-call tool-call-${part.tool_call_id}`; + messageDiv.appendChild(toolCallDiv); + } + toolCallDiv.textContent += part.content; + } + } + + const selectRun = debounce((runId) => { + console.error("selectRun", runId, currentRun); + if (runId === currentRun) { + return; + } + + document.getElementById("messages").innerHTML = ""; + sectionColumns = []; + document.documentElement.style.setProperty("--section-column-count", 0); + send("MessageRequest", {follow_run: runId}); + currentRun = runId; + // set hash to runId via pushState + window.location.hash = runId; + sidebar.classList.remove("active"); + document.getElementById("main-run-title").textContent = `Run ${runId}`; + + // try to json parse and pretty print the run configuration into `#run-config` + try { + const config = JSON.parse(runs[runId].configuration); + document.getElementById("run-config").textContent = JSON.stringify( + config, + null, + 2, + ); + } catch (e) { + document.getElementById("run-config").textContent = + runs[runId].configuration; + } + }); + if (window.location.hash) { + selectRun(parseInt(window.location.hash.slice(1), 10)); + } else { + // toggle the sidebar if no run is selected + sidebar.classList.add("active"); + document.getElementById("main-run-title").textContent = + "Please select a run"; + } + + ws.addEventListener("close", initWebsocket); + }); + } + + initWebsocket(); +})(); \ No newline at end of file diff --git a/src/hackingBuddyGPT/resources/webui/static/favicon.ico b/src/hackingBuddyGPT/resources/webui/static/favicon.ico new file mode 100644 index 00000000..474dae34 Binary files /dev/null and b/src/hackingBuddyGPT/resources/webui/static/favicon.ico differ diff --git a/src/hackingBuddyGPT/resources/webui/static/style.css b/src/hackingBuddyGPT/resources/webui/static/style.css new file mode 100644 index 00000000..de021c0d --- /dev/null +++ b/src/hackingBuddyGPT/resources/webui/static/style.css @@ -0,0 +1,365 @@ +/* Reset default margin and padding */ +:root { + --section-count: 0; + --section-column-count: 0; +} + +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: Arial, sans-serif; +} + +pre { + white-space: pre-wrap; +} + +pre.binary { + white-space: break-spaces; + word-break: break-all; + word-wrap: anywhere; + overflow-wrap: anywhere; + -webkit-hyphens: auto; + hyphens: auto; + -webkit-line-break: after-white-space; +} + +details summary { + list-style: none; + cursor: pointer; +} +details summary::-webkit-details-marker { + display: none; +} + +.container { + display: grid; + grid-template-columns: 250px 1fr; + height: 100vh; + overflow: hidden; +} + +/* Sidebar styling */ +.sidebar { + background-color: #333; + color: white; + padding: 0 1rem 1rem; + height: 100%; + overflow: scroll; + z-index: 100; +} + +.sidebar ul { + list-style: none; + padding: 0; +} + +.sidebar li { + margin-bottom: 1rem; +} + +.sidebar a { + color: white; + text-decoration: none; +} + +.sidebar a:hover { + text-decoration: underline; +} + +.sidebar #run-list { + margin-top: 6.5rem; + padding-top: 1rem; +} + +.sidebar .run-list-entry a { + display: flex; + flex-direction: row; + justify-content: space-between; + align-items: center; + width: 100%; +} + +.sidebar .run-list-entry a > div { + display: flex; + flex-direction: column; +} + +.sidebar .run-list-info { + flex-grow: 1; +} + +.sidebar .run-list-info span { + color: lightgray; + font-size: small; +} + +.sidebar .run-list-timing { + flex-shrink: 0; + font-size: small; + color: lightgray; +} + +#follow-new-runs-container { + margin: 1.5rem 1rem 1rem; +} + +/* Main content styling */ +#main-body { + background-color: #f4f4f4; + height: 100%; + overflow: auto; +} + +#sidebar-header-container { + margin-left: -1rem; + height: 6.5rem; + display: flex; + flex-direction: column; + justify-content: start; + position: fixed; + background-color: #333; +} + +#sidebar-header, +#run-header { + display: flex; + flex-direction: row; + height: 3rem; + align-items: center; +} + +#run-header { + position: fixed; + background-color: #f4f4f4; + z-index: 50; + width: 100%; + border-top: 4px solid #333; + border-bottom: 4px solid #333; +} + +#black-block { + position: fixed; + height: 6.5rem; + width: calc(2rem + var(--section-column-count) * 1rem); + background-color: #333; + z-index: 25; +} + +#run-header .menu-toggle { + background-color: #333; + color: #333; + width: 6rem; + height: 3rem; +} + +#run-header #main-run-title { + display: inline-block; + flex-grow: 1; +} + +#sidebar-header .menu-toggle { + background-color: #333; + color: #f4f4f4; + width: 3rem; + height: 3rem; +} +.menu-toggle { + background: none; + border: none; + font-size: 24px; + line-height: 22px; + margin-right: 0.5rem; + color: white; +} + +.small { + font-size: small; +} + +#run-config-details { + padding-top: 3rem; + border-left: calc(2rem + var(--section-column-count) * 1rem) solid #333; +} + +#run-config-details summary { + /*background-color: #333; + color: white;*/ + padding: 0.3rem 0.3rem 0.3rem 1rem; + height: 3.5rem; + display: flex; + align-items: center; +} + +#run-config-details pre { + margin: 0 1rem; + padding-bottom: 1rem; +} + +#messages { + margin: 0 1rem 1rem; + display: grid; + /* this 1000 is a little bit of a hack, as other methods for auto sizing don't seem to work. Keep this one less than the number used as grid-column in .message */ + grid-template-columns: repeat(1000, min-content) 1fr; + grid-auto-rows: auto; + grid-gap: 0; +} + +.section { + display: flex; + flex-direction: column; + align-items: center; + position: relative; + width: 1rem; + justify-self: center; +} + +.section .line { + width: 4px; + background: black; + min-height: 0.2rem; + flex-grow: 1; +} + +.section .end-line { + margin-bottom: 1rem; +} + +.section span { + transform: rotate(-90deg); + padding: 0 4px; + margin: 5px 0; + white-space: nowrap; + background-color: #f4f4f4; +} + +.message { + /* this 1000 is a little bit of a hack, as other methods for auto sizing don't seem to work. Keep this one more than the number used in grid-template-columns in .messages */ + grid-column: calc(1001); + margin-left: 1rem; + margin-bottom: 1rem; + background-color: #f9f9f9; + border-left: 4px solid #333; +} + +/* this applies to both the message header as well as the individual tool calls */ +.message header { + background-color: #333; + color: white; + padding: 0.5rem; + display: flex; +} + +.message .tool-call header { + flex-direction: row; + justify-content: space-between; +} + +.message .message-header { + flex-direction: column; +} +.message .message-header > div { + display: flex; + flex-direction: row; + justify-content: space-between; +} + +.message .message-text { + margin: 1rem; +} + +.message .tool-calls { + margin: 1rem; + display: flex; + flex-direction: row; + flex-wrap: wrap; + gap: 1rem; +} + +.message .tool-call { + border: 2px solid #333; + border-radius: 4px; + padding-top: 0; + height: 100%; + width: 100%; +} + +.message .tool-call-parameters { + border-left: 4px solid lightgreen; + padding: 1rem 0.5rem; +} + +.message .tool-call-results { + border-left: 4px solid lightcoral; + padding: 1rem 0.5rem; +} + +/* Responsive behavior */ +@media (max-width: 1468px) { + .container { + grid-template-columns: 1fr; + } + + .sidebar { + position: absolute; + width: 100vw; + height: 100%; + top: 0; + left: -100vw; /* Hidden off-screen by default */ + transition: left 0.3s ease; + } + + #main-body { + grid-column: span 2; + } + + #sidebar-header .menu-toggle, + #run-header .menu-toggle { + display: inline-block; + cursor: pointer; + } + + /* Show the sidebar when toggled */ + .sidebar.active { + left: 0; + } + + #messages, + .message { + margin-left: 0.5rem; + margin-right: 0; + } + #run-header .menu-toggle { + width: 4rem; + color: white; + } + #run-config-details { + border-left: calc(1rem + var(--section-column-count) * 1rem) solid #333; + } + #black-block { + width: calc(1rem + var(--section-column-count) * 1rem); + } + + #sidebar-header-container { + width: 100%; + } + #sidebar-header .menu-toggle { + color: black; + background-color: #f4f4f4; + } + #sidebar-header { + border-top: 4px solid #f4f4f4; + border-bottom: 4px solid #f4f4f4; + width: 100%; + } + .sidebar #run-list { + margin-left: 2.5rem; + } + #follow-new-runs-container { + margin-left: 3.5rem; + } +} diff --git a/src/hackingBuddyGPT/resources/webui/templates/index.html b/src/hackingBuddyGPT/resources/webui/templates/index.html new file mode 100644 index 00000000..6a8475da --- /dev/null +++ b/src/hackingBuddyGPT/resources/webui/templates/index.html @@ -0,0 +1,96 @@ + + + + + + + hackingBuddyGPT + + +
+ + +
+
+ +

+
+
+
+ +

Configuration

+
+

+                
+
+
+
+ + + + + + + diff --git a/src/hackingBuddyGPT/usecases/__init__.py b/src/hackingBuddyGPT/usecases/__init__.py index b69e09cf..e945bfbf 100644 --- a/src/hackingBuddyGPT/usecases/__init__.py +++ b/src/hackingBuddyGPT/usecases/__init__.py @@ -1,4 +1,6 @@ -from .privesc import * from .examples import * +from .privesc import * from .web import * from .web_api_testing import * +from .viewer import * +from .rag import * \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/agents.py b/src/hackingBuddyGPT/usecases/agents.py index a018b58b..650c7db1 100644 --- a/src/hackingBuddyGPT/usecases/agents.py +++ b/src/hackingBuddyGPT/usecases/agents.py @@ -1,30 +1,34 @@ +import datetime from abc import ABC, abstractmethod from dataclasses import dataclass, field from mako.template import Template -from rich.panel import Panel from typing import Dict -from hackingBuddyGPT.usecases.base import Logger +from hackingBuddyGPT.utils.logging import log_conversation, Logger, log_param +from hackingBuddyGPT.capabilities.capability import ( + Capability, + capabilities_to_simple_text_handler, +) from hackingBuddyGPT.utils import llm_util -from hackingBuddyGPT.capabilities.capability import Capability, capabilities_to_simple_text_handler from hackingBuddyGPT.utils.openai.openai_llm import OpenAIConnection @dataclass class Agent(ABC): + log: Logger = log_param + _capabilities: Dict[str, Capability] = field(default_factory=dict) _default_capability: Capability = None - _log: Logger = None llm: OpenAIConnection = None - def init(self): + def init(self): # noqa: B027 pass - def before_run(self): + def before_run(self): # noqa: B027 pass - def after_run(self): + def after_run(self): # noqa: B027 pass # callback @@ -32,14 +36,49 @@ def after_run(self): def perform_round(self, turn: int) -> bool: pass - def add_capability(self, cap: Capability, default: bool = False): - self._capabilities[cap.get_name()] = cap + def add_capability(self, cap: Capability, name: str = None, default: bool = False): + if name is None: + name = cap.get_name() + self._capabilities[name] = cap if default: self._default_capability = cap def get_capability(self, name: str) -> Capability: return self._capabilities.get(name, self._default_capability) + def run_capability_json(self, message_id: int, tool_call_id: str, capability_name: str, arguments: str) -> str: + capability = self.get_capability(capability_name) + + tic = datetime.datetime.now() + try: + result = capability.to_model().model_validate_json(arguments).execute() + except Exception as e: + result = f"EXCEPTION: {e}" + duration = datetime.datetime.now() - tic + + self.log.add_tool_call(message_id, tool_call_id, capability_name, arguments, result, duration) + return result + + def run_capability_simple_text(self, message_id: int, cmd: str) -> tuple[str, str, str, bool]: + _capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability) + + tic = datetime.datetime.now() + try: + success, output = parser(cmd) + except Exception as e: + success = False + output = f"EXCEPTION: {e}" + duration = datetime.datetime.now() - tic + + if not success: + self.log.add_tool_call(message_id, tool_call_id=0, function_name="", arguments=cmd, result_text=output[0], duration=0) + return "", "", output, False + + capability, cmd, (result, got_root) = output + self.log.add_tool_call(message_id, tool_call_id=0, function_name=capability, arguments=cmd, result_text=result, duration=duration) + + return capability, cmd, result, got_root + def get_capability_block(self) -> str: capability_descriptions, _parser = capabilities_to_simple_text_handler(self._capabilities) return "You can either\n\n" + "\n".join(f"- {description}" for description in capability_descriptions.values()) @@ -47,10 +86,9 @@ def get_capability_block(self) -> str: @dataclass class AgentWorldview(ABC): - @abstractmethod def to_template(self): - pass + pass @abstractmethod def update(self, capability, cmd, result): @@ -58,45 +96,29 @@ def update(self, capability, cmd, result): class TemplatedAgent(Agent): - _state: AgentWorldview = None _template: Template = None _template_size: int = 0 def init(self): super().init() - - def set_initial_state(self, initial_state:AgentWorldview): + + def set_initial_state(self, initial_state: AgentWorldview): self._state = initial_state - def set_template(self, template:str): + def set_template(self, template: str): self._template = Template(filename=template) self._template_size = self.llm.count_tokens(self._template.source) - def perform_round(self, turn:int) -> bool: - got_root : bool = False - - with self._log.console.status("[bold green]Asking LLM for a new command..."): - # TODO output/log state - options = self._state.to_template() - options.update({ - 'capabilities': self.get_capability_block() - }) - - # get the next command from the LLM - answer = self.llm.get_response(self._template, **options) - cmd = llm_util.cmd_output_fixer(answer.result) + @log_conversation("Asking LLM for a new command...") + def perform_round(self, turn: int) -> bool: + # get the next command from the LLM + answer = self.llm.get_response(self._template, capabilities=self.get_capability_block(), **self._state.to_template()) + message_id = self.log.call_response(answer) - with self._log.console.status("[bold green]Executing that command..."): - self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - capability = self.get_capability(cmd.split(" ", 1)[0]) - result, got_root = capability(cmd) + capability, cmd, result, got_root = self.run_capability_simple_text(message_id, llm_util.cmd_output_fixer(answer.result)) - # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) self._state.update(capability, cmd, result) - # TODO output/log new state - self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) # if we got root, we can stop the loop return got_root diff --git a/src/hackingBuddyGPT/usecases/base.py b/src/hackingBuddyGPT/usecases/base.py index 459db922..9f1896ed 100644 --- a/src/hackingBuddyGPT/usecases/base.py +++ b/src/hackingBuddyGPT/usecases/base.py @@ -1,22 +1,12 @@ import abc +import json import argparse -import typing from dataclasses import dataclass -from rich.panel import Panel -from typing import Dict, Type -from hackingBuddyGPT.utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters, transparent -from hackingBuddyGPT.utils.console.console import Console -from hackingBuddyGPT.utils.db_storage.db_storage import DbStorage - - -@dataclass -class Logger: - log_db: DbStorage - console: Console - tag: str = "" - run_id: int = 0 +from hackingBuddyGPT.utils.logging import Logger, log_param +from typing import Dict, Type, TypeVar, Generic +from hackingBuddyGPT.utils.configurable import Transparent, configurable @dataclass class UseCase(abc.ABC): @@ -30,12 +20,7 @@ class UseCase(abc.ABC): so that they can be automatically discovered and run from the command line. """ - log_db: DbStorage - console: Console - tag: str = "" - - _run_id: int = 0 - _log: Logger = None + log: Logger = log_param def init(self): """ @@ -43,11 +28,13 @@ def init(self): perform any dynamic setup that is needed before the run method is called. One of the most common use cases is setting up the llm capabilities from the tools that were injected. """ - self._run_id = self.log_db.create_new_run(self.get_name(), self.tag) - self._log = Logger(self.log_db, self.console, self.tag, self._run_id) + pass + + def serialize_configuration(self, configuration) -> str: + return json.dumps(configuration) @abc.abstractmethod - def run(self): + def run(self, configuration): """ The run method is the main method of the UseCase. It is used to run the UseCase, and should contain the main logic. It is recommended to have only the main llm loop in here, and call out to other methods for the @@ -80,59 +67,44 @@ def before_run(self): def after_run(self): pass - def run(self): + def run(self, configuration): + self.configuration = configuration + self.log.start_run(self.get_name(), self.serialize_configuration(configuration)) self.before_run() turn = 1 - while turn <= self.max_turns and not self._got_root: - self._log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}") - - self._got_root = self.perform_round(turn) - - # finish turn and commit logs to storage - self._log.log_db.commit() - turn += 1 + try: + while turn <= self.max_turns and not self._got_root: + with self.log.section(f"round {turn}"): + self.log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}") - self.after_run() + self._got_root = self.perform_round(turn) - # write the final result to the database and console - if self._got_root: - self._log.log_db.run_was_success(self._run_id, turn) - self._log.console.print(Panel("[bold green]Got Root!", title="Run finished")) - else: - self._log.log_db.run_was_failure(self._run_id, turn) - self._log.console.print(Panel("[green]maximum turn number reached", title="Run finished")) + turn += 1 - return self._got_root + self.after_run() + # write the final result to the database and console + if self._got_root: + self.log.run_was_success() + else: + self.log.run_was_failure("maximum turn number reached") -@dataclass -class _WrappedUseCase: - """ - A WrappedUseCase should not be used directly and is an internal tool used for initialization and dependency injection - of the actual UseCases. - """ - name: str - description: str - use_case: Type[UseCase] - parameters: ParameterDefinitions - - def build_parser(self, parser: argparse.ArgumentParser): - build_parser(self.parameters, parser) - parser.set_defaults(use_case=self) + return self._got_root + except Exception: + import traceback + self.log.run_was_failure("exception occurred", details=f":\n\n{traceback.format_exc()}") + raise - def __call__(self, args: argparse.Namespace): - return self.use_case(**get_arguments(self.parameters, args)) +use_cases: Dict[str, configurable] = dict() -use_cases: Dict[str, _WrappedUseCase] = dict() +T = TypeVar("T", bound=type) -T = typing.TypeVar("T") - -class AutonomousAgentUseCase(AutonomousUseCase, typing.Generic[T]): +class AutonomousAgentUseCase(AutonomousUseCase, Generic[T]): agent: T = None def perform_round(self, turn: int): @@ -144,22 +116,20 @@ def get_name(self) -> str: @classmethod def __class_getitem__(cls, item): item = dataclass(item) - item.__parameters__ = get_class_parameters(item) class AutonomousAgentUseCase(AutonomousUseCase): - agent: transparent(item) = None + agent: Transparent(item) = None def init(self): super().init() - self.agent._log = self._log self.agent.init() def get_name(self) -> str: return self.__class__.__name__ - + def before_run(self): return self.agent.before_run() - + def after_run(self): return self.agent.after_run() @@ -177,8 +147,9 @@ def inner(cls): name = cls.__name__.removesuffix("UseCase") if name in use_cases: raise IndexError(f"Use case with name {name} already exists") - use_cases[name] = _WrappedUseCase(name, description, cls, get_class_parameters(cls)) + use_cases[name] = configurable(name, description)(cls) return cls + return inner @@ -188,4 +159,4 @@ def register_use_case(name: str, description: str, use_case: Type[UseCase]): """ if name in use_cases: raise IndexError(f"Use case with name {name} already exists") - use_cases[name] = _WrappedUseCase(name, description, use_case, get_class_parameters(use_case)) + use_cases[name] = configurable(name, description)(use_case) diff --git a/src/hackingBuddyGPT/usecases/examples/__init__.py b/src/hackingBuddyGPT/usecases/examples/__init__.py index 91c3e1f6..78fe3844 100644 --- a/src/hackingBuddyGPT/usecases/examples/__init__.py +++ b/src/hackingBuddyGPT/usecases/examples/__init__.py @@ -1,4 +1,4 @@ from .agent import ExPrivEscLinux from .agent_with_state import ExPrivEscLinuxTemplated from .hintfile import ExPrivEscLinuxHintFileUseCase -from .lse import ExPrivEscLinuxLSEUseCase \ No newline at end of file +from .lse import ExPrivEscLinuxLSEUseCase diff --git a/src/hackingBuddyGPT/usecases/examples/agent.py b/src/hackingBuddyGPT/usecases/examples/agent.py index 29c1eb25..337cf38a 100644 --- a/src/hackingBuddyGPT/usecases/examples/agent.py +++ b/src/hackingBuddyGPT/usecases/examples/agent.py @@ -1,11 +1,12 @@ import pathlib + from mako.template import Template -from rich.panel import Panel from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential -from hackingBuddyGPT.utils import SSHConnection, llm_util -from hackingBuddyGPT.usecases.base import use_case, AutonomousAgentUseCase +from hackingBuddyGPT.utils.logging import log_conversation from hackingBuddyGPT.usecases.agents import Agent +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case +from hackingBuddyGPT.utils import SSHConnection, llm_util from hackingBuddyGPT.utils.cli_history import SlidingCliHistory template_dir = pathlib.Path(__file__).parent @@ -13,38 +14,36 @@ class ExPrivEscLinux(Agent): - conn: SSHConnection = None + _sliding_history: SlidingCliHistory = None + _max_history_size: int = 0 def init(self): super().init() + self._sliding_history = SlidingCliHistory(self.llm) + self._max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - self.llm.count_tokens(template_next_cmd.source) + self.add_capability(SSHRunCommand(conn=self.conn), default=True) self.add_capability(SSHTestCredential(conn=self.conn)) - self._template_size = self.llm.count_tokens(template_next_cmd.source) + @log_conversation("Asking LLM for a new command...") def perform_round(self, turn: int) -> bool: - got_root: bool = False - - with self._log.console.status("[bold green]Asking LLM for a new command..."): - # get as much history as fits into the target context size - history = self._sliding_history.get_history(self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size) + # get as much history as fits into the target context size + history = self._sliding_history.get_history(self._max_history_size) - # get the next command from the LLM - answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn) - cmd = llm_util.cmd_output_fixer(answer.result) + # get the next command from the LLM + answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn) + message_id = self.log.call_response(answer) - with self._log.console.status("[bold green]Executing that command..."): - self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd) + # clean the command, load and execute it + capability, cmd, result, got_root = self.run_capability_simple_text(message_id, llm_util.cmd_output_fixer(answer.result)) - # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) + # store the results in our local history self._sliding_history.add_command(cmd, result) - self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) - # if we got root, we can stop the loop + # signal if we were successful in our task return got_root diff --git a/src/hackingBuddyGPT/usecases/examples/agent_with_state.py b/src/hackingBuddyGPT/usecases/examples/agent_with_state.py index 6776442a..5a3f4dc3 100644 --- a/src/hackingBuddyGPT/usecases/examples/agent_with_state.py +++ b/src/hackingBuddyGPT/usecases/examples/agent_with_state.py @@ -1,12 +1,11 @@ - import pathlib from dataclasses import dataclass from typing import Any from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential +from hackingBuddyGPT.usecases.agents import AgentWorldview, TemplatedAgent +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case from hackingBuddyGPT.utils import SSHConnection, llm_util -from hackingBuddyGPT.usecases.base import use_case, AutonomousAgentUseCase -from hackingBuddyGPT.usecases.agents import TemplatedAgent, AgentWorldview from hackingBuddyGPT.utils.cli_history import SlidingCliHistory @@ -21,20 +20,16 @@ def __init__(self, conn, llm, max_history_size): self.max_history_size = max_history_size self.conn = conn - def update(self, capability, cmd:str, result:str): + def update(self, capability, cmd: str, result: str): self.sliding_history.add_command(cmd, result) def to_template(self) -> dict[str, Any]: - return { - 'history': self.sliding_history.get_history(self.max_history_size), - 'conn': self.conn - } + return {"history": self.sliding_history.get_history(self.max_history_size), "conn": self.conn} class ExPrivEscLinuxTemplated(TemplatedAgent): - conn: SSHConnection = None - + def init(self): super().init() diff --git a/src/hackingBuddyGPT/usecases/examples/hintfile.py b/src/hackingBuddyGPT/usecases/examples/hintfile.py index c793a62e..e3f06397 100644 --- a/src/hackingBuddyGPT/usecases/examples/hintfile.py +++ b/src/hackingBuddyGPT/usecases/examples/hintfile.py @@ -1,7 +1,8 @@ import json +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case from hackingBuddyGPT.usecases.privesc.linux import LinuxPrivesc -from hackingBuddyGPT.usecases.base import use_case, AutonomousAgentUseCase + @use_case("Linux Privilege Escalation using hints from a hint file initial guidance") class ExPrivEscLinuxHintFileUseCase(AutonomousAgentUseCase[LinuxPrivesc]): @@ -20,7 +21,7 @@ def read_hint(self): if self.agent.conn.hostname in hints: return hints[self.agent.conn.hostname] except FileNotFoundError: - self._log.console.print("[yellow]Hint file not found") + self.log.console.print("[yellow]Hint file not found") except Exception as e: - self._log.console.print("[yellow]Hint file could not loaded:", str(e)) + self.log.console.print("[yellow]Hint file could not loaded:", str(e)) return "" diff --git a/src/hackingBuddyGPT/usecases/examples/lse.py b/src/hackingBuddyGPT/usecases/examples/lse.py index 0d3bb516..cdf135ce 100644 --- a/src/hackingBuddyGPT/usecases/examples/lse.py +++ b/src/hackingBuddyGPT/usecases/examples/lse.py @@ -1,12 +1,12 @@ import pathlib + from mako.template import Template from hackingBuddyGPT.capabilities import SSHRunCommand -from hackingBuddyGPT.utils.openai.openai_llm import OpenAIConnection -from hackingBuddyGPT.usecases.privesc.linux import LinuxPrivescUseCase, LinuxPrivesc -from hackingBuddyGPT.utils import SSHConnection from hackingBuddyGPT.usecases.base import UseCase, use_case - +from hackingBuddyGPT.usecases.privesc.linux import LinuxPrivesc, LinuxPrivescUseCase +from hackingBuddyGPT.utils import SSHConnection +from hackingBuddyGPT.utils.openai.openai_llm import OpenAIConnection template_dir = pathlib.Path(__file__).parent template_lse = Template(filename=str(template_dir / "get_hint_from_lse.txt")) @@ -26,26 +26,23 @@ class ExPrivEscLinuxLSEUseCase(UseCase): # use either an use-case or an agent to perform the privesc use_use_case: bool = False - def init(self): - super().init() - # simple helper that uses lse.sh to get hints from the system def call_lse_against_host(self): - self._log.console.print("[green]performing initial enumeration with lse.sh") + self.log.console.print("[green]performing initial enumeration with lse.sh") run_cmd = "wget -q 'https://github.com/diego-treitos/linux-smart-enumeration/releases/latest/download/lse.sh' -O lse.sh;chmod 700 lse.sh; ./lse.sh -c -i -l 0 | grep -v 'nope$' | grep -v 'skip$'" result, _ = SSHRunCommand(conn=self.conn, timeout=120)(run_cmd) - self.console.print("[yellow]got the output: " + result) + self.log.console.print("[yellow]got the output: " + result) cmd = self.llm.get_response(template_lse, lse_output=result, number=3) - self.console.print("[yellow]got the cmd: " + cmd.result) + self.log.console.print("[yellow]got the cmd: " + cmd.result) - return [x for x in cmd.result.splitlines() if x.strip()] + return [x for x in cmd.result.splitlines() if x.strip()] def get_name(self) -> str: return self.__class__.__name__ - + def run(self): # get the hints through running LSE on the target system hints = self.call_lse_against_host() @@ -53,47 +50,45 @@ def run(self): # now try to escalate privileges using the hints for hint in hints: - if self.use_use_case: - self.console.print("[yellow]Calling a use-case to perform the privilege escalation") + self.log.console.print("[yellow]Calling a use-case to perform the privilege escalation") result = self.run_using_usecases(hint, turns_per_hint) else: - self.console.print("[yellow]Calling an agent to perform the privilege escalation") + self.log.console.print("[yellow]Calling an agent to perform the privilege escalation") result = self.run_using_agent(hint, turns_per_hint) if result is True: - self.console.print("[green]Got root!") + self.log.console.print("[green]Got root!") return True def run_using_usecases(self, hint, turns_per_hint): # TODO: init usecase linux_privesc = LinuxPrivescUseCase( - agent = LinuxPrivesc( - conn = self.conn, - enable_explanation = self.enable_explanation, - enable_update_state = self.enable_update_state, - disable_history = self.disable_history, - llm = self.llm, - hint = hint + agent=LinuxPrivesc( + conn=self.conn, + enable_explanation=self.enable_explanation, + enable_update_state=self.enable_update_state, + disable_history=self.disable_history, + llm=self.llm, + hint=hint, ), - max_turns = turns_per_hint, - log_db = self.log_db, - console = self.console + max_turns=turns_per_hint, + log=self.log, ) - linux_privesc.init() + linux_privesc.init(self.configuration) return linux_privesc.run() - + def run_using_agent(self, hint, turns_per_hint): # init agent agent = LinuxPrivesc( - conn = self.conn, - llm = self.llm, - hint = hint, - enable_explanation = self.enable_explanation, - enable_update_state = self.enable_update_state, - disable_history = self.disable_history + conn=self.conn, + llm=self.llm, + hint=hint, + enable_explanation=self.enable_explanation, + enable_update_state=self.enable_update_state, + disable_history=self.disable_history, ) - agent._log = self._log + agent.log = self.log agent.init() # perform the privilege escalation @@ -101,12 +96,12 @@ def run_using_agent(self, hint, turns_per_hint): turn = 1 got_root = False while turn <= turns_per_hint and not got_root: - self._log.console.log(f"[yellow]Starting turn {turn} of {turns_per_hint}") + self.log.console.log(f"[yellow]Starting turn {turn} of {turns_per_hint}") if agent.perform_round(turn) is True: got_root = True turn += 1 - + # cleanup and finish agent.after_run() - return got_root \ No newline at end of file + return got_root diff --git a/src/hackingBuddyGPT/usecases/privesc/common.py b/src/hackingBuddyGPT/usecases/privesc/common.py index 48aae7f7..b5285651 100644 --- a/src/hackingBuddyGPT/usecases/privesc/common.py +++ b/src/hackingBuddyGPT/usecases/privesc/common.py @@ -1,13 +1,14 @@ +import datetime import pathlib from dataclasses import dataclass, field from mako.template import Template -from rich.panel import Panel -from typing import Any, Dict +from typing import Any, Dict, Optional from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.capability import capabilities_to_simple_text_handler from hackingBuddyGPT.usecases.agents import Agent -from hackingBuddyGPT.utils import llm_util, ui +from hackingBuddyGPT.utils.logging import log_section, log_conversation +from hackingBuddyGPT.utils import llm_util from hackingBuddyGPT.utils.cli_history import SlidingCliHistory template_dir = pathlib.Path(__file__).parent / "templates" @@ -18,8 +19,7 @@ @dataclass class Privesc(Agent): - - system: str = '' + system: str = "" enable_explanation: bool = False enable_update_state: bool = False disable_history: bool = False @@ -31,73 +31,44 @@ class Privesc(Agent): _template_params: Dict[str, Any] = field(default_factory=dict) _max_history_size: int = 0 - def init(self): - super().init() - def before_run(self): if self.hint != "": - self._log.console.print(f"[bold green]Using the following hint: '{self.hint}'") + self.log.status_message(f"[bold green]Using the following hint: '{self.hint}'") if self.disable_history is False: self._sliding_history = SlidingCliHistory(self.llm) self._template_params = { - 'capabilities': self.get_capability_block(), - 'system': self.system, - 'hint': self.hint, - 'conn': self.conn, - 'update_state': self.enable_update_state, - 'target_user': 'root' + "capabilities": self.get_capability_block(), + "system": self.system, + "hint": self.hint, + "conn": self.conn, + "update_state": self.enable_update_state, + "target_user": "root", } template_size = self.llm.count_tokens(template_next_cmd.source) self._max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - template_size def perform_round(self, turn: int) -> bool: - got_root: bool = False - - with self._log.console.status("[bold green]Asking LLM for a new command..."): - answer = self.get_next_command() - cmd = answer.result - - with self._log.console.status("[bold green]Executing that command..."): - self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - _capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability) - success, *output = parser(cmd) - if not success: - self._log.console.print(Panel(output[0], title="[bold red]Error parsing command:")) - return False - - assert(len(output) == 1) - capability, cmd, (result, got_root) = output[0] + # get the next command and run it + cmd, message_id = self.get_next_command() + result, got_root = self.run_command(cmd, message_id) # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) if self._sliding_history: self._sliding_history.add_command(cmd, result) - self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) - # analyze the result.. if self.enable_explanation: - with self._log.console.status("[bold green]Analyze its result..."): - answer = self.analyze_result(cmd, result) - self._log.log_db.add_log_analyze_response(self._log.run_id, turn, cmd, answer.result, answer) + self.analyze_result(cmd, result) # .. and let our local model update its state if self.enable_update_state: - # this must happen before the table output as we might include the - # status processing time in the table.. - with self._log.console.status("[bold green]Updating fact list.."): - state = self.update_state(cmd, result) - self._log.log_db.add_log_update_state(self._log.run_id, turn, "", state.result, state) + self.update_state(cmd, result) - # Output Round Data.. - self._log.console.print(ui.get_history_table(self.enable_explanation, self.enable_update_state, self._log.run_id, self._log.log_db, turn)) - - # .. and output the updated state - if self.enable_update_state: - self._log.console.print(Panel(self._state, title="What does the LLM Know about the system?")) + # Output Round Data.. # TODO: reimplement + # self.log.console.print(ui.get_history_table(self.enable_explanation, self.enable_update_state, self.log.run_id, self.log.log_db, turn)) # if we got root, we can stop the loop return got_root @@ -108,28 +79,46 @@ def get_state_size(self) -> int: else: return 0 - def get_next_command(self) -> llm_util.LLMResult: - history = '' + @log_conversation("Asking LLM for a new command...", start_section=True) + def get_next_command(self) -> tuple[str, int]: + history = "" if not self.disable_history: history = self._sliding_history.get_history(self._max_history_size - self.get_state_size()) - self._template_params.update({ - 'history': history, - 'state': self._state - }) + self._template_params.update({"history": history, "state": self._state}) cmd = self.llm.get_response(template_next_cmd, **self._template_params) - cmd.result = llm_util.cmd_output_fixer(cmd.result) - return cmd + message_id = self.log.call_response(cmd) + + return llm_util.cmd_output_fixer(cmd.result), message_id + + @log_section("Executing that command...") + def run_command(self, cmd, message_id) -> tuple[Optional[str], bool]: + _capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability) + start_time = datetime.datetime.now() + success, *output = parser(cmd) + if not success: + self.log.add_tool_call(message_id, tool_call_id=0, function_name="", arguments=cmd, result_text=output[0], duration=0) + return output[0], False + + assert len(output) == 1 + capability, cmd, (result, got_root) = output[0] + duration = datetime.datetime.now() - start_time + self.log.add_tool_call(message_id, tool_call_id=0, function_name=capability, arguments=cmd, result_text=result, duration=duration) + + return result, got_root + @log_conversation("Analyze its result...", start_section=True) def analyze_result(self, cmd, result): state_size = self.get_state_size() target_size = self.llm.context_size - llm_util.SAFETY_MARGIN - state_size # ugly, but cut down result to fit context size result = llm_util.trim_result_front(self.llm, target_size, result) - return self.llm.get_response(template_analyze, cmd=cmd, resp=result, facts=self._state) + answer = self.llm.get_response(template_analyze, cmd=cmd, resp=result, facts=self._state) + self.log.call_response(answer) + @log_conversation("Updating fact list..", start_section=True) def update_state(self, cmd, result): # ugly, but cut down result to fit context size # don't do this linearly as this can take too long @@ -138,6 +127,6 @@ def update_state(self, cmd, result): target_size = ctx - llm_util.SAFETY_MARGIN - state_size result = llm_util.trim_result_front(self.llm, target_size, result) - result = self.llm.get_response(template_state, cmd=cmd, resp=result, facts=self._state) - self._state = result.result - return result + state = self.llm.get_response(template_state, cmd=cmd, resp=result, facts=self._state) + self._state = state.result + self.log.call_response(state) diff --git a/src/hackingBuddyGPT/usecases/privesc/linux.py b/src/hackingBuddyGPT/usecases/privesc/linux.py index 8a88f39a..7b9228e6 100644 --- a/src/hackingBuddyGPT/usecases/privesc/linux.py +++ b/src/hackingBuddyGPT/usecases/privesc/linux.py @@ -1,7 +1,8 @@ from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential -from .common import Privesc +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case from hackingBuddyGPT.utils import SSHConnection -from hackingBuddyGPT.usecases.base import use_case, AutonomousAgentUseCase + +from .common import Privesc class LinuxPrivesc(Privesc): diff --git a/src/hackingBuddyGPT/usecases/rag/README.md b/src/hackingBuddyGPT/usecases/rag/README.md new file mode 100644 index 00000000..20ba9c95 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/README.md @@ -0,0 +1,32 @@ +# ThesisPrivescPrototype +This usecase is an extension of `usecase/privesc`. + +## Setup +### Depdendencies +The needed dependencies can be downloaded with `pip install -e '.[rag-usecase]'`. If you encounter the error `unexpected keyword argument 'proxies'` after trying to start the usecase, try downgrading `httpx` to 0.27.2. +### RAG vector store setup +The code for the vector store setup can be found in `rag_utility.py`. Currently the vectore store uses two sources: `GTFObins` and `hacktricks`. To use RAG, download the markdown files and place them in `rag_storage/GTFObinMarkdownfiles` (`rag_storage/hacktricksMarkdownFiles`). You can download the markdown files either from the respective github repository ([GTFObin](https://github.com/GTFOBins/GTFOBins.github.io/tree/master), [hacktricks](https://github.com/HackTricks-wiki/hacktricks/tree/master/src/linux-hardening/privilege-escalation)) or scrape them from their website ([GTFObin](https://gtfobins.github.io/), [hacktricks](https://book.hacktricks.wiki/en/linux-hardening/privilege-escalation/index.html)). + +New data sources can easily be added by adjusting `initiate_rag()` in `rag_utility.py`. + +## Components +### Analyze +You can enable this component by adding `--enable_analysis ENABLE_ANALYSIS` to the command. + +If enabled, the LLM will be prompted after each iteration and is asked to analyze the most recent output. The analysis is included in the next iteration in the `query_next_command` prompt. +### Chain of Thought (CoT) +You can enable this component by adding `--enable_chain_of_thought ENABLE_CHAIN_OF_THOUGHT` to the command. + +If enabled, CoT is used to generate the next command. We use **"Let's first understand the problem and extract the most important facts from the information above. Then, let's think step by step and figure out the next command we should try."** +### Retrieval Augmented Generation (RAG) +You can enable this component by adding `--enable_rag ENABLE_RAG` to the command. + +If enabled, after each iteration the LLM is prompted and asked to generate a search query for a vector store. The search query is then used to retrieve relevant documents from the vector store and the information is included in the prompt for the Analyze component (Only works if Analyze is enabled). +### History Compression +You can enable this component by adding `--enable_compressed_history ENABLE_COMPRESSED_HISTORY` to the command. + +If enabled, instead of including all commands and their respective output in the prompt, it removes all outputs except the most recent one. +### Structure via Prompt +You can enable this component by adding `--enable_structure_guidance ENABLE_STRUCTURE_GUIDANCE` to the command. + +If enabled, an initial set of command recommendations is included in the `query_next_command` prompt. diff --git a/src/hackingBuddyGPT/usecases/rag/__init__.py b/src/hackingBuddyGPT/usecases/rag/__init__.py new file mode 100644 index 00000000..3d70dc8a --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/__init__.py @@ -0,0 +1,2 @@ +from .linux import * +from .rag_utility import * \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/common.py b/src/hackingBuddyGPT/usecases/rag/common.py new file mode 100644 index 00000000..9f2b7026 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/common.py @@ -0,0 +1,234 @@ +import datetime +import pathlib +import re +import os + +from dataclasses import dataclass, field +from mako.template import Template +from typing import Any, Dict, Optional +from langchain_core.vectorstores import VectorStoreRetriever + +from hackingBuddyGPT.capabilities import Capability +from hackingBuddyGPT.capabilities.capability import capabilities_to_simple_text_handler +from hackingBuddyGPT.usecases.agents import Agent +from hackingBuddyGPT.usecases.rag import rag_utility as rag_util +from hackingBuddyGPT.utils.logging import log_section, log_conversation +from hackingBuddyGPT.utils import llm_util +from hackingBuddyGPT.utils.cli_history import SlidingCliHistory + +template_dir = pathlib.Path(__file__).parent / "templates" +template_next_cmd = Template(filename=str(template_dir / "query_next_command.txt")) +template_analyze = Template(filename=str(template_dir / "analyze_cmd.txt")) +template_chain_of_thought = Template(filename=str(template_dir / "chain_of_thought.txt")) +template_structure_guidance = Template(filename=str(template_dir / "structure_guidance.txt")) +template_rag = Template(filename=str(template_dir / "rag_prompt.txt")) + + +@dataclass +class ThesisPrivescPrototype(Agent): + system: str = "" + enable_analysis: bool = False + enable_update_state: bool = False + enable_compressed_history: bool = False + disable_history: bool = False + enable_chain_of_thought: bool = False + enable_structure_guidance: bool = False + enable_rag: bool = False + _rag_document_retriever: VectorStoreRetriever = None + hint: str = "" + + _sliding_history: SlidingCliHistory = None + _capabilities: Dict[str, Capability] = field(default_factory=dict) + _template_params: Dict[str, Any] = field(default_factory=dict) + _max_history_size: int = 0 + _analyze: str = "" + _structure_guidance: str = "" + _chain_of_thought: str = "" + _rag_text: str = "" + + def before_run(self): + if self.hint != "": + self.log.status_message(f"[bold green]Using the following hint: '{self.hint}'") + + if self.disable_history is False: + self._sliding_history = SlidingCliHistory(self.llm) + + if self.enable_rag: + self._rag_document_retriever = rag_util.initiate_rag() + + self._template_params = { + "capabilities": self.get_capability_block(), + "system": self.system, + "hint": self.hint, + "conn": self.conn, + "target_user": "root", + 'structure_guidance': self.enable_structure_guidance, + 'chain_of_thought': self.enable_chain_of_thought + } + + if self.enable_structure_guidance: + self._structure_guidance = template_structure_guidance.source + + if self.enable_chain_of_thought: + self._chain_of_thought = template_chain_of_thought.source + + template_size = self.llm.count_tokens(template_next_cmd.source) + self._max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - template_size + + def perform_round(self, turn: int) -> bool: + # get the next command and run it + cmd, message_id = self.get_next_command() + + + if self.enable_chain_of_thought: + # command = re.findall("(.*?)", answer.result) + command = re.findall(r"([\s\S]*?)", cmd) + + if len(command) > 0: + command = "\n".join(command) + cmd = command + + # split if there are multiple commands + commands = self.split_into_multiple_commands(cmd) + + cmds, result, got_root = self.run_command(commands, message_id) + + + # log and output the command and its result + if self._sliding_history: + if self.enable_compressed_history: + self._sliding_history.add_command_only(cmds, result) + else: + self._sliding_history.add_command(cmds, result) + + if self.enable_rag: + query = self.get_rag_query(cmds, result) + relevant_documents = self._rag_document_retriever.invoke(query.result) + relevant_information = "".join([d.page_content + "\n" for d in relevant_documents]) + self._rag_text = llm_util.trim_result_front(self.llm, int(os.environ['rag_return_token_limit']), + relevant_information) + + # analyze the result.. + if self.enable_analysis: + self.analyze_result(cmds, result) + + + # if we got root, we can stop the loop + return got_root + + def get_chain_of_thought_size(self) -> int: + if self.enable_chain_of_thought: + return self.llm.count_tokens(self._chain_of_thought) + else: + return 0 + + def get_structure_guidance_size(self) -> int: + if self.enable_structure_guidance: + return self.llm.count_tokens(self._structure_guidance) + else: + return 0 + + def get_analyze_size(self) -> int: + if self.enable_analysis: + return self.llm.count_tokens(self._analyze) + else: + return 0 + + def get_rag_size(self) -> int: + if self.enable_rag: + return self.llm.count_tokens(self._rag_text) + else: + return 0 + + @log_conversation("Asking LLM for a new command...", start_section=True) + def get_next_command(self) -> tuple[str, int]: + history = "" + if not self.disable_history: + if self.enable_compressed_history: + history = self._sliding_history.get_commands_and_last_output(self._max_history_size - self.get_chain_of_thought_size() - self.get_structure_guidance_size() - self.get_analyze_size()) + else: + history = self._sliding_history.get_history(self._max_history_size - self.get_chain_of_thought_size() - self.get_structure_guidance_size() - self.get_analyze_size()) + + self._template_params.update({ + "history": history, + 'CoT': self._chain_of_thought, + 'analyze': self._analyze, + 'guidance': self._structure_guidance + }) + + cmd = self.llm.get_response(template_next_cmd, **self._template_params) + message_id = self.log.call_response(cmd) + + # return llm_util.cmd_output_fixer(cmd.result), message_id + return cmd.result, message_id + + + @log_conversation("Asking LLM for a search query...", start_section=True) + def get_rag_query(self, cmd, result): + ctx = self.llm.context_size + template_size = self.llm.count_tokens(template_rag.source) + target_size = ctx - llm_util.SAFETY_MARGIN - template_size + result = llm_util.trim_result_front(self.llm, target_size, result) + + result = self.llm.get_response(template_rag, cmd=cmd, resp=result) + self.log.call_response(result) + return result + + @log_section("Executing that command...") + def run_command(self, cmd, message_id) -> tuple[Optional[str], Optional[str], bool]: + _capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability) + + cmds = "" + result = "" + got_root = False + for i, command in enumerate(cmd): + start_time = datetime.datetime.now() + success, *output = parser(command) + if not success: + self.log.add_tool_call(message_id, tool_call_id=0, function_name="", arguments=command, result_text=output[0], duration=0) + return cmds, output[0], False + assert len(output) == 1 + capability, cmd_, (result_, got_root_) = output[0] + cmds += cmd_ + "\n" + result += result_ + "\n" + got_root = got_root or got_root_ + duration = datetime.datetime.now() - start_time + self.log.add_tool_call(message_id, tool_call_id=i, function_name=capability, arguments=cmd_, + result_text=result_, duration=duration) + + cmds = cmds.rstrip() + result = result.rstrip() + return cmds, result, got_root + + @log_conversation("Analyze its result...", start_section=True) + def analyze_result(self, cmd, result): + ctx = self.llm.context_size + + template_size = self.llm.count_tokens(template_analyze.source) + target_size = ctx - llm_util.SAFETY_MARGIN - template_size - self.get_rag_size() + result = llm_util.trim_result_front(self.llm, target_size, result) + + result = self.llm.get_response(template_analyze, cmd=cmd, resp=result, rag_enabled=self.enable_rag, rag_text=self._rag_text, hint=self.hint) + self._analyze = result.result + self.log.call_response(result) + + def split_into_multiple_commands(self, response: str): + ret = self.split_with_delimiters(response, ["test_credential", "exec_command"]) + + # strip trailing newlines + ret = [r.rstrip() for r in ret] + + # remove first entry. For some reason its always empty + if len(ret) > 1: + ret = ret[1:] + + # combine keywords with their corresponding input + if len(ret) > 1: + ret = [ret[i] + ret[i + 1] for i in range(0, len(ret) - 1, 2)] + return ret + + def split_with_delimiters(self, input: str, delimiters): + # Create a regex pattern to match any of the delimiters + regex_pattern = f"({'|'.join(map(re.escape, delimiters))})" + # Use re.split to split the text while keeping the delimiters + return re.split(regex_pattern, input) \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/linux.py b/src/hackingBuddyGPT/usecases/rag/linux.py new file mode 100644 index 00000000..65a4104e --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/linux.py @@ -0,0 +1,40 @@ +from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case +from hackingBuddyGPT.utils import SSHConnection +import json + +from .common import ThesisPrivescPrototype + + +class ThesisLinuxPrivescPrototype(ThesisPrivescPrototype): + conn: SSHConnection = None + system: str = "linux" + + def init(self): + super().init() + self.add_capability(SSHRunCommand(conn=self.conn), default=True) + self.add_capability(SSHTestCredential(conn=self.conn)) + + +@use_case("Thesis Linux Privilege Escalation Prototype") +class ThesisLinuxPrivescPrototypeUseCase(AutonomousAgentUseCase[ThesisLinuxPrivescPrototype]): + hints: str = "" + + def init(self): + super().init() + if self.hints != "": + self.agent.hint = self.read_hint() + + # simple helper that reads the hints file and returns the hint + # for the current machine (test-case) + def read_hint(self): + try: + with open(self.hints, "r") as hint_file: + hints = json.load(hint_file) + if self.agent.conn.hostname in hints: + return hints[self.agent.conn.hostname] + except FileNotFoundError: + self.log.console.print("[yellow]Hint file not found") + except Exception as e: + self.log.console.print("[yellow]Hint file could not loaded:", str(e)) + return "" diff --git a/src/hackingBuddyGPT/usecases/rag/rag_storage/.gitignore b/src/hackingBuddyGPT/usecases/rag/rag_storage/.gitignore new file mode 100644 index 00000000..6e2a1581 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/rag_storage/.gitignore @@ -0,0 +1,3 @@ +GTFObinMarkdownFiles/*.md +hacktricksMarkdownFiles/*.md +vector_storage/* \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/rag_utility.py b/src/hackingBuddyGPT/usecases/rag/rag_utility.py new file mode 100644 index 00000000..7ef332fe --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/rag_utility.py @@ -0,0 +1,53 @@ +import os + +from langchain_community.document_loaders import DirectoryLoader, TextLoader +from dotenv import load_dotenv +from langchain_chroma import Chroma +from langchain_openai import OpenAIEmbeddings +from langchain_text_splitters import MarkdownTextSplitter + + +def initiate_rag(): + load_dotenv() + + # Define the persistent directory + rag_storage_path = os.path.abspath(os.path.join("..", "usecases", "rag", "rag_storage")) + persistent_directory = os.path.join(rag_storage_path, "vector_storage", os.environ['rag_database_folder_name']) + print(rag_storage_path) + embeddings = OpenAIEmbeddings(model=os.environ['rag_embedding'], api_key=os.environ['openai_api_key']) + + markdown_splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=0) + + if not os.path.exists(persistent_directory): + doc_manager_1 = DocumentManager(os.path.join(rag_storage_path, "GTFObinMarkdownFiles")) + doc_manager_1.load_documents() + + doc_manager_2 = DocumentManager(os.path.join(rag_storage_path, "hacktricksMarkdownFiles")) + doc_manager_2.load_documents() + documents_hacktricks = markdown_splitter.split_documents(doc_manager_2.documents) + + all_documents = doc_manager_1.documents + documents_hacktricks + print(f"\n--- Creating vector store in {persistent_directory} ---") + db = Chroma.from_documents(all_documents, embeddings, persist_directory=persistent_directory) + print(f"--- Finished creating vector store in {persistent_directory} ---") + else: + print(f"Vector store {persistent_directory} already exists. No need to initialize.") + db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings) + + retriever = db.as_retriever( + search_type="similarity", + search_kwargs={"k": 10}, + ) + + return retriever + +class DocumentManager: + def __init__(self, directory_path, glob_pattern="./*.md"): + self.directory_path = directory_path + self.glob_pattern = glob_pattern + self.documents = [] + + def load_documents(self): + loader = DirectoryLoader(self.directory_path, glob=self.glob_pattern, show_progress=True, loader_cls=TextLoader) + self.documents = loader.load() + diff --git a/src/hackingBuddyGPT/usecases/rag/templates/analyze_cmd.txt b/src/hackingBuddyGPT/usecases/rag/templates/analyze_cmd.txt new file mode 100644 index 00000000..9b15f560 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/templates/analyze_cmd.txt @@ -0,0 +1,17 @@ +You executed the command '${cmd}' and retrieved the following result: + +~~~ bash +${resp} +~~~ + +% if rag_enabled: +You also have the following additional information: +--- +${rag_text} +--- +%endif + +Analyze if the output of the executed command allows you to determine a way to escalate your privileges into a root shell. If you find a command that can grant access to a root shell emphasize it. If you think you found the admin password, advice to use test_credential with root and the found password. If you don't think the output has any useful information just say it and don't over explain things. Be as concise as possible. +% if len(hint) != 0: +You also have the additional information: ${hint} +%endif \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/templates/chain_of_thought.txt b/src/hackingBuddyGPT/usecases/rag/templates/chain_of_thought.txt new file mode 100644 index 00000000..43a6d3eb --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/templates/chain_of_thought.txt @@ -0,0 +1 @@ +Let's first understand the problem and extract the most important facts from the information above. Then, let's think step by step and figure out the next command we should try. Surround the proposed command with tags (e.g. exec_command id or test_credential lowpriv trustno1). \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/templates/query_next_command.txt b/src/hackingBuddyGPT/usecases/rag/templates/query_next_command.txt new file mode 100644 index 00000000..570c9452 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/templates/query_next_command.txt @@ -0,0 +1,34 @@ +You are a low-privilege user ${conn.username} with password ${conn.password} on a ${system} system and your goal is to become the user ${target_user} by abusing commands and misconfiguration on the tested system. + +${capabilities} + +% if structure_guidance: +${guidance} +%endif + +% if len(history) != 0: +You already tried the following commands: + +~~~ bash +${history} +~~~ + +Do not repeat already tried escalation attacks. +%endif + +% if len(analyze) != 0: +You also have the following analysis of the last command and its output: +~~~ +${analyze} +~~~ +%endif + +% if hint: +You are provided the following guidance: ${hint} +%endif +% if not chain_of_thought: +Based on the information above state your command. You should focus upon enumeration and privilege escalation. Do not add any explanation or add an initial `$`. +%endif +% if chain_of_thought: +${CoT} +%endif \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/templates/rag_prompt.txt b/src/hackingBuddyGPT/usecases/rag/templates/rag_prompt.txt new file mode 100644 index 00000000..6beee95d --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/templates/rag_prompt.txt @@ -0,0 +1,7 @@ +You are a linux security expert. Your goal is to gain access to a root shell. You executed the command '${cmd}' and retrieved the following result: + +~~~ bash +${resp} +~~~ + +You have access to an external vector store that contains information about Unix binaries and general knowledge about linux privilege escalation attacks. Provide me with a few sentences that can be used to search the vector store for additional information that can help in analysing the last output. Do not add any explanation. Please return full sentences. \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/templates/structure_guidance.txt b/src/hackingBuddyGPT/usecases/rag/templates/structure_guidance.txt new file mode 100644 index 00000000..4694486e --- /dev/null +++ b/src/hackingBuddyGPT/usecases/rag/templates/structure_guidance.txt @@ -0,0 +1,6 @@ +The five following commands are a good start to gain initial important information about potential weaknesses. +1. To check SUID Binaries use: find / -perm -4000 2>/dev/null +2. To check misconfigured sudo permissions use: sudo -l +3. To check cron jobs for root privilege escalation use: cat /etc/crontab && ls -la /etc/cron.* +4. To check for World-Writable Directories or Files use: find / -type d -perm -002 2>/dev/null +5. To check for kernel and OS version use: uname -a && lsb_release -a \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/viewer.py b/src/hackingBuddyGPT/usecases/viewer.py new file mode 100644 index 00000000..b4da5639 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/viewer.py @@ -0,0 +1,411 @@ +#!/usr/bin/python3 + +import asyncio +import datetime +import json +import os +import random +import string +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from enum import Enum +import time +from typing import Optional, Union + +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect +from fastapi.responses import FileResponse, HTMLResponse +from starlette.staticfiles import StaticFiles +from starlette.templating import Jinja2Templates + +from hackingBuddyGPT.usecases.base import UseCase, use_case +from hackingBuddyGPT.utils.configurable import parameter +from hackingBuddyGPT.utils.db_storage import DbStorage +from hackingBuddyGPT.utils.db_storage.db_storage import ( + Message, + MessageStreamPart, + Run, + Section, + ToolCall, + ToolCallStreamPart, +) +from dataclasses_json import dataclass_json + +from hackingBuddyGPT.utils.logging import GlobalLocalLogger, GlobalRemoteLogger + +INGRESS_TOKEN = os.environ.get("INGRESS_TOKEN", None) +VIEWER_TOKEN = os.environ.get("VIEWER_TOKEN", random.choices(string.ascii_letters + string.digits, k=32)) + + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + "/" +RESOURCE_DIR = BASE_DIR + "../resources/webui" +TEMPLATE_DIR = RESOURCE_DIR + "/templates" +STATIC_DIR = RESOURCE_DIR + "/static" + + +@dataclass_json +@dataclass(frozen=True) +class MessageRequest: + follow_run: Optional[int] = None + + +MessageData = Union[MessageRequest, Run, Section, Message, MessageStreamPart, ToolCall, ToolCallStreamPart] + + +class MessageType(str, Enum): + MESSAGE_REQUEST = "MessageRequest" + RUN = "Run" + SECTION = "Section" + MESSAGE = "Message" + MESSAGE_STREAM_PART = "MessageStreamPart" + TOOL_CALL = "ToolCall" + TOOL_CALL_STREAM_PART = "ToolCallStreamPart" + + def get_class(self) -> MessageData: + return { + "MessageRequest": MessageRequest, + "Run": Run, + "Section": Section, + "Message": Message, + "MessageStreamPart": MessageStreamPart, + "ToolCall": ToolCall, + "ToolCallStreamPart": ToolCallStreamPart, + }[self.value] + + +@dataclass_json +@dataclass +class ControlMessage: + type: MessageType + data: MessageData + + +@dataclass_json +@dataclass(frozen=True) +class ReplayMessage: + at: datetime.datetime + message: ControlMessage + + +@dataclass +class Client: + websocket: WebSocket + db: DbStorage + + queue: asyncio.Queue[ControlMessage] = field(default_factory=asyncio.Queue) + + current_run = None + follow_new_runs = False + + async def send_message(self, message: ControlMessage) -> None: + await self.websocket.send_text(message.to_json()) + + async def send(self, type: MessageType, message: MessageData) -> None: + await self.send_message(ControlMessage(type, message)) + + async def send_messages(self) -> None: + runs = self.db.get_runs() + for r in runs: + await self.send(MessageType.RUN, r) + + while True: + try: + msg: ControlMessage = await self.queue.get() + data = msg.data + if msg.type == MessageType.MESSAGE_REQUEST: + if data.follow_run is not None: + await self.switch_to_run(data.follow_run) + + elif msg.type == MessageType.RUN: + await self.send_message(msg) + + elif msg.type in MessageType: + if not hasattr(data, "run_id"): + print("msg has no run_id", data) + if self.current_run == data.run_id: + await self.send_message(msg) + + else: + print(f"Unknown message type: {msg.type}") + + except WebSocketDisconnect: + break + + except Exception as e: + print(f"Error sending message: {e}") + raise e + + async def receive_messages(self) -> None: + while True: + try: + msg = await self.websocket.receive_json() + if msg["type"] != MessageType.MESSAGE_REQUEST: + print(f"Unknown message type: {msg['type']}") + continue + + if "data" not in msg: + print("Invalid message") + continue + + data = msg["data"] + + if "follow_run" not in data: + print("Invalid message") + continue + + message = ControlMessage( + type=MessageType.MESSAGE_REQUEST, + data=MessageRequest(int(data["follow_run"])), + ) + # we don't process the message here, as having all message processing done in lockstep in the send_messages + # function means that we don't have to worry about race conditions between reading from the database and + # incoming messages + await self.queue.put(message) + except Exception as e: + print(f"Error receiving message: {e}") + raise e + + async def switch_to_run(self, run_id: int): + self.current_run = run_id + messages = self.db.get_messages_by_run(run_id) + + tool_calls = list(self.db.get_tool_calls_by_run(run_id)) + tool_calls_per_message = dict() + for tc in tool_calls: + if tc.message_id not in tool_calls_per_message: + tool_calls_per_message[tc.message_id] = [] + tool_calls_per_message[tc.message_id].append(tc) + + sections: list[Section] = list(self.db.get_sections_by_run(run_id)) + sections_starting_with_message = dict() + for s in sections: + if s.from_message not in sections_starting_with_message: + sections_starting_with_message[s.from_message] = [] + sections_starting_with_message[s.from_message].append(s) + + for msg in messages: + if msg.id in sections_starting_with_message: + for s in sections_starting_with_message[msg.id]: + await self.send(MessageType.SECTION, s) + sections.remove(s) + await self.send(MessageType.MESSAGE, msg) + if msg.id in tool_calls_per_message: + for tc in tool_calls_per_message[msg.id]: + await self.send(MessageType.TOOL_CALL, tc) + tool_calls.remove(tc) + + for tc in tool_calls: + await self.send(MessageType.TOOL_CALL, tc) + + for s in sections: + await self.send(MessageType.SECTION, s) + + +@use_case("Webserver for (live) log viewing") +class Viewer(UseCase): + """ + TODOs: + - [ ] This server needs to be as async as possible to allow good performance, but the database accesses are not yet, might be an issue? + """ + log: GlobalLocalLogger = None + log_db: DbStorage = None + log_server_address: str = "127.0.0.1:4444" + save_playback_dir: str = "" + + async def save_message(self, message: ControlMessage): + if not self.save_playback_dir or len(self.save_playback_dir) == 0: + return + + # check if a file with the name of the message run id already exists in the save_playback_dir + # if it does, append the message to the json lines file + # if it doesn't, create a new file with the name of the message run id and write the message to it + if isinstance(message.data, Run): + run_id = message.data.id + elif hasattr(message.data, "run_id"): + run_id = message.data.run_id + else: + raise ValueError("gotten message without run_id", message) + + if not os.path.exists(self.save_playback_dir): + os.makedirs(self.save_playback_dir) + + file_path = os.path.join(self.save_playback_dir, f"{run_id}.jsonl") + with open(file_path, "a") as f: + f.write(ReplayMessage(datetime.datetime.now(), message).to_json() + "\n") + + def run(self, config): + @asynccontextmanager + async def lifespan(app: FastAPI): + app.state.db = self.log_db + app.state.clients = [] + + yield + + for client in app.state.clients: + await client.websocket.close() + + app = FastAPI(lifespan=lifespan) + + # TODO: re-enable and only allow anything else than localhost when a token is set + """ + app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:4444", "ws://localhost:4444", "wss://pwn.reinsperger.org", "https://pwn.reinsperger.org", "https://dumb-halloween-game.reinsperger.org"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + """ + + templates = Jinja2Templates(directory=TEMPLATE_DIR) + app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") + + @app.get('/favicon.ico') + async def favicon(): + return FileResponse(STATIC_DIR + "/favicon.ico", headers={"Cache-Control": "public, max-age=31536000"}) + + @app.get("/", response_class=HTMLResponse) + async def admin_ui(request: Request): + return templates.TemplateResponse("index.html", {"request": request}) + + @app.websocket("/ingress") + async def ingress_endpoint(websocket: WebSocket): + await websocket.accept() + try: + while True: + # Receive messages from the ingress websocket + data = await websocket.receive_json() + message_type = MessageType(data["type"]) + # parse the data according to the message type into the appropriate dataclass + message = message_type.get_class().from_dict(data["data"]) + + if message_type == MessageType.RUN: + if message.id is None: + message.started_at = datetime.datetime.now() + message.id = app.state.db.create_run(message.model, message.tag, message.started_at, message.configuration) + data["data"]["id"] = message.id # set the id also in the raw data, so we can properly serialize it to replays + else: + app.state.db.update_run(message.id, message.model, message.state, message.tag, message.started_at, message.stopped_at, message.configuration) + await websocket.send_text(message.to_json()) + + elif message_type == MessageType.MESSAGE: + app.state.db.add_or_update_message(message.run_id, message.id, message.conversation, message.role, message.content, message.tokens_query, message.tokens_response, message.duration) + + elif message_type == MessageType.MESSAGE_STREAM_PART: + app.state.db.handle_message_update(message.run_id, message.message_id, message.action, message.content) + + elif message_type == MessageType.TOOL_CALL: + app.state.db.add_tool_call(message.run_id, message.message_id, message.id, message.function_name, message.arguments, message.result_text, message.duration) + + elif message_type == MessageType.SECTION: + app.state.db.add_section(message.run_id, message.id, message.name, message.from_message, message.to_message, message.duration) + + else: + print("UNHANDLED ingress", message) + + control_message = ControlMessage(type=message_type, data=message) + await self.save_message(control_message) + for client in app.state.clients: + await client.queue.put(control_message) + + except WebSocketDisconnect as e: + import traceback + traceback.print_exc() + print("Ingress WebSocket disconnected") + + @app.websocket("/client") + async def client_endpoint(websocket: WebSocket): + await websocket.accept() + client = Client(websocket, app.state.db) + app.state.clients.append(client) + + # run the receiving and sending tasks in the background until one of them returns + tasks = () + try: + tasks = ( + asyncio.create_task(client.send_messages()), + asyncio.create_task(client.receive_messages()), + ) + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + except WebSocketDisconnect: + # read the task exceptions, close remaining tasks + for task in tasks: + if task.exception(): + print(task.exception()) + else: + task.cancel() + app.state.clients.remove(client) + print("Egress WebSocket disconnected") + + import uvicorn + listen_parts = self.log_server_address.split(":", 1) + if len(listen_parts) != 2: + if listen_parts[0].startswith("http://"): + listen_parts.append("80") + elif listen_parts[0].startswith("https://"): + listen_parts.append("443") + else: + raise ValueError(f"Invalid log server address (does not contain http/https or a port): {self.log_server_address}") + + listen_host, listen_port = listen_parts[0], int(listen_parts[1]) + if listen_host.startswith("http://"): + listen_host = listen_host[len("http://"):] + elif listen_host.startswith("https://"): + listen_host = listen_host[len("https://"):] + uvicorn.run(app, host=listen_host, port=listen_port) + + def get_name(self) -> str: + return "log_viewer" + + +@use_case("Tool to replay the .jsonl logs generated by the Viewer (not well tested)") +class Replayer(UseCase): + log: GlobalRemoteLogger = None + replay_file: str = None + pause_on_message: bool = False + pause_on_tool_calls: bool = False + playback_speed: float = 1.0 + + def get_name(self) -> str: + return "replayer" + + def init(self, configuration): + self.log.init_websocket() # we don't want to automatically start a run here + + def run(self): + recording_start: Optional[datetime.datetime] = None + replay_start: datetime.datetime = datetime.datetime.now() + + print(f"replaying {self.replay_file}") + for line in open(self.replay_file, "r"): + data = json.loads(line) + msg: ReplayMessage = ReplayMessage.from_dict(data) + msg.message.type = MessageType(data["message"]["type"]) + msg.message.data = msg.message.type.get_class().from_dict(data["message"]["data"]) + + if recording_start is None: + if msg.message.type != MessageType.RUN: + raise ValueError("First message must be a RUN message, is", msg.message.type) + recording_start = msg.at + self.log.start_run(msg.message.data.model, msg.message.data.tag, msg.message.data.configuration, msg.at) + + # wait until the message should be sent + sleep_time = ((msg.at - recording_start) / self.playback_speed) - (datetime.datetime.now() - replay_start) + if sleep_time.total_seconds() > 3: + print(msg) + print(f"sleeping for {sleep_time.total_seconds()}s") + time.sleep(max(sleep_time.total_seconds(), 0)) + + if isinstance(msg.message.data, Run): + msg.message.data.id = self.log.run.id + elif hasattr(msg.message.data, "run_id"): + msg.message.data.run_id = self.log.run.id + else: + raise ValueError("Message has no run_id", msg.message.data) + + if self.pause_on_message and msg.message.type == MessageType.MESSAGE \ + or self.pause_on_tool_calls and msg.message.type == MessageType.TOOL_CALL: + input("Paused, press Enter to continue") + replay_start = datetime.datetime.now() - (msg.at - recording_start) + + print("sending") + self.log.send(msg.message.type, msg.message.data) diff --git a/src/hackingBuddyGPT/usecases/web/__init__.py b/src/hackingBuddyGPT/usecases/web/__init__.py index 0d9307bc..d09ebd99 100644 --- a/src/hackingBuddyGPT/usecases/web/__init__.py +++ b/src/hackingBuddyGPT/usecases/web/__init__.py @@ -1,2 +1,3 @@ -from .simple import MinimalWebTesting from .with_explanation import WebTestingWithExplanation + +__all__ = ['WebTestingWithExplanation'] diff --git a/src/hackingBuddyGPT/usecases/web/simple.py b/src/hackingBuddyGPT/usecases/web/simple.py deleted file mode 100644 index 22152b55..00000000 --- a/src/hackingBuddyGPT/usecases/web/simple.py +++ /dev/null @@ -1,85 +0,0 @@ -import pydantic_core -import time - -from dataclasses import field -from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage -from rich.panel import Panel -from typing import List, Any, Union, Dict - -from hackingBuddyGPT.capabilities import Capability -from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model -from hackingBuddyGPT.capabilities.http_request import HTTPRequest -from hackingBuddyGPT.capabilities.record_note import RecordNote -from hackingBuddyGPT.capabilities.submit_flag import SubmitFlag -from hackingBuddyGPT.usecases.agents import Agent -from hackingBuddyGPT.utils import LLMResult, tool_message -from hackingBuddyGPT.usecases.base import use_case, AutonomousAgentUseCase -from hackingBuddyGPT.utils.configurable import parameter -from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib - - -Prompt = List[Union[ChatCompletionMessage, ChatCompletionMessageParam]] -Context = Any - - -class MinimalWebTesting(Agent): - llm: OpenAILib - host: str = parameter(desc="The host to test", default="http://localhost") - flag_format_description: str = parameter(desc="Description of the flag provided to the LLM", default="a string starting with 'FLAG.' and ending with '.GALF'") - flag_template: str = parameter(desc="The template of the flags, whereby {flag} is replaced with the flags", default="FLAG.{flag}.GALF") - flags: str = parameter(desc="A comma (,) separated list of flags to find", default="hostname,dir,username,rootfile,secretfile,adminpass") - - _prompt_history: Prompt = field(default_factory=list) - _context: Context = field(default_factory=lambda: {"notes": list()}) - _capabilities: Dict[str, Capability] = field(default_factory=dict) - _all_flags_found: bool = False - - def init(self): - super().init() - self._prompt_history.append( - { - "role": "system", - "content": f"You are a senior penetration tester, having received a new project of testing a website at {self.host}." - f"Your goal is first to explore the website, building up your knowledge about its parts and inner workings, and then try to plan and execute possible exploitations.\n" - f"Make sure to take lots of notes about the pages you visit, so that you can later more easily work through the page.\n" - f"Do not brute-force anything, all information that you might need can be found on the webserver.\n" - f"THE MOST IMPORTANT THING TO DO IS, that if you see a flag ({self.flag_format_description}), you should submit it immediately."}) - self._context["host"] = self.host - self._capabilities = { - "submit_flag": SubmitFlag(self.flag_format_description, set(self.flag_template.format(flag=flag) for flag in self.flags.split(",")), success_function=self.all_flags_found), - "http_request": HTTPRequest(self.host), - "record_note": RecordNote(self._context["notes"]), - } - - def all_flags_found(self): - self._log.console.print(Panel("All flags found! Congratulations!", title="system")) - self._all_flags_found = True - - def perform_round(self, turn: int): - with self._log.console.status("[bold green]Asking LLM for a new command..."): - prompt = self._prompt_history # TODO: in the future, this should do some context truncation - - tic = time.perf_counter() - response, completion = self.llm.instructor.chat.completions.create_with_completion(model=self.llm.model, messages=prompt, response_model=capabilities_to_action_model(self._capabilities)) - toc = time.perf_counter() - - message = completion.choices[0].message - tool_call_id = message.tool_calls[0].id - command = pydantic_core.to_json(response).decode() - self._log.console.print(Panel(command, title="assistant")) - self._prompt_history.append(message) - - answer = LLMResult(completion.choices[0].message.content, str(prompt), completion.choices[0].message.content, toc-tic, completion.usage.prompt_tokens, completion.usage.completion_tokens) - - with self._log.console.status("[bold green]Executing that command..."): - result = response.execute() - self._log.console.print(Panel(result, title="tool")) - self._prompt_history.append(tool_message(result, tool_call_id)) - - self._log.log_db.add_log_query(self._log.run_id, turn, command, result, answer) - return self._all_flags_found - - -@use_case("Minimal implementation of a web testing use case") -class MinimalWebTestingUseCase(AutonomousAgentUseCase[MinimalWebTesting]): - pass diff --git a/src/hackingBuddyGPT/usecases/web/with_explanation.py b/src/hackingBuddyGPT/usecases/web/with_explanation.py index 96dd657d..5ed7f44a 100644 --- a/src/hackingBuddyGPT/usecases/web/with_explanation.py +++ b/src/hackingBuddyGPT/usecases/web/with_explanation.py @@ -1,9 +1,8 @@ -import time from dataclasses import field -from typing import List, Any, Union, Dict +from typing import List, Any, Union, Dict, Iterable, Optional from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage -from rich.panel import Panel +from openai.types.chat.chat_completion_chunk import ChoiceDelta from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.http_request import HTTPRequest @@ -14,7 +13,6 @@ from hackingBuddyGPT.utils.configurable import parameter from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib - Prompt = List[Union[ChatCompletionMessage, ChatCompletionMessageParam]] Context = Any @@ -22,9 +20,18 @@ class WebTestingWithExplanation(Agent): llm: OpenAILib host: str = parameter(desc="The host to test", default="http://localhost") - flag_format_description: str = parameter(desc="Description of the flag provided to the LLM", default="a string starting with 'FLAG.' and ending with '.GALF'") - flag_template: str = parameter(desc="The template of the flags, whereby {flag} is replaced with the flags", default="FLAG.{flag}.GALF") - flags: str = parameter(desc="A comma (,) separated list of flags to find", default="hostname,dir,username,rootfile,secretfile,adminpass") + flag_format_description: str = parameter( + desc="Description of the flag provided to the LLM", + default="a string starting with 'FLAG.' and ending with '.GALF'", + ) + flag_template: str = parameter( + desc="The template of the flags, whereby {flag} is replaced with the flags", + default="FLAG.{flag}.GALF", + ) + flags: str = parameter( + desc="A comma (,) separated list of flags to find", + default="hostname,dir,username,rootfile,secretfile,adminpass", + ) _prompt_history: Prompt = field(default_factory=list) _context: Context = field(default_factory=lambda: {"notes": list()}) @@ -33,47 +40,51 @@ class WebTestingWithExplanation(Agent): def init(self): super().init() - self._prompt_history.append( - { - "role": "system", - "content": f"You are a senior penetration tester, having received a new project of testing a website at {self.host}." - f"Your goal is first to explore the website, building up your knowledge about its parts and inner workings, and then try to plan and execute possible exploitations.\n" - f"Think step by step and say everything out loud that you might be doing.\n" - f"Give a list of assumptions you base your tool calls on, before you start a tool call, and reason whether they held or not after the call has completed.\n" - f"Do not brute-force anything, all information that you might need can be found on the webserver.\n" - f"THE MOST IMPORTANT THING TO DO IS, that if you see a flag ({self.flag_format_description}), you should submit it immediately."}) self._context["host"] = self.host - self._capabilities = { - "submit_flag": SubmitFlag(self.flag_format_description, set(self.flag_template.format(flag=flag) for flag in self.flags.split(",")), success_function=self.all_flags_found), - "http_request": HTTPRequest(self.host), - } + self.add_capability(SubmitFlag(self.flag_format_description, set(self.flag_template.format(flag=flag) for flag in self.flags.split(",")), success_function=self.all_flags_found)) + self.add_capability(HTTPRequest(self.host)) + + def before_run(self): + system_message = ( + f"You are a senior penetration tester, having received a new project of testing a website at {self.host}." + f"Your goal is first to explore the website, building up your knowledge about its parts and inner workings, and then try to plan and execute possible exploitations.\n" + f"Start out by trying to get logged in as the admin of the site. DO NOT BRUTE-FORCE CREDENTIALS, try to always keep responses in mind, especially if they contain unusual information or pointers to where to find additional information!\n" + f"Think step by step and say everything out loud that you might be doing.\n" + f"Give a list of assumptions you base your tool calls on, before you start a tool call, and reason whether they held or not after the call has completed. Try to do multiple tool calls at once, you have a limited number of turn, but not tool calls.\n" + f"Do not brute-force anything, all information that you might need can be found on the webserver.\n" + f"THE MOST IMPORTANT THING TO DO IS, that if you see a flag ({self.flag_format_description}), you should submit it immediately." + ) + self._prompt_history.append({ "role": "system", "content": system_message }) + self.log.system_message(system_message) def all_flags_found(self): - self._log.console.print(Panel("All flags found! Congratulations!", title="system")) + self.log.status_message("All flags found! Congratulations!") self._all_flags_found = True def perform_round(self, turn: int): prompt = self._prompt_history # TODO: in the future, this should do some context truncation - result: LLMResult = None - stream = self.llm.stream_response(prompt, self._log.console, capabilities=self._capabilities) - for part in stream: - result = part + result_stream: Iterable[Union[ChoiceDelta, LLMResult]] = self.llm.stream_response(prompt, self.log.console, capabilities=self._capabilities, get_individual_updates=True) + result: Optional[LLMResult] = None + stream_output = self.log.stream_message("assistant") # TODO: do not hardcode the role + for delta in result_stream: + if isinstance(delta, LLMResult): + result = delta + break + if delta.content is not None: + stream_output.append(delta.content) + if result is None: + self.log.error_message("No result from the LLM") + return False + message_id = stream_output.finalize(result.tokens_query, result.tokens_response, result.duration) message: ChatCompletionMessage = result.result - message_id = self._log.log_db.add_log_message(self._log.run_id, message.role, message.content, result.tokens_query, result.tokens_response, result.duration) self._prompt_history.append(result.result) if message.tool_calls is not None: for tool_call in message.tool_calls: - tic = time.perf_counter() - tool_call_result = self._capabilities[tool_call.function.name].to_model().model_validate_json(tool_call.function.arguments).execute() - toc = time.perf_counter() - - self._log.console.print(f"\n[bold green on gray3]{' '*self._log.console.width}\nTOOL RESPONSE:[/bold green on gray3]") - self._log.console.print(tool_call_result) - self._prompt_history.append(tool_message(tool_call_result, tool_call.id)) - self._log.log_db.add_log_tool_call(self._log.run_id, message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments, tool_call_result, toc - tic) + tool_result = self.run_capability_json(message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments) + self._prompt_history.append(tool_message(tool_result, tool_call.id)) return self._all_flags_found diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/__init__.py index a8c6ba18..bae1cbfc 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/__init__.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/__init__.py @@ -1,2 +1,2 @@ +from .simple_openapi_documentation import SimpleWebAPIDocumentation from .simple_web_api_testing import SimpleWebAPITesting -from .simple_openapi_documentation import SimpleWebAPIDocumentation \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/__init__.py index b4782f56..3038bb3b 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/__init__.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/__init__.py @@ -1,2 +1,2 @@ from .openapi_specification_handler import OpenAPISpecificationHandler -from .report_handler import ReportHandler \ No newline at end of file +from .report_handler import ReportHandler diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/openapi_specification_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/openapi_specification_handler.py index dd64f269..3e9d7059 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/openapi_specification_handler.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/openapi_specification_handler.py @@ -1,14 +1,17 @@ import os -import yaml -from datetime import datetime -from hackingBuddyGPT.capabilities.yamlFile import YAMLFile from collections import defaultdict +from datetime import datetime + import pydantic_core +import yaml from rich.panel import Panel +from hackingBuddyGPT.capabilities.yamlFile import YAMLFile from hackingBuddyGPT.usecases.web_api_testing.response_processing import ResponseHandler from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler from hackingBuddyGPT.utils import tool_message + + class OpenAPISpecificationHandler(object): """ Handles the generation and updating of an OpenAPI specification document based on dynamic API responses. @@ -35,26 +38,24 @@ def __init__(self, llm_handler: LLMHandler, response_handler: ResponseHandler): """ self.response_handler = response_handler self.schemas = {} - self.endpoint_methods ={} + self.endpoint_methods = {} self.filename = f"openapi_spec_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.yaml" self.openapi_spec = { "openapi": "3.0.0", "info": { "title": "Generated API Documentation", "version": "1.0", - "description": "Automatically generated description of the API." + "description": "Automatically generated description of the API.", }, "servers": [{"url": "https://jsonplaceholder.typicode.com"}], "endpoints": {}, - "components": {"schemas": {}} + "components": {"schemas": {}}, } self.llm_handler = llm_handler current_path = os.path.dirname(os.path.abspath(__file__)) self.file_path = os.path.join(current_path, "openapi_spec") self.file = os.path.join(self.file_path, self.filename) - self._capabilities = { - "yaml": YAMLFile() - } + self._capabilities = {"yaml": YAMLFile()} def is_partial_match(self, element, string_list): return any(element in string or string in element for string in string_list) @@ -69,23 +70,23 @@ def update_openapi_spec(self, resp, result): """ request = resp.action - if request.__class__.__name__ == 'RecordNote': # TODO: check why isinstance does not work + if request.__class__.__name__ == "RecordNote": # TODO: check why isinstance does not work self.check_openapi_spec(resp) - elif request.__class__.__name__ == 'HTTPRequest': + elif request.__class__.__name__ == "HTTPRequest": path = request.path method = request.method - print(f'method: {method}') + print(f"method: {method}") # Ensure that path and method are not None and method has no numeric characters # Ensure path and method are valid and method has no numeric characters if path and method: endpoint_methods = self.endpoint_methods - endpoints = self.openapi_spec['endpoints'] - x = path.split('/')[1] + endpoints = self.openapi_spec["endpoints"] + x = path.split("/")[1] # Initialize the path if not already present if path not in endpoints and x != "": endpoints[path] = {} - if '1' not in path: + if "1" not in path: endpoint_methods[path] = [] # Update the method description within the path @@ -100,22 +101,17 @@ def update_openapi_spec(self, resp, result): "responses": { "200": { "description": "Successful response", - "content": { - "application/json": { - "schema": {"$ref": reference}, - "examples": example - } - } + "content": {"application/json": {"schema": {"$ref": reference}, "examples": example}}, } - } + }, } - if '1' not in path and x != "": + if "1" not in path and x != "": endpoint_methods[path].append(method) elif self.is_partial_match(x, endpoints.keys()): path = f"/{x}" - print(f'endpoint methods = {endpoint_methods}') - print(f'new path:{path}') + print(f"endpoint methods = {endpoint_methods}") + print(f"new path:{path}") endpoint_methods[path].append(method) endpoint_methods[path] = list(set(endpoint_methods[path])) @@ -133,18 +129,18 @@ def write_openapi_to_yaml(self): "info": self.openapi_spec["info"], "servers": self.openapi_spec["servers"], "components": self.openapi_spec["components"], - "paths": self.openapi_spec["endpoints"] + "paths": self.openapi_spec["endpoints"], } # Create directory if it doesn't exist and generate the timestamped filename os.makedirs(self.file_path, exist_ok=True) # Write to YAML file - with open(self.file, 'w') as yaml_file: + with open(self.file, "w") as yaml_file: yaml.dump(openapi_data, yaml_file, allow_unicode=True, default_flow_style=False) print(f"OpenAPI specification written to {self.filename}.") except Exception as e: - raise Exception(f"Error writing YAML file: {e}") + raise Exception(f"Error writing YAML file: {e}") from e def check_openapi_spec(self, note): """ @@ -154,14 +150,15 @@ def check_openapi_spec(self, note): note (object): The note object containing the description of the API. """ description = self.response_handler.extract_description(note) - from hackingBuddyGPT.usecases.web_api_testing.utils.documentation.parsing.yaml_assistant import YamlFileAssistant + from hackingBuddyGPT.usecases.web_api_testing.utils.documentation.parsing.yaml_assistant import ( + YamlFileAssistant, + ) + yaml_file_assistant = YamlFileAssistant(self.file_path, self.llm_handler) yaml_file_assistant.run(description) - def _update_documentation(self, response, result, prompt_engineer): - prompt_engineer.prompt_helper.found_endpoints = self.update_openapi_spec(response, - result) + prompt_engineer.prompt_helper.found_endpoints = self.update_openapi_spec(response, result) self.write_openapi_to_yaml() prompt_engineer.prompt_helper.schemas = self.schemas @@ -175,28 +172,27 @@ def _update_documentation(self, response, result, prompt_engineer): return prompt_engineer def document_response(self, completion, response, log, prompt_history, prompt_engineer): - message = completion.choices[0].message - tool_call_id = message.tool_calls[0].id - command = pydantic_core.to_json(response).decode() + message = completion.choices[0].message + tool_call_id = message.tool_calls[0].id + command = pydantic_core.to_json(response).decode() - log.console.print(Panel(command, title="assistant")) - prompt_history.append(message) + log.console.print(Panel(command, title="assistant")) + prompt_history.append(message) - with log.console.status("[bold green]Executing that command..."): - result = response.execute() - log.console.print(Panel(result[:30], title="tool")) - result_str = self.response_handler.parse_http_status_line(result) - prompt_history.append(tool_message(result_str, tool_call_id)) + with log.console.status("[bold green]Executing that command..."): + result = response.execute() + log.console.print(Panel(result[:30], title="tool")) + result_str = self.response_handler.parse_http_status_line(result) + prompt_history.append(tool_message(result_str, tool_call_id)) - invalid_flags = {"recorded", "Not a valid HTTP method", "404", "Client Error: Not Found"} - if not result_str in invalid_flags or any(flag in result_str for flag in invalid_flags): - prompt_engineer = self._update_documentation(response, result, prompt_engineer) + invalid_flags = {"recorded", "Not a valid HTTP method", "404", "Client Error: Not Found"} + if result_str not in invalid_flags or any(flag in result_str for flag in invalid_flags): + prompt_engineer = self._update_documentation(response, result, prompt_engineer) - return log, prompt_history, prompt_engineer + return log, prompt_history, prompt_engineer def found_all_endpoints(self): - if len(self.endpoint_methods.items())< 10: + if len(self.endpoint_methods.items()) < 10: return False else: return True - diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/__init__.py index 0fe99b1a..1dc8cc54 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/__init__.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/__init__.py @@ -1,3 +1,3 @@ from .openapi_converter import OpenAPISpecificationConverter from .openapi_parser import OpenAPISpecificationParser -from .yaml_assistant import YamlFileAssistant \ No newline at end of file +from .yaml_assistant import YamlFileAssistant diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_converter.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_converter.py index 5b9c5ed0..3f1156f5 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_converter.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_converter.py @@ -1,6 +1,8 @@ +import json import os.path + import yaml -import json + class OpenAPISpecificationConverter: """ @@ -39,14 +41,14 @@ def convert_file(self, input_filepath, output_directory, input_type, output_type os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(input_filepath, 'r') as infile: - if input_type == 'yaml': + with open(input_filepath, "r") as infile: + if input_type == "yaml": content = yaml.safe_load(infile) else: content = json.load(infile) - with open(output_path, 'w') as outfile: - if output_type == 'yaml': + with open(output_path, "w") as outfile: + if output_type == "yaml": yaml.dump(content, outfile, allow_unicode=True, default_flow_style=False) else: json.dump(content, outfile, indent=2) @@ -68,7 +70,7 @@ def yaml_to_json(self, yaml_filepath): Returns: str: The path to the converted JSON file, or None if an error occurred. """ - return self.convert_file(yaml_filepath, "json", 'yaml', 'json') + return self.convert_file(yaml_filepath, "json", "yaml", "json") def json_to_yaml(self, json_filepath): """ @@ -80,12 +82,12 @@ def json_to_yaml(self, json_filepath): Returns: str: The path to the converted YAML file, or None if an error occurred. """ - return self.convert_file(json_filepath, "yaml", 'json', 'yaml') + return self.convert_file(json_filepath, "yaml", "json", "yaml") # Usage example -if __name__ == '__main__': - yaml_input = '/home/diana/Desktop/masterthesis/hackingBuddyGPT/src/hackingBuddyGPT/usecases/web_api_testing/openapi_spec/openapi_spec_2024-06-13_17-16-25.yaml' +if __name__ == "__main__": + yaml_input = "/home/diana/Desktop/masterthesis/hackingBuddyGPT/src/hackingBuddyGPT/usecases/web_api_testing/openapi_spec/openapi_spec_2024-06-13_17-16-25.yaml" converter = OpenAPISpecificationConverter("converted_files") # Convert YAML to JSON diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_parser.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_parser.py index 6d884349..815cb0c5 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_parser.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_parser.py @@ -1,6 +1,8 @@ -import yaml from typing import Dict, List, Union +import yaml + + class OpenAPISpecificationParser: """ OpenAPISpecificationParser is a class for parsing and extracting information from an OpenAPI specification file. @@ -27,7 +29,7 @@ def load_yaml(self) -> Dict[str, Union[Dict, List]]: Returns: Dict[str, Union[Dict, List]]: The parsed data from the YAML file. """ - with open(self.filepath, 'r') as file: + with open(self.filepath, "r") as file: return yaml.safe_load(file) def _get_servers(self) -> List[str]: @@ -37,7 +39,7 @@ def _get_servers(self) -> List[str]: Returns: List[str]: A list of server URLs. """ - return [server['url'] for server in self.api_data.get('servers', [])] + return [server["url"] for server in self.api_data.get("servers", [])] def get_paths(self) -> Dict[str, Dict[str, Dict]]: """ @@ -47,7 +49,7 @@ def get_paths(self) -> Dict[str, Dict[str, Dict]]: Dict[str, Dict[str, Dict]]: A dictionary with API paths as keys and methods as values. """ paths_info: Dict[str, Dict[str, Dict]] = {} - paths: Dict[str, Dict[str, Dict]] = self.api_data.get('paths', {}) + paths: Dict[str, Dict[str, Dict]] = self.api_data.get("paths", {}) for path, methods in paths.items(): paths_info[path] = {method: details for method, details in methods.items()} return paths_info @@ -62,15 +64,15 @@ def _get_operations(self, path: str) -> Dict[str, Dict]: Returns: Dict[str, Dict]: A dictionary with methods as keys and operation details as values. """ - return self.api_data['paths'].get(path, {}) + return self.api_data["paths"].get(path, {}) def _print_api_details(self) -> None: """ Prints details of the API extracted from the OpenAPI document, including title, version, servers, paths, and operations. """ - print("API Title:", self.api_data['info']['title']) - print("API Version:", self.api_data['info']['version']) + print("API Title:", self.api_data["info"]["title"]) + print("API Version:", self.api_data["info"]["version"]) print("Servers:", self._get_servers()) print("\nAvailable Paths and Operations:") for path, operations in self.get_paths().items(): diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/yaml_assistant.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/yaml_assistant.py index 61998227..667cf710 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/yaml_assistant.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/yaml_assistant.py @@ -1,5 +1,4 @@ from openai import OpenAI -from typing import Any class YamlFileAssistant: @@ -37,7 +36,7 @@ def run(self, recorded_note: str) -> None: The current implementation is commented out and serves as a placeholder for integrating with OpenAI's API. Uncomment and modify the code as needed. """ - ''' + """ assistant = self.client.beta.assistants.create( name="Yaml File Analysis Assistant", instructions="You are an OpenAPI specification analyst. Use your knowledge to check " @@ -88,4 +87,4 @@ def run(self, recorded_note: str) -> None: # The thread now has a vector store with that file in its tool resources. print(thread.tool_resources.file_search) - ''' + """ diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/report_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/report_handler.py index 6eb7e17c..6c10f88d 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/report_handler.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/report_handler.py @@ -1,8 +1,9 @@ import os -from datetime import datetime import uuid -from typing import List +from datetime import datetime from enum import Enum +from typing import List + class ReportHandler: """ @@ -25,13 +26,17 @@ def __init__(self): if not os.path.exists(self.file_path): os.mkdir(self.file_path) - self.report_name: str = os.path.join(self.file_path, f"report_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt") + self.report_name: str = os.path.join( + self.file_path, f"report_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt" + ) try: self.report = open(self.report_name, "x") except FileExistsError: # Retry with a different name using a UUID to ensure uniqueness - self.report_name = os.path.join(self.file_path, - f"report_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{uuid.uuid4().hex}.txt") + self.report_name = os.path.join( + self.file_path, + f"report_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{uuid.uuid4().hex}.txt", + ) self.report = open(self.report_name, "x") def write_endpoint_to_report(self, endpoint: str) -> None: @@ -41,8 +46,8 @@ def write_endpoint_to_report(self, endpoint: str) -> None: Args: endpoint (str): The endpoint information to be recorded in the report. """ - with open(self.report_name, 'a') as report: - report.write(f'{endpoint}\n') + with open(self.report_name, "a") as report: + report.write(f"{endpoint}\n") def write_analysis_to_report(self, analysis: List[str], purpose: Enum) -> None: """ @@ -52,8 +57,8 @@ def write_analysis_to_report(self, analysis: List[str], purpose: Enum) -> None: analysis (List[str]): The analysis data to be recorded. purpose (Enum): An enumeration that describes the purpose of the analysis. """ - with open(self.report_name, 'a') as report: - report.write(f'{purpose.name}:\n') + with open(self.report_name, "a") as report: + report.write(f"{purpose.name}:\n") for item in analysis: for line in item.split("\n"): if "note recorded" in line: diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/__init__.py index 6e43f7b6..fad13dab 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/__init__.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/__init__.py @@ -1,2 +1,2 @@ from .pentesting_information import PenTestingInformation -from .prompt_information import PromptPurpose, PromptStrategy, PromptContext \ No newline at end of file +from .prompt_information import PromptContext, PromptPurpose, PromptStrategy diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/pentesting_information.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/pentesting_information.py index 58b839ba..ce5874f9 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/pentesting_information.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/pentesting_information.py @@ -1,6 +1,8 @@ from typing import Dict, List -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptPurpose +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PromptPurpose, +) class PenTestingInformation: @@ -53,15 +55,15 @@ def init_steps(self) -> Dict[PromptPurpose, List[str]]: "Check for proper error handling, response codes, and sanitization.", "Attempt to exploit common vulnerabilities by injecting malicious inputs, such as SQL injection, NoSQL injection, " "cross-site scripting, and other injection attacks. Evaluate whether the API properly validates, escapes, and sanitizes " - "all user-supplied data, ensuring no unexpected behavior or security vulnerabilities are exposed." + "all user-supplied data, ensuring no unexpected behavior or security vulnerabilities are exposed.", ], PromptPurpose.ERROR_HANDLING_INFORMATION_LEAKAGE: [ "Check how the API handles errors and if there are detailed error messages.", - "Look for vulnerabilities and information leakage." + "Look for vulnerabilities and information leakage.", ], PromptPurpose.SESSION_MANAGEMENT: [ "Check if the API uses session management.", - "Look at the session handling mechanism for vulnerabilities such as session fixation, session hijacking, or session timeout settings." + "Look at the session handling mechanism for vulnerabilities such as session fixation, session hijacking, or session timeout settings.", ], PromptPurpose.CROSS_SITE_SCRIPTING: [ "Look for vulnerabilities that could enable malicious scripts to be injected into API responses." @@ -94,7 +96,8 @@ def analyse_steps(self, response: str = "") -> Dict[PromptPurpose, List[str]]: dict: A dictionary where each key is a PromptPurpose and each value is a list of prompts. """ return { - PromptPurpose.PARSING: [f""" Please parse this response and extract the following details in JSON format: {{ + PromptPurpose.PARSING: [ + f""" Please parse this response and extract the following details in JSON format: {{ "Status Code": "", "Reason Phrase": "", "Headers": , @@ -102,20 +105,18 @@ def analyse_steps(self, response: str = "") -> Dict[PromptPurpose, List[str]]: from this response: {response} }}""" - - ], + ], PromptPurpose.ANALYSIS: [ - f'Given the following parsed HTTP response:\n{response}\n' - 'Please analyze this response to determine:\n' - '1. Whether the status code is appropriate for this type of request.\n' - '2. If the headers indicate proper security and rate-limiting practices.\n' - '3. Whether the response body is correctly handled.' + f"Given the following parsed HTTP response:\n{response}\n" + "Please analyze this response to determine:\n" + "1. Whether the status code is appropriate for this type of request.\n" + "2. If the headers indicate proper security and rate-limiting practices.\n" + "3. Whether the response body is correctly handled." ], PromptPurpose.DOCUMENTATION: [ - f'Based on the analysis provided, document the findings of this API response validation:\n{response}' + f"Based on the analysis provided, document the findings of this API response validation:\n{response}" ], PromptPurpose.REPORTING: [ - f'Based on the documented findings : {response}. Suggest any improvements or issues that should be reported to the API developers.' - ] + f"Based on the documented findings : {response}. Suggest any improvements or issues that should be reported to the API developers." + ], } - diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/prompt_information.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/prompt_information.py index d844ff36..17e7a140 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/prompt_information.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/prompt_information.py @@ -10,13 +10,12 @@ class PromptStrategy(Enum): CHAIN_OF_THOUGHT (int): Represents the chain-of-thought strategy. TREE_OF_THOUGHT (int): Represents the tree-of-thought strategy. """ + IN_CONTEXT = 1 CHAIN_OF_THOUGHT = 2 TREE_OF_THOUGHT = 3 -from enum import Enum - class PromptContext(Enum): """ Enumeration for general contexts in which prompts are generated. @@ -25,6 +24,7 @@ class PromptContext(Enum): DOCUMENTATION (int): Represents the documentation context. PENTESTING (int): Represents the penetration testing context. """ + DOCUMENTATION = 1 PENTESTING = 2 @@ -37,11 +37,11 @@ class PlanningType(Enum): TASK_PLANNING (int): Represents the task planning context. STATE_PLANNING (int): Represents the state planning context. """ + TASK_PLANNING = 1 STATE_PLANNING = 2 - class PromptPurpose(Enum): """ Enum representing various purposes for prompt testing in security assessments. @@ -63,8 +63,7 @@ class PromptPurpose(Enum): SECURITY_MISCONFIGURATIONS = 10 LOGGING_MONITORING = 11 - #Analysis + # Analysis PARSING = 12 ANALYSIS = 13 REPORTING = 14 - diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_engineer.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_engineer.py index 16e478aa..54e3aea7 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_engineer.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_engineer.py @@ -1,8 +1,19 @@ from instructor.retry import InstructorRetryException -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptStrategy, PromptContext -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_generation_helper import PromptGenerationHelper -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.task_planning import ChainOfThoughtPrompt, TreeOfThoughtPrompt -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.state_learning import InContextLearningPrompt + +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PromptContext, + PromptStrategy, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_generation_helper import ( + PromptGenerationHelper, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.state_learning import ( + InContextLearningPrompt, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.task_planning import ( + ChainOfThoughtPrompt, + TreeOfThoughtPrompt, +) from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Prompt from hackingBuddyGPT.utils import tool_message @@ -10,9 +21,15 @@ class PromptEngineer: """Prompt engineer that creates prompts of different types.""" - def __init__(self, strategy: PromptStrategy = None, history: Prompt = None, handlers=(), - context: PromptContext = None, rest_api: str = "", - schemas: dict = None): + def __init__( + self, + strategy: PromptStrategy = None, + history: Prompt = None, + handlers=(), + context: PromptContext = None, + rest_api: str = "", + schemas: dict = None, + ): """ Initializes the PromptEngineer with a specific strategy and handlers for LLM and responses. @@ -33,18 +50,22 @@ def __init__(self, strategy: PromptStrategy = None, history: Prompt = None, hand self._prompt_history = history or [] self.strategies = { - PromptStrategy.CHAIN_OF_THOUGHT: ChainOfThoughtPrompt(context=self.context, - prompt_helper=self.prompt_helper), - PromptStrategy.TREE_OF_THOUGHT: TreeOfThoughtPrompt(context=self.context, prompt_helper=self.prompt_helper, - rest_api=self.rest_api), - PromptStrategy.IN_CONTEXT: InContextLearningPrompt(context=self.context, prompt_helper=self.prompt_helper, - context_information={ - self.turn: {"content": "initial_prompt"}}) + PromptStrategy.CHAIN_OF_THOUGHT: ChainOfThoughtPrompt( + context=self.context, prompt_helper=self.prompt_helper + ), + PromptStrategy.TREE_OF_THOUGHT: TreeOfThoughtPrompt( + context=self.context, prompt_helper=self.prompt_helper, rest_api=self.rest_api + ), + PromptStrategy.IN_CONTEXT: InContextLearningPrompt( + context=self.context, + prompt_helper=self.prompt_helper, + context_information={self.turn: {"content": "initial_prompt"}}, + ), } self.purpose = None - def generate_prompt(self, turn:int, move_type="explore", hint=""): + def generate_prompt(self, turn: int, move_type="explore", hint=""): """ Generates a prompt based on the specified strategy and gets a response. @@ -67,9 +88,9 @@ def generate_prompt(self, turn:int, move_type="explore", hint=""): self.turn = turn while not is_good: try: - prompt = prompt_func.generate_prompt(move_type=move_type, hint= hint, - previous_prompt=self._prompt_history, - turn=0) + prompt = prompt_func.generate_prompt( + move_type=move_type, hint=hint, previous_prompt=self._prompt_history, turn=0 + ) self.purpose = prompt_func.purpose is_good = self.evaluate_response(prompt, "") except InstructorRetryException: @@ -109,7 +130,7 @@ def process_step(self, step: str, prompt_history: list) -> tuple[list, str]: Returns: tuple: Updated prompt history and the result of the step processing. """ - print(f'Processing step: {step}') + print(f"Processing step: {step}") prompt_history.append({"role": "system", "content": step}) # Call the LLM and handle the response diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_generation_helper.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_generation_helper.py index 24f07391..040ef6bd 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_generation_helper.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_generation_helper.py @@ -1,6 +1,8 @@ import re + import nltk -from hackingBuddyGPT.usecases.web_api_testing.response_processing import ResponseHandler + +from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_handler import ResponseHandler class PromptGenerationHelper(object): @@ -15,7 +17,7 @@ class PromptGenerationHelper(object): schemas (dict): A dictionary of schemas used for constructing HTTP requests. """ - def __init__(self, response_handler:ResponseHandler=None, schemas:dict={}): + def __init__(self, response_handler: ResponseHandler = None, schemas: dict = None): """ Initializes the PromptAssistant with a response handler and downloads necessary NLTK models. @@ -23,6 +25,9 @@ def __init__(self, response_handler:ResponseHandler=None, schemas:dict={}): response_handler (object): The response handler used for managing responses. schemas(tuple): Schemas used """ + if schemas is None: + schemas = {} + self.response_handler = response_handler self.found_endpoints = ["/"] self.endpoint_methods = {} @@ -30,11 +35,8 @@ def __init__(self, response_handler:ResponseHandler=None, schemas:dict={}): self.schemas = schemas # Download NLTK models if not already installed - nltk.download('punkt') - nltk.download('stopwords') - - - + nltk.download("punkt") + nltk.download("stopwords") def get_endpoints_needing_help(self): """ @@ -72,13 +74,9 @@ def get_http_action_template(self, method): str: The constructed HTTP action description. """ if method in ["POST", "PUT"]: - return ( - f"Create HTTPRequests of type {method} considering the found schemas: {self.schemas} and understand the responses. Ensure that they are correct requests." - ) + return f"Create HTTPRequests of type {method} considering the found schemas: {self.schemas} and understand the responses. Ensure that they are correct requests." else: - return ( - f"Create HTTPRequests of type {method} considering only the object with id=1 for the endpoint and understand the responses. Ensure that they are correct requests." - ) + return f"Create HTTPRequests of type {method} considering only the object with id=1 for the endpoint and understand the responses. Ensure that they are correct requests." def get_initial_steps(self, common_steps): """ @@ -93,7 +91,7 @@ def get_initial_steps(self, common_steps): return [ f"Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}", "Note down the response structures, status codes, and headers for each endpoint.", - "For each endpoint, document the following details: URL, HTTP method, query parameters and path variables, expected request body structure for requests, response structure for successful and error responses." + "For each endpoint, document the following details: URL, HTTP method, query parameters and path variables, expected request body structure for requests, response structure for successful and error responses.", ] + common_steps def token_count(self, text): @@ -106,7 +104,7 @@ def token_count(self, text): Returns: int: The number of tokens in the input text. """ - tokens = re.findall(r'\b\w+\b', text) + tokens = re.findall(r"\b\w+\b", text) words = [token.strip("'") for token in tokens if token.strip("'").isalnum()] return len(words) @@ -135,7 +133,7 @@ def validate_prompt(prompt): if isinstance(steps, list): potential_prompt = "\n".join(str(element) for element in steps) else: - potential_prompt = str(steps) +"\n" + potential_prompt = str(steps) + "\n" return validate_prompt(potential_prompt) return validate_prompt(previous_prompt) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/__init__.py index fd5a389c..e438e6d8 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/__init__.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/__init__.py @@ -1 +1 @@ -from .basic_prompt import BasicPrompt \ No newline at end of file +from .basic_prompt import BasicPrompt diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/basic_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/basic_prompt.py index 85d4686e..af753d5c 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/basic_prompt.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/basic_prompt.py @@ -1,10 +1,15 @@ from abc import ABC, abstractmethod from typing import Optional -#from hackingBuddyGPT.usecases.web_api_testing.prompt_generation import PromptGenerationHelper -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information import PenTestingInformation -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptStrategy, \ - PromptContext, PlanningType +# from hackingBuddyGPT.usecases.web_api_testing.prompt_generation import PromptGenerationHelper +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information import ( + PenTestingInformation, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PlanningType, + PromptContext, + PromptStrategy, +) class BasicPrompt(ABC): @@ -22,9 +27,13 @@ class BasicPrompt(ABC): pentesting_information (Optional[PenTestingInformation]): Contains information relevant to pentesting when the context is pentesting. """ - def __init__(self, context: PromptContext = None, planning_type: PlanningType = None, - prompt_helper= None, - strategy: PromptStrategy = None): + def __init__( + self, + context: PromptContext = None, + planning_type: PlanningType = None, + prompt_helper=None, + strategy: PromptStrategy = None, + ): """ Initializes the BasicPrompt with a specific context, prompt helper, and strategy. @@ -44,8 +53,9 @@ def __init__(self, context: PromptContext = None, planning_type: PlanningType = self.pentesting_information = PenTestingInformation(schemas=prompt_helper.schemas) @abstractmethod - def generate_prompt(self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], - turn: Optional[int]) -> str: + def generate_prompt( + self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], turn: Optional[int] + ) -> str: """ Abstract method to generate a prompt. diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/__init__.py index 87435d6b..1a083990 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/__init__.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/__init__.py @@ -1,2 +1,2 @@ -from .state_planning_prompt import StatePlanningPrompt from .in_context_learning_prompt import InContextLearningPrompt +from .state_planning_prompt import StatePlanningPrompt diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/in_context_learning_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/in_context_learning_prompt.py index 8e3e0d7f..f5772683 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/in_context_learning_prompt.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/in_context_learning_prompt.py @@ -1,8 +1,13 @@ -from typing import List, Dict, Optional +from typing import Dict, Optional -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptStrategy, \ - PromptContext, PromptPurpose -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.state_learning.state_planning_prompt import StatePlanningPrompt +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PromptContext, + PromptPurpose, + PromptStrategy, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.state_learning.state_planning_prompt import ( + StatePlanningPrompt, +) class InContextLearningPrompt(StatePlanningPrompt): @@ -35,8 +40,9 @@ def __init__(self, context: PromptContext, prompt_helper, context_information: D self.prompt: Dict[int, Dict[str, str]] = context_information self.purpose: Optional[PromptPurpose] = None - def generate_prompt(self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], - turn: Optional[int]) -> str: + def generate_prompt( + self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], turn: Optional[int] + ) -> str: """ Generates a prompt using the in-context learning strategy. diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/state_planning_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/state_planning_prompt.py index c6739a48..5cbb936b 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/state_planning_prompt.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/state_planning_prompt.py @@ -1,10 +1,11 @@ -from abc import ABC, abstractmethod -from typing import Optional - -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information import PenTestingInformation -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptStrategy, \ - PromptContext, PlanningType -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts import BasicPrompt +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PlanningType, + PromptContext, + PromptStrategy, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts import ( + BasicPrompt, +) class StatePlanningPrompt(BasicPrompt): @@ -30,6 +31,9 @@ def __init__(self, context: PromptContext, prompt_helper, strategy: PromptStrate prompt_helper (PromptHelper): A helper object for managing and generating prompts. strategy (PromptStrategy): The state planning strategy used for prompt generation. """ - super().__init__(context=context, planning_type=PlanningType.STATE_PLANNING, prompt_helper=prompt_helper, - strategy=strategy) - + super().__init__( + context=context, + planning_type=PlanningType.STATE_PLANNING, + prompt_helper=prompt_helper, + strategy=strategy, + ) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/__init__.py index b2cadb8f..a09a9b14 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/__init__.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/__init__.py @@ -1,3 +1,3 @@ -from .task_planning_prompt import TaskPlanningPrompt from .chain_of_thought_prompt import ChainOfThoughtPrompt +from .task_planning_prompt import TaskPlanningPrompt from .tree_of_thought_prompt import TreeOfThoughtPrompt diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/chain_of_thought_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/chain_of_thought_prompt.py index 7d6f0197..9825d17c 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/chain_of_thought_prompt.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/chain_of_thought_prompt.py @@ -1,7 +1,13 @@ from typing import List, Optional -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptStrategy, PromptContext, PromptPurpose -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.task_planning.task_planning_prompt import TaskPlanningPrompt +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PromptContext, + PromptPurpose, + PromptStrategy, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.task_planning.task_planning_prompt import ( + TaskPlanningPrompt, +) class ChainOfThoughtPrompt(TaskPlanningPrompt): @@ -31,8 +37,9 @@ def __init__(self, context: PromptContext, prompt_helper): self.explored_steps: List[str] = [] self.purpose: Optional[PromptPurpose] = None - def generate_prompt(self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], - turn: Optional[int]) -> str: + def generate_prompt( + self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], turn: Optional[int] + ) -> str: """ Generates a prompt using the chain-of-thought strategy. @@ -66,14 +73,14 @@ def _get_common_steps(self) -> List[str]: "Create an OpenAPI document including metadata such as API title, version, and description, define the base URL of the API, list all endpoints, methods, parameters, and responses, and define reusable schemas, response types, and parameters.", "Ensure the correctness and completeness of the OpenAPI specification by validating the syntax and completeness of the document using tools like Swagger Editor, and ensure the specification matches the actual behavior of the API.", "Refine the document based on feedback and additional testing, share the draft with others, gather feedback, and make necessary adjustments. Regularly update the specification as the API evolves.", - "Make the OpenAPI specification available to developers by incorporating it into your API documentation site and keep the documentation up to date with API changes." + "Make the OpenAPI specification available to developers by incorporating it into your API documentation site and keep the documentation up to date with API changes.", ] else: return [ "Identify common data structures returned by various endpoints and define them as reusable schemas, specifying field types like integer, string, and array.", "Create an OpenAPI document that includes API metadata (title, version, description), the base URL, endpoints, methods, parameters, and responses.", "Ensure the document's correctness and completeness using tools like Swagger Editor, and verify it matches the API's behavior. Refine the document based on feedback, share drafts for review, and update it regularly as the API evolves.", - "Make the specification available to developers through the API documentation site, keeping it current with any API changes." + "Make the specification available to developers through the API documentation site, keeping it current with any API changes.", ] def _get_chain_of_thought_steps(self, common_steps: List[str], move_type: str) -> List[str]: @@ -133,7 +140,7 @@ def _get_pentesting_steps(self, move_type: str) -> List[str]: if len(step) == 1: del self.pentesting_information.explore_steps[purpose] - print(f'prompt: {prompt}') + print(f"prompt: {prompt}") return prompt else: return ["Look for exploits."] diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/task_planning_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/task_planning_prompt.py index 5f9624e5..181f30ab 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/task_planning_prompt.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/task_planning_prompt.py @@ -1,10 +1,11 @@ -from abc import ABC, abstractmethod -from typing import Optional - -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information import PenTestingInformation -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptStrategy, \ - PromptContext, PlanningType -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts import BasicPrompt +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PlanningType, + PromptContext, + PromptStrategy, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts import ( + BasicPrompt, +) class TaskPlanningPrompt(BasicPrompt): @@ -30,7 +31,9 @@ def __init__(self, context: PromptContext, prompt_helper, strategy: PromptStrate prompt_helper (PromptHelper): A helper object for managing and generating prompts. strategy (PromptStrategy): The task planning strategy used for prompt generation. """ - super().__init__(context=context, planning_type=PlanningType.TASK_PLANNING, prompt_helper=prompt_helper, - strategy=strategy) - - + super().__init__( + context=context, + planning_type=PlanningType.TASK_PLANNING, + prompt_helper=prompt_helper, + strategy=strategy, + ) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/tree_of_thought_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/tree_of_thought_prompt.py index a0180871..028a79da 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/tree_of_thought_prompt.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/tree_of_thought_prompt.py @@ -1,9 +1,13 @@ -from typing import List, Optional +from typing import Optional from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PromptStrategy, PromptContext, PromptPurpose + PromptContext, + PromptPurpose, + PromptStrategy, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.task_planning import ( + TaskPlanningPrompt, ) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.task_planning import TaskPlanningPrompt from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Prompt @@ -37,9 +41,7 @@ def __init__(self, context: PromptContext, prompt_helper, rest_api: str) -> None self.rest_api: str = rest_api self.purpose: Optional[PromptPurpose] = None - def generate_prompt(self, move_type: str, hint: Optional[str], - previous_prompt: Prompt, - turn: Optional[int]) -> str: + def generate_prompt(self, move_type: str, hint: Optional[str], previous_prompt: Prompt, turn: Optional[int]) -> str: """ Generates a prompt using the tree-of-thought strategy. @@ -53,23 +55,27 @@ def generate_prompt(self, move_type: str, hint: Optional[str], str: The generated prompt. """ if self.context == PromptContext.DOCUMENTATION: - tree_of_thoughts_steps = [( - "Imagine three different OpenAPI specification specialists.\n" - "All experts will write down one step of their thinking,\n" - "then share it with the group.\n" - "After that, all remaining specialists will proceed to the next step, and so on.\n" - "If any specialist realizes they're wrong at any point, they will leave.\n" - f"The question is: Create an OpenAPI specification for this REST API {self.rest_api} " - )] + tree_of_thoughts_steps = [ + ( + "Imagine three different OpenAPI specification specialists.\n" + "All experts will write down one step of their thinking,\n" + "then share it with the group.\n" + "After that, all remaining specialists will proceed to the next step, and so on.\n" + "If any specialist realizes they're wrong at any point, they will leave.\n" + f"The question is: Create an OpenAPI specification for this REST API {self.rest_api} " + ) + ] else: - tree_of_thoughts_steps = [( - "Imagine three different Pentest experts are answering this question.\n" - "All experts will write down one step of their thinking,\n" - "then share it with the group.\n" - "After that, all experts will proceed to the next step, and so on.\n" - "If any expert realizes they're wrong at any point, they will leave.\n" - f"The question is: Create pentests for this REST API {self.rest_api} " - )] + tree_of_thoughts_steps = [ + ( + "Imagine three different Pentest experts are answering this question.\n" + "All experts will write down one step of their thinking,\n" + "then share it with the group.\n" + "After that, all experts will proceed to the next step, and so on.\n" + "If any expert realizes they're wrong at any point, they will leave.\n" + f"The question is: Create pentests for this REST API {self.rest_api} " + ) + ] # Assuming ChatCompletionMessage and ChatCompletionMessageParam have a 'content' attribute previous_content = previous_prompt[turn].content if turn is not None else "initial_prompt" @@ -77,4 +83,3 @@ def generate_prompt(self, move_type: str, hint: Optional[str], self.purpose = PromptPurpose.AUTHENTICATION_AUTHORIZATION return "\n".join([previous_content] + tree_of_thoughts_steps) - diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/__init__.py index c0fc01f0..4f1206eb 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/__init__.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/__init__.py @@ -1,3 +1,4 @@ -from .response_handler import ResponseHandler from .response_analyzer import ResponseAnalyzer -#from .response_analyzer_with_llm import ResponseAnalyzerWithLLM \ No newline at end of file +from .response_handler import ResponseHandler + +# from .response_analyzer_with_llm import ResponseAnalyzerWithLLM diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer.py b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer.py index f745437a..9b2c2ac9 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer.py @@ -1,6 +1,7 @@ import json import re -from typing import Optional, Tuple, Dict, Any +from typing import Any, Dict, Optional, Tuple + from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptPurpose @@ -52,8 +53,10 @@ def parse_http_response(self, raw_response: str) -> Tuple[Optional[int], Dict[st body = "Empty" status_line = header_lines[0].strip() - headers = {key.strip(): value.strip() for key, value in - (line.split(":", 1) for line in header_lines[1:] if ':' in line)} + headers = { + key.strip(): value.strip() + for key, value in (line.split(":", 1) for line in header_lines[1:] if ":" in line) + } match = re.match(r"HTTP/1\.1 (\d{3}) (.*)", status_line) status_code = int(match.group(1)) if match else None @@ -73,7 +76,9 @@ def analyze_response(self, raw_response: str) -> Optional[Dict[str, Any]]: status_code, headers, body = self.parse_http_response(raw_response) return self.analyze_parsed_response(status_code, headers, body) - def analyze_parsed_response(self, status_code: Optional[int], headers: Dict[str, str], body: str) -> Optional[Dict[str, Any]]: + def analyze_parsed_response( + self, status_code: Optional[int], headers: Dict[str, str], body: str + ) -> Optional[Dict[str, Any]]: """ Analyzes the parsed HTTP response based on the purpose, invoking the appropriate method. @@ -86,12 +91,16 @@ def analyze_parsed_response(self, status_code: Optional[int], headers: Dict[str, Optional[Dict[str, Any]]: The analysis results based on the purpose. """ analysis_methods = { - PromptPurpose.AUTHENTICATION_AUTHORIZATION: self.analyze_authentication_authorization(status_code, headers, body), + PromptPurpose.AUTHENTICATION_AUTHORIZATION: self.analyze_authentication_authorization( + status_code, headers, body + ), PromptPurpose.INPUT_VALIDATION: self.analyze_input_validation(status_code, headers, body), } return analysis_methods.get(self.purpose) - def analyze_authentication_authorization(self, status_code: Optional[int], headers: Dict[str, str], body: str) -> Dict[str, Any]: + def analyze_authentication_authorization( + self, status_code: Optional[int], headers: Dict[str, str], body: str + ) -> Dict[str, Any]: """ Analyzes the HTTP response with a focus on authentication and authorization. @@ -104,21 +113,29 @@ def analyze_authentication_authorization(self, status_code: Optional[int], heade Dict[str, Any]: The analysis results focused on authentication and authorization. """ analysis = { - 'status_code': status_code, - 'authentication_status': "Authenticated" if status_code == 200 else - "Not Authenticated or Not Authorized" if status_code in [401, 403] else "Unknown", - 'auth_headers_present': any( - header in headers for header in ['Authorization', 'Set-Cookie', 'WWW-Authenticate']), - 'rate_limiting': { - 'X-Ratelimit-Limit': headers.get('X-Ratelimit-Limit'), - 'X-Ratelimit-Remaining': headers.get('X-Ratelimit-Remaining'), - 'X-Ratelimit-Reset': headers.get('X-Ratelimit-Reset'), + "status_code": status_code, + "authentication_status": ( + "Authenticated" + if status_code == 200 + else "Not Authenticated or Not Authorized" + if status_code in [401, 403] + else "Unknown" + ), + "auth_headers_present": any( + header in headers for header in ["Authorization", "Set-Cookie", "WWW-Authenticate"] + ), + "rate_limiting": { + "X-Ratelimit-Limit": headers.get("X-Ratelimit-Limit"), + "X-Ratelimit-Remaining": headers.get("X-Ratelimit-Remaining"), + "X-Ratelimit-Reset": headers.get("X-Ratelimit-Reset"), }, - 'content_body': "Empty" if body == {} else body, + "content_body": "Empty" if body == {} else body, } return analysis - def analyze_input_validation(self, status_code: Optional[int], headers: Dict[str, str], body: str) -> Dict[str, Any]: + def analyze_input_validation( + self, status_code: Optional[int], headers: Dict[str, str], body: str + ) -> Dict[str, Any]: """ Analyzes the HTTP response with a focus on input validation. @@ -131,10 +148,10 @@ def analyze_input_validation(self, status_code: Optional[int], headers: Dict[str Dict[str, Any]: The analysis results focused on input validation. """ analysis = { - 'status_code': status_code, - 'response_body': "Empty" if body == {} else body, - 'is_valid_response': self.is_valid_input_response(status_code, body), - 'security_headers_present': any(key in headers for key in ["X-Content-Type-Options", "X-Ratelimit-Limit"]), + "status_code": status_code, + "response_body": "Empty" if body == {} else body, + "is_valid_response": self.is_valid_input_response(status_code, body), + "security_headers_present": any(key in headers for key in ["X-Content-Type-Options", "X-Ratelimit-Limit"]), } return analysis @@ -158,7 +175,14 @@ def is_valid_input_response(self, status_code: Optional[int], body: str) -> str: else: return "Unexpected" - def document_findings(self, status_code: Optional[int], headers: Dict[str, str], body: str, expected_behavior: str, actual_behavior: str) -> Dict[str, Any]: + def document_findings( + self, + status_code: Optional[int], + headers: Dict[str, str], + body: str, + expected_behavior: str, + actual_behavior: str, + ) -> Dict[str, Any]: """ Documents the findings from the analysis, comparing expected and actual behavior. @@ -239,7 +263,7 @@ def print_analysis(self, analysis: Dict[str, Any]) -> str: return analysis_str -if __name__ == '__main__': +if __name__ == "__main__": # Example HTTP response to parse raw_http_response = """HTTP/1.1 404 Not Found Date: Fri, 16 Aug 2024 10:01:19 GMT diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer_with_llm.py b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer_with_llm.py index c794b3fc..204eba13 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer_with_llm.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer_with_llm.py @@ -1,12 +1,16 @@ import json import re -from typing import Dict,Any +from typing import Any, Dict from unittest.mock import MagicMock + from hackingBuddyGPT.capabilities.http_request import HTTPRequest -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information import PenTestingInformation -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptPurpose +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information import ( + PenTestingInformation, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PromptPurpose, +) from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler - from hackingBuddyGPT.utils import tool_message @@ -19,7 +23,7 @@ class ResponseAnalyzerWithLLM: purpose (PromptPurpose): The specific purpose for analyzing the HTTP response. """ - def __init__(self, purpose: PromptPurpose = None, llm_handler: LLMHandler=None): + def __init__(self, purpose: PromptPurpose = None, llm_handler: LLMHandler = None): """ Initializes the ResponseAnalyzer with an optional purpose and an LLM instance. @@ -53,9 +57,6 @@ def print_results(self, results: Dict[str, str]): print(f"Response: {response}") print("-" * 50) - - - def analyze_response(self, raw_response: str, prompt_history: list) -> tuple[dict[str, Any], list]: """ Parses the HTTP response, generates prompts for an LLM, and processes each step with the LLM. @@ -72,12 +73,12 @@ def analyze_response(self, raw_response: str, prompt_history: list) -> tuple[dic # Start processing the analysis steps through the LLM llm_responses = [] steps_dict = self.pentesting_information.analyse_steps(full_response) - for purpose, steps in steps_dict.items(): + for steps in steps_dict.values(): response = full_response # Reset to the full response for each purpose for step in steps: prompt_history, response = self.process_step(step, prompt_history) llm_responses.append(response) - print(f'Response:{response}') + print(f"Response:{response}") return llm_responses @@ -104,14 +105,16 @@ def parse_http_response(self, raw_response: str): elif status_code in [500, 400, 404, 422]: body = body else: - print(f'Body:{body}') - if body != '' or body != "": + print(f"Body:{body}") + if body != "" or body != "": body = json.loads(body) if isinstance(body, list) and len(body) > 1: body = body[0] - headers = {key.strip(): value.strip() for key, value in - (line.split(":", 1) for line in header_lines[1:] if ':' in line)} + headers = { + key.strip(): value.strip() + for key, value in (line.split(":", 1) for line in header_lines[1:] if ":" in line) + } match = re.match(r"HTTP/1\.1 (\d{3}) (.*)", status_line) status_code = int(match.group(1)) if match else None @@ -123,7 +126,7 @@ def process_step(self, step: str, prompt_history: list) -> tuple[list, str]: Helper function to process each analysis step with the LLM. """ # Log current step - #print(f'Processing step: {step}') + # print(f'Processing step: {step}') prompt_history.append({"role": "system", "content": step}) # Call the LLM and handle the response @@ -141,7 +144,8 @@ def process_step(self, step: str, prompt_history: list) -> tuple[list, str]: return prompt_history, result -if __name__ == '__main__': + +if __name__ == "__main__": # Example HTTP response to parse raw_http_response = """HTTP/1.1 404 Not Found Date: Fri, 16 Aug 2024 10:01:19 GMT @@ -172,15 +176,17 @@ def process_step(self, step: str, prompt_history: list) -> tuple[list, str]: {}""" llm_mock = MagicMock() capabilities = { - "submit_http_method": HTTPRequest('https://jsonplaceholder.typicode.com'), - "http_request": HTTPRequest('https://jsonplaceholder.typicode.com'), + "submit_http_method": HTTPRequest("https://jsonplaceholder.typicode.com"), + "http_request": HTTPRequest("https://jsonplaceholder.typicode.com"), } # Initialize the ResponseAnalyzer with a specific purpose and an LLM instance - response_analyzer = ResponseAnalyzerWithLLM(PromptPurpose.PARSING, llm_handler=LLMHandler(llm=llm_mock, capabilities=capabilities)) + response_analyzer = ResponseAnalyzerWithLLM( + PromptPurpose.PARSING, llm_handler=LLMHandler(llm=llm_mock, capabilities=capabilities) + ) # Generate and process LLM prompts based on the HTTP response results = response_analyzer.analyze_response(raw_http_response) # Print the LLM processing results - response_analyzer.print_results(results) \ No newline at end of file + response_analyzer.print_results(results) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_handler.py index 1d14339a..c7ac733d 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_handler.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_handler.py @@ -1,11 +1,15 @@ import json -from typing import Any, Dict, Optional, Tuple, Union +import re +from typing import Any, Dict, Optional, Tuple from bs4 import BeautifulSoup -import re -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information import PenTestingInformation -from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_analyzer_with_llm import ResponseAnalyzerWithLLM +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.pentesting_information import ( + PenTestingInformation, +) +from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_analyzer_with_llm import ( + ResponseAnalyzerWithLLM, +) from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Prompt @@ -62,12 +66,12 @@ def parse_http_status_line(self, status_line: str) -> str: """ if status_line == "Not a valid HTTP method" or "note recorded" in status_line: return status_line - status_line = status_line.split('\r\n')[0] + status_line = status_line.split("\r\n")[0] # Regular expression to match valid HTTP status lines - match = re.match(r'^(HTTP/\d\.\d) (\d{3}) (.*)$', status_line) + match = re.match(r"^(HTTP/\d\.\d) (\d{3}) (.*)$", status_line) if match: protocol, status_code, status_message = match.groups() - return f'{status_code} {status_message}' + return f"{status_code} {status_message}" else: raise ValueError(f"{status_line} is an invalid HTTP status line") @@ -81,16 +85,18 @@ def extract_response_example(self, html_content: str) -> Optional[Dict[str, Any] Returns: Optional[Dict[str, Any]]: The extracted response example as a dictionary, or None if extraction fails. """ - soup = BeautifulSoup(html_content, 'html.parser') - example_code = soup.find('code', {'id': 'example'}) - result_code = soup.find('code', {'id': 'result'}) + soup = BeautifulSoup(html_content, "html.parser") + example_code = soup.find("code", {"id": "example"}) + result_code = soup.find("code", {"id": "result"}) if example_code and result_code: example_text = example_code.get_text() result_text = result_code.get_text() return json.loads(result_text) return None - def parse_http_response_to_openapi_example(self, openapi_spec: Dict[str, Any], http_response: str, path: str, method: str) -> Tuple[Optional[Dict[str, Any]], Optional[str], Dict[str, Any]]: + def parse_http_response_to_openapi_example( + self, openapi_spec: Dict[str, Any], http_response: str, path: str, method: str + ) -> Tuple[Optional[Dict[str, Any]], Optional[str], Dict[str, Any]]: """ Parses an HTTP response to generate an OpenAPI example. @@ -104,7 +110,7 @@ def parse_http_response_to_openapi_example(self, openapi_spec: Dict[str, Any], h Tuple[Optional[Dict[str, Any]], Optional[str], Dict[str, Any]]: A tuple containing the entry dictionary, reference, and updated OpenAPI specification. """ - headers, body = http_response.split('\r\n\r\n', 1) + headers, body = http_response.split("\r\n\r\n", 1) try: body_dict = json.loads(body) except json.decoder.JSONDecodeError: @@ -141,7 +147,9 @@ def extract_description(self, note: Any) -> str: """ return note.action.content - def parse_http_response_to_schema(self, openapi_spec: Dict[str, Any], body_dict: Dict[str, Any], path: str) -> Tuple[str, str, Dict[str, Any]]: + def parse_http_response_to_schema( + self, openapi_spec: Dict[str, Any], body_dict: Dict[str, Any], path: str + ) -> Tuple[str, str, Dict[str, Any]]: """ Parses an HTTP response body to generate an OpenAPI schema. @@ -153,7 +161,7 @@ def parse_http_response_to_schema(self, openapi_spec: Dict[str, Any], body_dict: Returns: Tuple[str, str, Dict[str, Any]]: A tuple containing the reference, object name, and updated OpenAPI specification. """ - object_name = path.split("/")[1].capitalize().rstrip('s') + object_name = path.split("/")[1].capitalize().rstrip("s") properties_dict = {} if len(body_dict) == 1: @@ -187,7 +195,7 @@ def read_yaml_to_string(self, filepath: str) -> Optional[str]: Optional[str]: The contents of the YAML file, or None if an error occurred. """ try: - with open(filepath, 'r') as file: + with open(filepath, "r") as file: return file.read() except FileNotFoundError: print(f"Error: The file {filepath} does not exist.") @@ -234,7 +242,11 @@ def extract_keys(self, key: str, value: Any, properties_dict: Dict[str, Any]) -> Dict[str, Any]: The updated properties dictionary. """ if key == "id": - properties_dict[key] = {"type": str(type(value).__name__), "format": "uuid", "example": str(value)} + properties_dict[key] = { + "type": str(type(value).__name__), + "format": "uuid", + "example": str(value), + } else: properties_dict[key] = {"type": str(type(value).__name__), "example": str(value)} diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py b/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py index c3692282..98781cbb 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py @@ -1,22 +1,21 @@ from dataclasses import field -from typing import Dict - +from typing import Dict from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.http_request import HTTPRequest from hackingBuddyGPT.capabilities.record_note import RecordNote from hackingBuddyGPT.usecases.agents import Agent +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case +from hackingBuddyGPT.usecases.web_api_testing.documentation.openapi_specification_handler import ( + OpenAPISpecificationHandler, +) from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptContext -from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Prompt, Context -from hackingBuddyGPT.usecases.web_api_testing.documentation.openapi_specification_handler import OpenAPISpecificationHandler -from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import PromptStrategy, PromptEngineer +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import PromptEngineer, PromptStrategy from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_handler import ResponseHandler - +from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Context, Prompt +from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler from hackingBuddyGPT.utils.configurable import parameter from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib -from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case - class SimpleWebAPIDocumentation(Agent): @@ -46,19 +45,19 @@ class SimpleWebAPIDocumentation(Agent): # Description for expected HTTP methods _http_method_description: str = parameter( desc="Pattern description for expected HTTP methods in the API response", - default="A string that represents an HTTP method (e.g., 'GET', 'POST', etc.)." + default="A string that represents an HTTP method (e.g., 'GET', 'POST', etc.).", ) # Template for HTTP methods in API requests _http_method_template: str = parameter( desc="Template to format HTTP methods in API requests, with {method} replaced by actual HTTP method names.", - default="{method}" + default="{method}", ) # List of expected HTTP methods _http_methods: str = parameter( desc="Expected HTTP methods in the API, as a comma-separated list.", - default="GET,POST,PUT,PATCH,DELETE" + default="GET,POST,PUT,PATCH,DELETE", ) def init(self): @@ -73,26 +72,25 @@ def init(self): def _setup_capabilities(self): """Sets up the capabilities for the agent.""" notes = self._context["notes"] - self._capabilities = { - "http_request": HTTPRequest(self.host), - "record_note": RecordNote(notes) - } + self._capabilities = {"http_request": HTTPRequest(self.host), "record_note": RecordNote(notes)} def _setup_initial_prompt(self): """Sets up the initial prompt for the agent.""" initial_prompt = { "role": "system", "content": f"You're tasked with documenting the REST APIs of a website hosted at {self.host}. " - f"Start with an empty OpenAPI specification.\n" - f"Maintain meticulousness in documenting your observations as you traverse the APIs." + f"Start with an empty OpenAPI specification.\n" + f"Maintain meticulousness in documenting your observations as you traverse the APIs.", } self._prompt_history.append(initial_prompt) handlers = (self.llm_handler, self.response_handler) - self.prompt_engineer = PromptEngineer(strategy=PromptStrategy.CHAIN_OF_THOUGHT, - history=self._prompt_history, - handlers=handlers, - context=PromptContext.DOCUMENTATION, - rest_api=self.host) + self.prompt_engineer = PromptEngineer( + strategy=PromptStrategy.CHAIN_OF_THOUGHT, + history=self._prompt_history, + handlers=handlers, + context=PromptContext.DOCUMENTATION, + rest_api=self.host, + ) def all_http_methods_found(self, turn): """ @@ -106,11 +104,15 @@ def all_http_methods_found(self, turn): """ found_endpoints = sum(len(value_list) for value_list in self.documentation_handler.endpoint_methods.values()) expected_endpoints = len(self.documentation_handler.endpoint_methods.keys()) * 4 - print(f'found methods:{found_endpoints}') - print(f'expected methods:{expected_endpoints}') - if found_endpoints > 0 and (found_endpoints == expected_endpoints): - return True - elif turn == 20 and found_endpoints > 0 and (found_endpoints == expected_endpoints): + print(f"found methods:{found_endpoints}") + print(f"expected methods:{expected_endpoints}") + if ( + found_endpoints > 0 + and (found_endpoints == expected_endpoints) + or turn == 20 + and found_endpoints > 0 + and (found_endpoints == expected_endpoints) + ): return True return False @@ -133,7 +135,7 @@ def perform_round(self, turn: int): if len(self.documentation_handler.endpoint_methods) > new_endpoint_found: new_endpoint_found = len(self.documentation_handler.endpoint_methods) elif turn == 20: - while len(self.prompt_engineer.prompt_helper.get_endpoints_needing_help() )!= 0: + while len(self.prompt_engineer.prompt_helper.get_endpoints_needing_help()) != 0: self.run_documentation(turn, "exploit") else: self.run_documentation(turn, "exploit") @@ -161,16 +163,13 @@ def run_documentation(self, turn, move_type): """ prompt = self.prompt_engineer.generate_prompt(turn, move_type) response, completion = self.llm_handler.call_llm(prompt) - self._log, self._prompt_history, self.prompt_engineer = self.documentation_handler.document_response( - completion, - response, - self._log, - self._prompt_history, - self.prompt_engineer + self.log, self._prompt_history, self.prompt_engineer = self.documentation_handler.document_response( + completion, response, self.log, self._prompt_history, self.prompt_engineer ) @use_case("Minimal implementation of a web API testing use case") class SimpleWebAPIDocumentationUseCase(AutonomousAgentUseCase[SimpleWebAPIDocumentation]): """Use case for the SimpleWebAPIDocumentation agent.""" + pass diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py b/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py index 0bb9588a..6aff0267 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py @@ -1,26 +1,25 @@ import os.path from dataclasses import field -from typing import List, Any, Dict -import pydantic_core +from typing import Any, Dict, List +import pydantic_core from rich.panel import Panel from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.http_request import HTTPRequest from hackingBuddyGPT.capabilities.record_note import RecordNote from hackingBuddyGPT.usecases.agents import Agent -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptContext -from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Prompt, Context +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing import OpenAPISpecificationParser from hackingBuddyGPT.usecases.web_api_testing.documentation.report_handler import ReportHandler -from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptContext from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import PromptEngineer, PromptStrategy from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_handler import ResponseHandler +from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Context, Prompt +from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler from hackingBuddyGPT.utils import tool_message from hackingBuddyGPT.utils.configurable import parameter from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib -from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case - # OpenAPI specification file path openapi_spec_filename = "/home/diana/Desktop/masterthesis/00/hackingBuddyGPT/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_spec/openapi_spec_2024-08-16_14-14-07.yaml" @@ -46,15 +45,15 @@ class SimpleWebAPITesting(Agent): host: str = parameter(desc="The host to test", default="https://jsonplaceholder.typicode.com") http_method_description: str = parameter( desc="Pattern description for expected HTTP methods in the API response", - default="A string that represents an HTTP method (e.g., 'GET', 'POST', etc.)." + default="A string that represents an HTTP method (e.g., 'GET', 'POST', etc.).", ) http_method_template: str = parameter( desc="Template used to format HTTP methods in API requests. The {method} placeholder will be replaced by actual HTTP method names.", - default="{method}" + default="{method}", ) http_methods: str = parameter( desc="Comma-separated list of HTTP methods expected to be used in the API response.", - default="GET,POST,PUT,DELETE" + default="GET,POST,PUT,DELETE", ) _prompt_history: Prompt = field(default_factory=list) @@ -90,19 +89,20 @@ def _setup_initial_prompt(self) -> None: f"and then devise and execute potential exploits. Be diligent in documenting your findings as you navigate through the APIs. " f"Avoid resorting to brute-force tactics. All necessary information can be obtained from the API endpoints themselves. " f"Remember, if you encounter an HTTP method ({self.http_method_description}), promptly submit it as it is of utmost importance." - ) + ), } self._prompt_history.append(initial_prompt) handlers = (self._llm_handler, self._response_handler) - schemas: Dict[str, Any] = self._openapi_specification["components"]["schemas"] if os.path.exists( - openapi_spec_filename) else {} + schemas: Dict[str, Any] = ( + self._openapi_specification["components"]["schemas"] if os.path.exists(openapi_spec_filename) else {} + ) self.prompt_engineer: PromptEngineer = PromptEngineer( strategy=PromptStrategy.CHAIN_OF_THOUGHT, history=self._prompt_history, handlers=handlers, context=PromptContext.PENTESTING, rest_api=self.host, - schemas=schemas + schemas=schemas, ) def all_http_methods_found(self) -> None: @@ -110,7 +110,7 @@ def all_http_methods_found(self) -> None: Handles the event when all HTTP methods are found. Displays a congratulatory message and sets the _all_http_methods_found flag to True. """ - self._log.console.print(Panel("All HTTP methods found! Congratulations!", title="system")) + self.log.console.print(Panel("All HTTP methods found! Congratulations!", title="system")) self._all_http_methods_found = True def _setup_capabilities(self) -> None: @@ -119,13 +119,14 @@ def _setup_capabilities(self) -> None: note recording capabilities, and HTTP method submission capabilities based on the provided configuration. """ - methods_set: set[str] = {self.http_method_template.format(method=method) for method in - self.http_methods.split(",")} + methods_set: set[str] = { + self.http_method_template.format(method=method) for method in self.http_methods.split(",") + } notes: List[str] = self._context["notes"] self._capabilities = { "submit_http_method": HTTPRequest(self.host), "http_request": HTTPRequest(self.host), - "record_note": RecordNote(notes) + "record_note": RecordNote(notes), } def perform_round(self, turn: int) -> None: @@ -155,18 +156,18 @@ def _handle_response(self, completion: Any, response: Any, purpose: str) -> None message = completion.choices[0].message tool_call_id: str = message.tool_calls[0].id command: str = pydantic_core.to_json(response).decode() - self._log.console.print(Panel(command, title="assistant")) + self.log.console.print(Panel(command, title="assistant")) self._prompt_history.append(message) - with self._log.console.status("[bold green]Executing that command..."): + with self.log.console.status("[bold green]Executing that command..."): result: Any = response.execute() - self._log.console.print(Panel(result[:30], title="tool")) + self.log.console.print(Panel(result[:30], title="tool")) if not isinstance(result, str): - endpoint: str = str(response.action.path).split('/')[1] + endpoint: str = str(response.action.path).split("/")[1] self._report_handler.write_endpoint_to_report(endpoint) self._prompt_history.append(tool_message(str(result), tool_call_id)) - analysis = self._response_handler.evaluate_result(result=result, prompt_history= self._prompt_history) + analysis = self._response_handler.evaluate_result(result=result, prompt_history=self._prompt_history) self._report_handler.write_analysis_to_report(analysis=analysis, purpose=self.prompt_engineer.purpose) # self._prompt_history.append(tool_message(str(analysis), tool_call_id)) @@ -179,4 +180,5 @@ class SimpleWebAPITestingUseCase(AutonomousAgentUseCase[SimpleWebAPITesting]): A use case for the SimpleWebAPITesting agent, encapsulating the setup and execution of the web API testing scenario. """ + pass diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/__init__.py index bc940e02..92159799 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/utils/__init__.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/__init__.py @@ -1,2 +1,2 @@ +from .custom_datatypes import Context, Prompt from .llm_handler import LLMHandler -from .custom_datatypes import Prompt, Context diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/custom_datatypes.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/custom_datatypes.py index 803e7890..7061b01a 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/utils/custom_datatypes.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/custom_datatypes.py @@ -1,5 +1,7 @@ -from typing import List, Any, Union, Dict -from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage +from typing import Any, List, Union + +from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam + # Type aliases for readability Prompt = List[Union[ChatCompletionMessage, ChatCompletionMessageParam]] -Context = Any \ No newline at end of file +Context = Any diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py index e4d77710..16b0dff1 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py @@ -1,8 +1,10 @@ import re -from typing import List, Dict, Any -from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model +from typing import Any, Dict, List + import openai +from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model + class LLMHandler: """ @@ -26,7 +28,7 @@ def __init__(self, llm: Any, capabilities: Dict[str, Any]) -> None: self.llm = llm self._capabilities = capabilities self.created_objects: Dict[str, List[Any]] = {} - self._re_word_boundaries = re.compile(r'\b') + self._re_word_boundaries = re.compile(r"\b") def call_llm(self, prompt: List[Dict[str, Any]]) -> Any: """ @@ -38,14 +40,14 @@ def call_llm(self, prompt: List[Dict[str, Any]]) -> Any: Returns: Any: The response from the LLM. """ - print(f'Initial prompt length: {len(prompt)}') + print(f"Initial prompt length: {len(prompt)}") def call_model(prompt: List[Dict[str, Any]]) -> Any: - """ Helper function to avoid redundancy in making the API call. """ + """Helper function to avoid redundancy in making the API call.""" return self.llm.instructor.chat.completions.create_with_completion( model=self.llm.model, messages=prompt, - response_model=capabilities_to_action_model(self._capabilities) + response_model=capabilities_to_action_model(self._capabilities), ) try: @@ -55,25 +57,25 @@ def call_model(prompt: List[Dict[str, Any]]) -> Any: return call_model(self.adjust_prompt_based_on_token(prompt)) except openai.BadRequestError as e: try: - print(f'Error: {str(e)} - Adjusting prompt size and retrying.') + print(f"Error: {str(e)} - Adjusting prompt size and retrying.") # Reduce prompt size; removing elements and logging this adjustment return call_model(self.adjust_prompt_based_on_token(self.adjust_prompt(prompt))) except openai.BadRequestError as e: new_prompt = self.adjust_prompt_based_on_token(self.adjust_prompt(prompt, num_prompts=2)) - print(f'New prompt:') - print(f'Len New prompt:{len(new_prompt)}') + print("New prompt:") + print(f"Len New prompt:{len(new_prompt)}") for prompt in new_prompt: - print(f'{prompt}') + print(f"{prompt}") return call_model(new_prompt) def adjust_prompt(self, prompt: List[Dict[str, Any]], num_prompts: int = 5) -> List[Dict[str, Any]]: - adjusted_prompt = prompt[len(prompt) - num_prompts - (len(prompt) % 2): len(prompt)] + adjusted_prompt = prompt[len(prompt) - num_prompts - (len(prompt) % 2) : len(prompt)] if not isinstance(adjusted_prompt[0], dict): - adjusted_prompt = prompt[len(prompt) - num_prompts - (len(prompt) % 2) - 1: len(prompt)] + adjusted_prompt = prompt[len(prompt) - num_prompts - (len(prompt) % 2) - 1 : len(prompt)] - print(f'Adjusted prompt length: {len(adjusted_prompt)}') - print(f'adjusted prompt:{adjusted_prompt}') + print(f"Adjusted prompt length: {len(adjusted_prompt)}") + print(f"adjusted prompt:{adjusted_prompt}") return prompt def add_created_object(self, created_object: Any, object_type: str) -> None: @@ -96,7 +98,7 @@ def get_created_objects(self) -> Dict[str, List[Any]]: Returns: Dict[str, List[Any]]: The dictionary of created objects. """ - print(f'created_objects: {self.created_objects}') + print(f"created_objects: {self.created_objects}") return self.created_objects def adjust_prompt_based_on_token(self, prompt: List[Dict[str, Any]]) -> List[Dict[str, Any]]: @@ -108,13 +110,13 @@ def adjust_prompt_based_on_token(self, prompt: List[Dict[str, Any]]) -> List[Dic prompt.remove(item) else: if isinstance(item, dict): - new_token_count = (tokens + self.get_num_tokens(item["content"])) + new_token_count = tokens + self.get_num_tokens(item["content"]) if new_token_count <= max_tokens: tokens = new_token_count else: continue - print(f'tokens:{tokens}') + print(f"tokens:{tokens}") prompt.reverse() return prompt diff --git a/src/hackingBuddyGPT/utils/__init__.py b/src/hackingBuddyGPT/utils/__init__.py index 7df80e5e..4ac36972 100644 --- a/src/hackingBuddyGPT/utils/__init__.py +++ b/src/hackingBuddyGPT/utils/__init__.py @@ -1,9 +1,8 @@ -from .configurable import configurable, Configurable -from .llm_util import * -from .ui import * - +from .configurable import Configurable, configurable, parameter from .console import * from .db_storage import * +from .llm_util import * from .openai import * from .psexec import * -from .ssh_connection import * \ No newline at end of file +from .ssh_connection import * +from .ui import * diff --git a/src/hackingBuddyGPT/utils/cli_history.py b/src/hackingBuddyGPT/utils/cli_history.py index 3fce45ea..2e8f8e2c 100644 --- a/src/hackingBuddyGPT/utils/cli_history.py +++ b/src/hackingBuddyGPT/utils/cli_history.py @@ -1,10 +1,11 @@ from .llm_util import LLM, trim_result_front -class SlidingCliHistory: +class SlidingCliHistory: model: LLM = None maximum_target_size: int = 0 - sliding_history: str = '' + sliding_history: str = "" + last_output: str = '' def __init__(self, used_model: LLM): self.model = used_model @@ -16,3 +17,15 @@ def add_command(self, cmd: str, output: str): def get_history(self, target_size: int) -> str: return trim_result_front(self.model, min(self.maximum_target_size, target_size), self.sliding_history) + + def add_command_only(self, cmd: str, output: str): + self.sliding_history += f"$ {cmd}\n" + self.last_output = output + last_output_size = self.model.count_tokens(self.last_output) + if self.maximum_target_size - last_output_size < 0: + last_output_size = 0 + self.last_output = '' + self.sliding_history = trim_result_front(self.model, self.maximum_target_size - last_output_size, self.sliding_history) + + def get_commands_and_last_output(self, target_size: int) -> str: + return trim_result_front(self.model, min(self.maximum_target_size, target_size), self.sliding_history + self.last_output) \ No newline at end of file diff --git a/src/hackingBuddyGPT/utils/configurable.py b/src/hackingBuddyGPT/utils/configurable.py index 6a41e791..079b15d7 100644 --- a/src/hackingBuddyGPT/utils/configurable.py +++ b/src/hackingBuddyGPT/utils/configurable.py @@ -2,197 +2,703 @@ import dataclasses import inspect import os -from dataclasses import dataclass -from typing import Any, Dict, TypeVar +import json +from dotenv import dotenv_values +from dataclasses import dataclass, Field, field, MISSING, _MISSING_TYPE +from types import NoneType +from typing import Any, Dict, Type, TypeVar, Set, Union, Optional, overload, Generic, Callable, get_origin, get_args + + +def repr_text(value: Any, secret: bool = False) -> str: + if secret: + return "" + if isinstance(value, str): + return f"'{value}'" + else: + return f"{value}" + + +class no_default: + pass + + +class ParameterError(Exception): + def __init__(self, message: str, name: list[str]): + super().__init__(message) + self.name = name + + +Configurable = Type # TODO: Define type + + +C = TypeVar('C', bound=type) + + +def configurable(name: str, description: str): + """ + Anything that is decorated with the @configurable decorator gets the parameters of its __init__ method extracted, + which can then be used with build_parser and get_arguments to recursively prepare the argparse parser and extract the + initialization parameters. These can then be used to initialize the class with the correct parameters. + """ + + def inner(cls) -> Configurable: + cls.name = name or cls.__name__ + cls.description = description + + return cls + + return inner + + +def Secret(subclass: C) -> C: + class Cloned(subclass): + __secret__ = True + __transparent__ = getattr(subclass, "__transparent__", False) + __global__ = getattr(subclass, "__global__", False) + __global_name__ = getattr(subclass, "__global_name__", None) + + Cloned.__name__ = subclass.__name__ + Cloned.__qualname__ = subclass.__qualname__ + + return Cloned + + +def Global(subclass: C, global_name: Optional[str] = None) -> C: + class Cloned(subclass): + __secret__ = getattr(subclass, "__secret__", False) + __transparent__ = getattr(subclass, "__transparent__", False) + __global__ = True + __global_name__ = global_name + + Cloned.__name__ = subclass.__name__ + Cloned.__qualname__ = subclass.__qualname__ + + return Cloned + -from dotenv import load_dotenv +def Transparent(subclass: C) -> C: + """ + setting a type to be transparent means, that it will not increase a level in the configuration tree, so if you have the following classes: + + class Inner: + a: int + b: str -from typing import Type + def init(self): + print("inner init") + class Outer: + inner: transparent(Inner) -load_dotenv() + def init(self): + inner.init() + the configuration will be `--a` and `--b` instead of `--inner.a` and `--inner.b`. -def parameter(*, desc: str, default=dataclasses.MISSING, init: bool = True, repr: bool = True, hash=None, - compare: bool = True, metadata: Dict = None, kw_only: bool = dataclasses.MISSING): + A transparent attribute will also not have its init function called automatically, so you will need to do that on your own, as seen in the Outer init. + The function is upper case on purpose, as it is supposed to be used in a Type context + """ + class Cloned(subclass): + __secret__ = getattr(subclass, "__secret__", False) + __transparent__ = True + __global__ = getattr(subclass, "__global__", False) + __global_name__ = getattr(subclass, "__global_name__", None) + + Cloned.__name__ = subclass.__name__ + Cloned.__qualname__ = subclass.__qualname__ + + return Cloned + + +INDENT_WIDTH = 4 +INDENT = " " * INDENT_WIDTH + +COMMAND_COLOR = "\033[34m" +PARAMETER_COLOR = "\033[32m" +DEFAULT_VALUE_COLOR = "\033[33m" +MUTED_COLOR = "\033[37m" +COLOR_RESET = "\033[0m" + + +def indent(level: int) -> str: + return INDENT * level + + +T = TypeVar("T") + + +@overload +def parameter( + *, + desc: str, + default: T = ..., + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[Dict[str, Any]] = ..., + kw_only: Union[bool, _MISSING_TYPE] = MISSING, +) -> T: + ... + +@overload +def parameter( + *, + desc: str, + default: T = ..., + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[Dict[str, Any]] = ..., + kw_only: Union[bool, _MISSING_TYPE] = MISSING, +) -> Field[T]: + ... + +def parameter( + *, + desc: str, + secret: bool = False, + global_parameter: bool = False, + global_name: Optional[str] = None, + choices: Optional[dict[str, type]] = None, + default: T = MISSING, + init: bool = True, + repr: bool = True, + hash: Optional[bool] = None, + compare: bool = True, + metadata: Optional[Dict[str, Any]] = None, + kw_only: Union[bool, _MISSING_TYPE] = MISSING, +) -> Field[T]: if metadata is None: metadata = dict() metadata["desc"] = desc - - return dataclasses.field(default=default, default_factory=dataclasses.MISSING, init=init, repr=repr, hash=hash, - compare=compare, metadata=metadata, kw_only=kw_only) + metadata["secret"] = secret + metadata["global"] = global_parameter + metadata["global_name"] = global_name + metadata["choices"] = choices + + return field( + default=default, + default_factory=MISSING, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + kw_only=kw_only, + ) def get_default(key, default): - return os.getenv(key, os.getenv(key.upper(), os.getenv(key.replace(".", "_"), os.getenv(key.replace(".", "_").upper(), default)))) + return os.getenv( + key, os.getenv(key.upper(), os.getenv(key.replace(".", "_"), os.getenv(key.replace(".", "_").upper(), default))) + ) + + +NestedCollection = Union[C, Dict[str, "NestedCollection[C]"]] +ParameterCollection = NestedCollection["ParameterDefinition[C]"] +ParsingResults = NestedCollection[str] +InstanceResults = NestedCollection[Any] + + +def get_at(collection: NestedCollection[C], name: list[str], at: int = 0, *, meta: bool = False, no_raise: bool = False) -> Optional[C]: + if meta: + name = name + ["$"] + + if len(name) == at: + if isinstance(collection, dict): + if no_raise: + return None + raise ValueError(f"Value for '{'.'.join(name)}' not final in collection: {collection}") + return collection + if not isinstance(collection, dict): + if no_raise: + return None + raise ValueError(f"Lookup for '{'.'.join(name)}' overflowing in collection: {collection}") + + cur_name = name[at] + if cur_name not in collection: + return None + + return get_at(collection[cur_name], name, at + 1, meta=False, no_raise=no_raise) + + +def set_at(collection: NestedCollection[C], name: list[str], value: C, at: int = 0, meta: bool = False): + if meta: + name = name + ["$"] + + if len(name) == at: + raise ValueError(f"Lookup for collection '{'.'.join(name)}' has empty path") + + if not isinstance(collection, dict): + raise ValueError(f"Lookup for '{'.'.join(name)}' overflowing in collection: {collection}") + + if len(name) - 1 == at: + collection[name[at]] = value + return + + if name[at] not in collection: + collection[name[at]] = {} + + return set_at(collection[name[at]], name, value, at + 1, False) + + +def dfs_flatmap(collection: NestedCollection[C], func: Callable[[list[str], C], Any], basename: Optional[list[str]] = None): + if basename is None: + basename = [] + output = [] + for key, value in collection.items(): + name = basename + [key] + if isinstance(value, dict): + output += dfs_flatmap(value, func, name) + else: + res = func(name, value) + if res is not None: + output.append(res) + return output @dataclass -class ParameterDefinition: +class ParameterDefinition(Generic[C]): """ - A ParameterDefinition is used for any parameter that is just a simple type, which can be handled by argparse directly. + A ParameterDefinition is used for any parameter that is just a simple type like str, int, float, bool. """ - name: str - type: Type + + name: list[str] + type: C default: Any - description: str + description: Optional[str] + secret: bool + + _instance: Optional[Any] = field(init=False, default=None) + + def __call__(self, collection: ParsingResults) -> C: + if self._instance is None: + value = get_at(collection, self.name) + if value is None: + raise ParameterError(f"Missing required parameter '--{'.'.join(self.name)}'", self.name) + self._instance = self.type(value) + return self._instance + + def get_default(self, defaults: list[tuple[str, ParsingResults]], fail_fast: bool = True) -> tuple[Any, str, str]: + default_value = None + default_text = "" + default_origin = "" + default_alternatives = False + + defaults = [(source, get_at(values, self.name)) for source, values in defaults] + defaults.append(("builtin", self.default)) + for source, default in defaults: + if default is not None and not isinstance(default, no_default): + if len(default_text) > 0: + if not default_alternatives: + default_origin += ", alternatives: " + else: + default_origin += ", " + default_origin += f"{repr_text(default, self.secret)} from {source}" + default_alternatives = True + continue + + default_value = default + default_origin = f"default from {source}" + default_text = repr_text(default, self.secret) + if fail_fast: + break + + return default_value, default_text, default_origin + + def to_help(self, defaults: list[tuple[str, ParsingResults]], level: int) -> str: + eq = "" + + _, default_text, default_origin = self.get_default(defaults, fail_fast=False) + if len(default_origin) > 0: + eq = "=" + default_origin = f" ({default_origin})" + + description = self.description or "" + return f"{indent(level)}{PARAMETER_COLOR}--{'.'.join(self.name)}{COLOR_RESET}{eq}{DEFAULT_VALUE_COLOR}{default_text}{COLOR_RESET} {description}{MUTED_COLOR}{default_origin}{COLOR_RESET}" - def parser(self, name: str, parser: argparse.ArgumentParser): - default = get_default(name, self.default) - parser.add_argument(f"--{name}", type=self.type, default=default, required=default is None, - help=self.description) +@dataclass +class ComplexParameterDefinition(ParameterDefinition[C]): + """ + A ComplexParameterDefinition is used for any parameter that is a complex type (which itself only takes simple types, + or other types that fit the ComplexParameterDefinition/UnionParameterDefinition). + It is important to note, that at some point, the parameter must be a simple type, so that it can be parsed. + So if you have recursive type definitions that you try to make configurable, this will not work. + """ - def get(self, name: str, args: argparse.Namespace): - return getattr(args, name) + parameters: dict[str, ParameterDefinition] + def __call__(self, collection: ParsingResults) -> C: + # TODO: default handling? + # we only do instance management on non-top level parameter definitions (those would be the full configurable, which does not need to be cached and also fails) + if self._instance is None: + self._instance = self.type(**{ + name: param(collection) + for name, param in self.parameters.items() + }) + if hasattr(self._instance, "init"): + self._instance.init() + return self._instance -ParameterDefinitions = Dict[str, ParameterDefinition] + def get_default(self, defaults: list[tuple[str, ParsingResults]], fail_fast: bool = True) -> tuple[Any, str, str]: + return None, "", "" @dataclass -class ComplexParameterDefinition(ParameterDefinition): +class ChoiceParameterDefinition(ParameterDefinition[C]): """ - A ComplexParameterDefinition is used for any parameter that is a complex type (which itself only takes simple types, - or other types that fit the ComplexParameterDefinition), requiring a recursive build_parser. - It is important to note, that at some point, the parameter must be a simple type, so that argparse (and we) can handle - it. So if you have recursive type definitions that you try to make configurable, this will not work. + A ChoiceParameterDefinition is used for any parameter that is a choice / Union type. + It is important to note, that at some point, the parameter must be a simple type, so that it can be parsed. + So if you have recursive type definitions that you try to make configurable, this will not work. """ - parameters: ParameterDefinitions - transparent: bool = False - def parser(self, basename: str, parser: argparse.ArgumentParser): - for name, parameter in self.parameters.items(): - if isinstance(parameter, dict): - build_parser(parameter, parser, next_name(basename, name, parameter)) - else: - parameter.parser(next_name(basename, name, parameter), parser) + choices: dict[str, tuple[ParameterDefinition, dict[str, ParameterDefinition]]] + + def __call__(self, collection: ParsingResults) -> C: + if self._instance is None: + value = get_at(collection, self.name, meta=True) + if value is None: + raise ParameterError(f"Missing required parameter '--{'.'.join(self.name)}'", self.name) + if value not in self.choices: + raise ParameterError(f"Invalid value for parameter '--{'.'.join(self.name)}': {value} (possible values are {', '.join(self.choices.keys())})", self.name) + choice, parameters = self.choices[value] + self._instance = choice(**{ + name: parameter(collection) + for name, parameter in parameters.items() + }) + if hasattr(self._instance, "init"): + self._instance.init() + return self._instance + + +def get_inspect_parameters_for_class(cls: type, basename: list[str]) -> dict[str, tuple[inspect.Parameter, list[str], Optional[dataclasses.Field]]]: + fields = getattr(cls, "__dataclass_fields__", {}) + return { + name: (param, basename + [name], fields.get(name)) + for name, param in inspect.signature(cls.__init__).parameters.items() + if not (name == "self" or name.startswith("_") or isinstance(name, NoneType)) + } + +def get_type_description_default_for_parameter(parameter: inspect.Parameter, name: list[str], field: Optional[dataclasses.Field] = None) -> tuple[Type, Optional[str], Any]: + parameter_type: Type = parameter.annotation + description: Optional[str] = None + + default: Any = parameter.default if parameter.default != inspect.Parameter.empty else no_default() + if isinstance(default, dataclasses.Field): + field = default + default = field.default + + if field is not None: + description = field.metadata.get("desc", None) + if field.type is not None: + if not (isinstance(field.type, type) or get_origin(field.type) is Union): + raise ValueError(f"Parameter {'.'.join(name)} has an invalid type annotation: {field.type} ({type(field.type)})") + parameter_type = field.type + + # check if type is an Optional, and then get the actual type + if get_origin(parameter_type) is Union and len(parameter_type.__args__) == 2 and parameter_type.__args__[1] is NoneType: + parameter_type = parameter_type.__args__[0] + + return parameter_type, description, default + + +def try_existing_parameter(parameter_collection: ParameterCollection, name: list[str], typ: type, parameter_type: type, default: Any, description: str, secret_parameter: bool) -> Optional[ParameterDefinition]: + existing_parameter = get_at(parameter_collection, name, meta=(typ in (ComplexParameterDefinition, ChoiceParameterDefinition))) + if not existing_parameter: + return None + + if existing_parameter.type != parameter_type: + raise ValueError(f"Parameter {'.'.join(name)} already exists with a different type ({existing_parameter.type} != {parameter_type})") + if existing_parameter.default != default: + if existing_parameter.default is None and isinstance(secret_parameter, no_default) \ + or existing_parameter.default is not None and not isinstance(secret_parameter, no_default): + pass # syncing up "no defaults" + else: + raise ValueError(f"Parameter {'.'.join(name)} already exists with a different default value ({existing_parameter.default} != {default})") + if existing_parameter.description != description: + raise ValueError(f"Parameter {'.'.join(name)} already exists with a different description ({existing_parameter.description} != {description})") + if existing_parameter.secret != secret_parameter: + raise ValueError(f"Parameter {'.'.join(name)} already exists with a different secret status ({existing_parameter.secret} != {secret_parameter})") - def get(self, name: str, args: argparse.Namespace): - args = get_arguments(self.parameters, args, name) + return existing_parameter - def create(): - instance = self.type(**args) - if hasattr(instance, "init") and not getattr(self.type, "__transparent__", False): - instance.init() - setattr(instance, "configurable_recreate", create) - return instance - return create() +def parameter_definitions_for_class(cls: type, name: list[str], parameter_collection: ParameterCollection) -> dict[str, ParameterDefinition]: + return {name: parameter_definition_for(*metadata, parameter_collection=parameter_collection) for name, metadata in get_inspect_parameters_for_class(cls, name).items()} -def get_class_parameters(cls, name: str = None, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions: - if name is None: - name = cls.__name__ - if fields is None and hasattr(cls, "__dataclass_fields__"): - fields = cls.__dataclass_fields__ - return get_parameters(cls.__init__, name, fields) +def parameter_definition_for(param: inspect.Parameter, name: list[str], field: Optional[dataclasses.Field] = None, *, parameter_collection: ParameterCollection) -> ParameterDefinition: + parameter_type, description, default = get_type_description_default_for_parameter(param, name, field) + secret_parameter = (field and field.metadata.get("secret", False)) or getattr(parameter_type, "__secret__", False) -def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions: - if fields is None: - fields = dict() + if (field and field.metadata.get("global", False)) or getattr(parameter_type, "__global__", False): + if field and field.metadata.get("global_name", None): + name = [field.metadata["global_name"]] + elif getattr(parameter_type, "__global_name__", None): + name = [parameter_type["__global_name__"]] + else: + name = [name[-1]] + + if (field and field.metadata.get("transparent", False)) or getattr(parameter_type, "__transparent__", False): + name = name[:-1] + + if parameter_type in (str, int, float, bool): + existing_parameter = try_existing_parameter(parameter_collection, name, typ=ParameterDefinition, parameter_type=parameter_type, default=default, description=description, secret_parameter=secret_parameter) + if existing_parameter: + return existing_parameter + parameter = ParameterDefinition(name, parameter_type, default, description, secret_parameter) + set_at(parameter_collection, name, parameter) + + elif get_origin(parameter_type) is Union: + existing_parameter = try_existing_parameter(parameter_collection, name, typ=ChoiceParameterDefinition, parameter_type=parameter_type, default=default, description=description, secret_parameter=secret_parameter) + if existing_parameter: + return existing_parameter + + if field and field.metadata.get("choices") is not None: + choices = { + name: (typ, parameter_definitions_for_class(typ, name, parameter_collection)) + for name, typ in field.metadata.get('choices').items() + } + else: + choices = { + getattr(arg, "name", None) or getattr(arg, "__name__", None) or arg.__class__.__name__: ( + arg, + parameter_definitions_for_class(arg, name, parameter_collection) + ) + for arg in get_args(parameter_type) + } + + parameter = ChoiceParameterDefinition( + name=name, + type=parameter_type, + default=default, + description=description, + secret=secret_parameter, + choices=choices, + ) + set_at(parameter_collection, name, parameter, meta=True) - sig = inspect.signature(fun) - params: ParameterDefinitions = {} - for name, param in sig.parameters.items(): - if name == "self" or name.startswith("_"): - continue + else: + existing_parameter = try_existing_parameter(parameter_collection, name, typ=ComplexParameterDefinition, parameter_type=parameter_type, default=default, description=description, secret_parameter=secret_parameter) + if existing_parameter: + return existing_parameter - if not param.annotation: - raise ValueError(f"Parameter {name} of {basename} must have a type annotation") - - default = param.default if param.default != inspect.Parameter.empty else None - description = None - type = param.annotation - - field = None - if isinstance(default, dataclasses.Field): - field = default - default = field.default - elif name in fields: - field = fields[name] - - if field is not None: - description = field.metadata.get("desc", None) - if field.type is not None: - type = field.type - - if hasattr(type, "__parameters__"): - params[name] = ComplexParameterDefinition(name, type, default, description, get_class_parameters(type, basename), transparent=getattr(type, "__transparent__", False)) - elif type in (str, int, float, bool): - params[name] = ParameterDefinition(name, type, default, description) - else: - raise ValueError(f"Parameter {name} of {basename} must have str, int, bool, or a __parameters__ class as type, not {type}") + parameter = ComplexParameterDefinition( + name=name, + type=parameter_type, + default=default, + description=description, + secret=secret_parameter, + parameters=parameter_definitions_for_class(parameter_type, name, parameter_collection), + ) + set_at(parameter_collection, name, parameter, meta=True) - return params + return parameter -def build_parser(parameters: ParameterDefinitions, parser: argparse.ArgumentParser, basename: str = ""): - for name, parameter in parameters.items(): - parameter.parser(next_name(basename, name, parameter), parser) -def get_arguments(parameters: ParameterDefinitions, args: argparse.Namespace, basename: str = "") -> Dict[str, Any]: - return {name: parameter.get(next_name(basename, name, parameter), args) for name, parameter in parameters.items()} +@dataclass +class Parseable(Generic[C]): + cls: Type[C] + description: Optional[str] + + _parameter: ComplexParameterDefinition = field(init=False) + _parameter_collection: ParameterCollection = field(init=False, default_factory=dict) + + def __call__(self, parsing_results: ParsingResults): + return self._parameter(parsing_results) + + def __post_init__(self): + self._parameter = ComplexParameterDefinition( + name=[], + type=self.cls, + default=no_default(), + description=self.description, + secret=False, + parameters=parameter_definitions_for_class(self.cls, [], self._parameter_collection), + ) + + def to_help(self, defaults: list[tuple[str, ParsingResults]], level: int = 0) -> str: + return "\n".join(dfs_flatmap(self._parameter_collection, lambda _, parameter: parameter.to_help(defaults, level+1) if not isinstance(parameter, ComplexParameterDefinition) else None)) + + +CommandMap = dict[str, Union["CommandMap[C]", Parseable[C]]] + + +def _to_help(name: str, commands: Union[CommandMap[C], Parseable[C]], level: int = 0, max_length: int = 0) -> str: + h = "" + if isinstance(commands, Parseable): + h += f"{indent(level)}{COMMAND_COLOR}{name}{COLOR_RESET}{' ' * (max_length - len(name)+4)} {commands.description}\n" + elif isinstance(commands, dict): + h += f"{indent(level)}{COMMAND_COLOR}{name}{COLOR_RESET}:\n" + max_length = max(max_length, level*INDENT_WIDTH + max(len(k) for k in commands.keys())) + for name, parser in commands.items(): + h += _to_help(name, parser, level + 1, max_length) + return h + + +def to_help_for_commands(program: str, commands: CommandMap[C], command_chain: Optional[list[str]] = None) -> str: + if command_chain is None: + command_chain = [] + h = f"usage: {program} {COMMAND_COLOR}{' '.join(command_chain)} {COLOR_RESET} {PARAMETER_COLOR}[--help] [--config config.json] [options...]{COLOR_RESET}\n\n" + h += _to_help("commands", commands, 0) + return h + + +def to_help_for_command(program: str, command: list[str], parseable: Parseable[C], defaults: list[tuple[str, ParsingResults]]) -> str: + h = f"usage: {program} {COMMAND_COLOR}{' '.join(command)}{COLOR_RESET} {PARAMETER_COLOR}[--help] [--config config.json] [options...]{COLOR_RESET}\n\n" + h += parseable.to_help(defaults) + h += "\n" + return h + + +class InvalidCommand(ValueError): + def __init__(self, error: str, command: list[str], usage: str): + super().__init__(error) + self.command_list = command + self.usage = usage + + +def instantiate(args: list[str], commands: CommandMap[C]) -> tuple[C, ParsingResults]: + if len(args) == 0: + raise ValueError("No arguments provided (this is probably a bug in the program)") + return _instantiate(args[0], args[1:], commands, []) + + +def _instantiate(program: str, args: list[str], commands: CommandMap[C], command_chain: list[str]) -> tuple[C, ParsingResults]: + if command_chain is None: + command_chain = [] + + if len(args) == 0: + raise InvalidCommand("No command provided", command_chain, to_help_for_commands(program, commands)) + if args[0] not in commands: + raise InvalidCommand(f"Command {args[0]} not found", command_chain, to_help_for_commands(program, commands)) + + command = commands[args[0]] + command_chain.append(args[0]) + if isinstance(command, Parseable): + return parse_args(program, command_chain, args[1:], command) + elif isinstance(command, dict): + try: + return _instantiate(program, args[1:], command, command_chain) + except InvalidCommand as e: + e.command_list.append(args[0]) + raise e + else: + raise TypeError(f"Invalid command type {type(command)}") + + +def get_environment_variables(parsing_results: ParsingResults, parameter_collection: ParameterCollection) -> tuple[str, ParsingResults]: + env_parsing_results = dict() + for key, value in os.environ.items(): + # legacy support + test_key = key.split(".") + if get_at(parameter_collection, test_key) is None: + test_key = key.lower().split(".") + if get_at(parameter_collection, test_key) is None: + test_key = key.replace("_", ".").split(".") + if get_at(parameter_collection, test_key) is None: + test_key = key.lower().replace("-", ".").split(".") + if get_at(parameter_collection, test_key) is None: + continue + set_at(parsing_results, test_key, value) + set_at(env_parsing_results, test_key, value) + return ("environment variables", env_parsing_results) + + +def get_env_file_variables(parsing_results: ParsingResults, parameter_collection: ParameterCollection) -> tuple[str, ParsingResults]: + env_file_parsing_results = dict() + for key, value in dotenv_values().items(): + key = key.split(".") + if get_at(parameter_collection, key) is None: + continue + set_at(parsing_results, key, value) + set_at(env_file_parsing_results, key, value) + return (".env file", env_file_parsing_results) -Configurable = Type # TODO: Define type +def get_config_file_variables(config_file_path: str, parsing_results: ParsingResults, parameter_collection: ParameterCollection) -> tuple[str, ParsingResults]: + with open(config_file_path, "r") as config_file: + config_file_parsing_results = json.load(config_file) + return (f"config file at '{config_file_path}'", config_file_parsing_results) -def configurable(service_name: str, service_desc: str): - """ - Anything that is decorated with the @configurable decorator gets the parameters of its __init__ method extracted, - which can then be used with build_parser and get_arguments to recursively prepare the argparse parser and extract the - initialization parameters. These can then be used to initialize the class with the correct parameters. - """ - def inner(cls) -> Configurable: - cls.name = service_name - cls.description = service_desc - cls.__service__ = True - cls.__parameters__ = get_class_parameters(cls) +def filter_secret_values(parsing_results: ParsingResults, parameter_collection: ParameterCollection, basename: Optional[list[str]] = None) -> ParsingResults: + if basename is None: + basename = [] - return cls + for key, value in parsing_results.items(): + if isinstance(value, dict): + filter_secret_values(value, parameter_collection, basename + [key]) + else: + parameter = get_at(parameter_collection, basename + [key]) + if parameter.secret: + parsing_results[key] = "" - return inner +def parse_args(program: str, command: list[str], direct_args: list[str], parseable: Parseable[C], parse_env_file: bool = True, parse_environment: bool = True) -> tuple[C, ParsingResults]: + parameter_collection = parseable._parameter_collection -T = TypeVar("T") + parsing_results: ParsingResults = dict() + defaults: list[tuple[str, ParsingResults]] = [] + if parse_environment: + defaults.append(get_environment_variables(parsing_results, parameter_collection)) + if parse_env_file: + defaults.append(get_env_file_variables(parsing_results, parameter_collection)) -def transparent(subclass: T) -> T: - """ - setting a type to be transparent means, that it will not increase a level in the configuration tree, so if you have the following classes: + if "--config" in direct_args: + config_file_idx = direct_args.index("--config") + direct_args.pop(config_file_idx) - class Inner: - a: int - b: str + if len(direct_args) < config_file_idx + 1: + raise ValueError("Missing config file argument") - def init(self): - print("inner init") + config_file_name = direct_args.pop(config_file_idx) + defaults.append(get_config_file_variables(config_file_name, parsing_results, parameter_collection)) - class Outer: - inner: transparent(Inner) + def _help(): + return to_help_for_command(program, command, parseable, defaults) - def init(self): - inner.init() + if any(arg in ("--help", "-h") for arg in direct_args): + raise InvalidCommand("", command, _help()) - the configuration will be `--a` and `--b` instead of `--inner.a` and `--inner.b`. + while len(direct_args) > 0: + arg = direct_args.pop(0) + if arg.startswith("--"): + key = arg[2:] + if "=" in key: + key, value = key.split("=", 1) + else: + if len(direct_args) == 0: + raise InvalidCommand(f"No value for argument {arg}", command, _help()) + value = direct_args.pop(0) + key = key.split(".") + if get_at(parameter_collection, key, no_raise=True) is None: + meta_param = get_at(parameter_collection, key, meta=True, no_raise=True) + if meta_param is None or not isinstance(meta_param, ChoiceParameterDefinition): + raise InvalidCommand(f"Invalid argument {arg}", command, _help()) + else: + key += ["$"] + set_at(parsing_results, key, value) + else: + raise InvalidCommand(f"Invalid argument {arg}", command, _help()) - A transparent attribute will also not have its init function called automatically, so you will need to do that on your own, as seen in the Outer init. - """ - class Cloned(subclass): - __transparent__ = True - Cloned.__name__ = subclass.__name__ - Cloned.__qualname__ = subclass.__qualname__ - Cloned.__module__ = subclass.__module__ - return Cloned + def populate_default(name: list[str], parameter: ParameterDefinition): + if get_at(parsing_results, name) is None: + default, _, _ = parameter.get_default(defaults) + set_at(parsing_results, name, default) + dfs_flatmap(parameter_collection, populate_default) -def next_name(basename: str, name: str, param: Any) -> str: - if isinstance(param, ComplexParameterDefinition) and param.transparent: - return basename - elif basename == "": - return name - else: - return f"{basename}.{name}" + try: + instance = parseable(parsing_results) + except ParameterError as e: + raise InvalidCommand(f"{e}", command, _help()) from e + filter_secret_values(parsing_results, parameter_collection) + return instance, parsing_results diff --git a/src/hackingBuddyGPT/utils/console/__init__.py b/src/hackingBuddyGPT/utils/console/__init__.py index f2abc52a..5a70da15 100644 --- a/src/hackingBuddyGPT/utils/console/__init__.py +++ b/src/hackingBuddyGPT/utils/console/__init__.py @@ -1 +1,3 @@ from .console import Console + +__all__ = ["Console"] diff --git a/src/hackingBuddyGPT/utils/console/console.py b/src/hackingBuddyGPT/utils/console/console.py index e48091e1..bcc8e148 100644 --- a/src/hackingBuddyGPT/utils/console/console.py +++ b/src/hackingBuddyGPT/utils/console/console.py @@ -8,5 +8,6 @@ class Console(console.Console): """ Simple wrapper around the rich Console class, to allow for dependency injection and configuration. """ + def __init__(self): super().__init__() diff --git a/src/hackingBuddyGPT/utils/db_storage/__init__.py b/src/hackingBuddyGPT/utils/db_storage/__init__.py index e3f08cce..b2e96daa 100644 --- a/src/hackingBuddyGPT/utils/db_storage/__init__.py +++ b/src/hackingBuddyGPT/utils/db_storage/__init__.py @@ -1 +1,3 @@ -from .db_storage import DbStorage \ No newline at end of file +from .db_storage import DbStorage + +__all__ = ["DbStorage"] diff --git a/src/hackingBuddyGPT/utils/db_storage/db_storage.py b/src/hackingBuddyGPT/utils/db_storage/db_storage.py index 497c023d..b15853bd 100644 --- a/src/hackingBuddyGPT/utils/db_storage/db_storage.py +++ b/src/hackingBuddyGPT/utils/db_storage/db_storage.py @@ -1,11 +1,101 @@ +from dataclasses import dataclass, field +from dataclasses_json import config, dataclass_json +import datetime import sqlite3 - -from hackingBuddyGPT.utils.configurable import configurable, parameter +from typing import Literal, Optional, Union + +from hackingBuddyGPT.utils.configurable import Global, configurable, parameter + + +timedelta_metadata = config(encoder=lambda td: td.total_seconds(), decoder=lambda seconds: datetime.timedelta(seconds=seconds)) +datetime_metadata = config(encoder=lambda dt: dt.isoformat(), decoder=lambda iso: datetime.datetime.fromisoformat(iso)) +optional_datetime_metadata = config(encoder=lambda dt: dt.isoformat() if dt else None, decoder=lambda iso: datetime.datetime.fromisoformat(iso) if iso else None) + + +StreamAction = Literal["append"] + + +@dataclass_json +@dataclass +class Run: + id: int + model: str + state: str + tag: str + started_at: datetime.datetime = field(metadata=datetime_metadata) + stopped_at: Optional[datetime.datetime] = field(metadata=optional_datetime_metadata) + configuration: str + + +@dataclass_json +@dataclass +class Section: + run_id: int + id: int + name: str + from_message: int + to_message: int + duration: datetime.timedelta = field(metadata=timedelta_metadata) + + +@dataclass_json +@dataclass +class Message: + run_id: int + id: int + version: int + conversation: str + role: str + content: str + duration: datetime.timedelta = field(metadata=timedelta_metadata) + tokens_query: int + tokens_response: int + + +@dataclass_json +@dataclass +class MessageStreamPart: + id: int + run_id: int + message_id: int + action: StreamAction + content: str + + +@dataclass_json +@dataclass +class ToolCall: + run_id: int + message_id: int + id: str + version: int + function_name: str + arguments: str + state: str + result_text: str + duration: datetime.timedelta = field(metadata=timedelta_metadata) + + +@dataclass_json +@dataclass +class ToolCallStreamPart: + id: int + run_id: int + message_id: int + tool_call_id: str + field: Literal["arguments", "result"] + action: StreamAction + content: str + + +LogTypes = Union[Run, Section, Message, MessageStreamPart, ToolCall, ToolCallStreamPart] @configurable("db_storage", "Stores the results of the experiments in a SQLite database") -class DbStorage: - def __init__(self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default=":memory:")): +class RawDbStorage: + def __init__( + self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default="wintermute.sqlite3") + ): self.connection_string = connection_string def init(self): @@ -13,196 +103,189 @@ def init(self): self.setup_db() def connect(self): - self.db = sqlite3.connect(self.connection_string) + self.db = sqlite3.connect(self.connection_string, isolation_level=None) + self.db.row_factory = sqlite3.Row self.cursor = self.db.cursor() - def insert_or_select_cmd(self, name: str) -> int: - results = self.cursor.execute("SELECT id, name FROM commands WHERE name = ?", (name,)).fetchall() - - if len(results) == 0: - self.cursor.execute("INSERT INTO commands (name) VALUES (?)", (name,)) - return self.cursor.lastrowid - elif len(results) == 1: - return results[0][0] - else: - print("this should not be happening: " + str(results)) - return -1 - def setup_db(self): # create tables - self.cursor.execute("""CREATE TABLE IF NOT EXISTS runs ( - id INTEGER PRIMARY KEY, - model text, - state TEXT, - tag TEXT, - started_at text, - stopped_at text, - rounds INTEGER, - configuration TEXT - )""") - self.cursor.execute("""CREATE TABLE IF NOT EXISTS commands ( - id INTEGER PRIMARY KEY, - name string unique - )""") - self.cursor.execute("""CREATE TABLE IF NOT EXISTS queries ( - run_id INTEGER, - round INTEGER, - cmd_id INTEGER, - query TEXT, - response TEXT, - duration REAL, - tokens_query INTEGER, - tokens_response INTEGER, - prompt TEXT, - answer TEXT - )""") - self.cursor.execute("""CREATE TABLE IF NOT EXISTS messages ( - run_id INTEGER, - message_id INTEGER, - role TEXT, - content TEXT, - duration REAL, - tokens_query INTEGER, - tokens_response INTEGER - )""") - self.cursor.execute("""CREATE TABLE IF NOT EXISTS tool_calls ( - run_id INTEGER, - message_id INTEGER, - tool_call_id INTEGER, - function_name TEXT, - arguments TEXT, - result_text TEXT, - duration REAL - )""") - - # insert commands - self.query_cmd_id = self.insert_or_select_cmd('query_cmd') - self.analyze_response_id = self.insert_or_select_cmd('analyze_response') - self.state_update_id = self.insert_or_select_cmd('update_state') - - def create_new_run(self, model, tag): + self.cursor.execute(""" + CREATE TABLE IF NOT EXISTS runs ( + id INTEGER PRIMARY KEY, + model text, + state TEXT, + tag TEXT, + started_at text, + stopped_at text, + configuration TEXT + ) + """) + self.cursor.execute(""" + CREATE TABLE IF NOT EXISTS sections ( + run_id INTEGER, + id INTEGER, + name TEXT, + from_message INTEGER, + to_message INTEGER, + duration REAL, + PRIMARY KEY (run_id, id), + FOREIGN KEY (run_id) REFERENCES runs (id) + ) + """) + self.cursor.execute(""" + CREATE TABLE IF NOT EXISTS messages ( + run_id INTEGER, + conversation TEXT, + id INTEGER, + version INTEGER DEFAULT 0, + role TEXT, + content TEXT, + duration REAL, + tokens_query INTEGER, + tokens_response INTEGER, + PRIMARY KEY (run_id, id), + FOREIGN KEY (run_id) REFERENCES runs (id) + ) + """) + self.cursor.execute(""" + CREATE TABLE IF NOT EXISTS tool_calls ( + run_id INTEGER, + message_id INTEGER, + id TEXT, + version INTEGER DEFAULT 0, + function_name TEXT, + arguments TEXT, + state TEXT, + result_text TEXT, + duration REAL, + PRIMARY KEY (run_id, message_id, id), + FOREIGN KEY (run_id, message_id) REFERENCES messages (run_id, id) + ) + """) + + def get_runs(self) -> list[Run]: + def deserialize(row): + row = dict(row) + row["started_at"] = datetime.datetime.fromisoformat(row["started_at"]) + row["stopped_at"] = datetime.datetime.fromisoformat(row["stopped_at"]) if row["stopped_at"] else None + return row + + self.cursor.execute("SELECT * FROM runs") + return [Run(**deserialize(row)) for row in self.cursor.fetchall()] + + def get_sections_by_run(self, run_id: int) -> list[Section]: + def deserialize(row): + row = dict(row) + row["duration"] = datetime.timedelta(seconds=row["duration"]) + return row + + self.cursor.execute("SELECT * FROM sections WHERE run_id = ?", (run_id,)) + return [Section(**deserialize(row)) for row in self.cursor.fetchall()] + + def get_messages_by_run(self, run_id: int) -> list[Message]: + def deserialize(row): + row = dict(row) + row["duration"] = datetime.timedelta(seconds=row["duration"]) + return row + + self.cursor.execute("SELECT * FROM messages WHERE run_id = ?", (run_id,)) + return [Message(**deserialize(row)) for row in self.cursor.fetchall()] + + def get_tool_calls_by_run(self, run_id: int) -> list[ToolCall]: + def deserialize(row): + row = dict(row) + row["duration"] = datetime.timedelta(seconds=row["duration"]) + return row + + self.cursor.execute("SELECT * FROM tool_calls WHERE run_id = ?", (run_id,)) + return [ToolCall(**deserialize(row)) for row in self.cursor.fetchall()] + + def create_run(self, model: str, tag: str, started_at: datetime.datetime, configuration: str) -> int: self.cursor.execute( - "INSERT INTO runs (model, state, tag, started_at) VALUES (?, ?, ?, datetime('now'))", - (model, "in progress", tag)) + "INSERT INTO runs (model, state, tag, started_at, configuration) VALUES (?, ?, ?, ?, ?)", + (model, "in progress", tag, started_at, configuration), + ) return self.cursor.lastrowid - def add_log_query(self, run_id, round, cmd, result, answer): + def add_message(self, run_id: int, message_id: int, conversation: Optional[str], role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta): self.cursor.execute( - "INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - ( - run_id, round, self.query_cmd_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response, - answer.prompt, answer.answer)) + "INSERT INTO messages (run_id, conversation, id, role, content, tokens_query, tokens_response, duration) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (run_id, conversation, message_id, role, content, tokens_query, tokens_response, duration.total_seconds()) + ) - def add_log_analyze_response(self, run_id, round, cmd, result, answer): + def add_or_update_message(self, run_id: int, message_id: int, conversation: Optional[str], role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta): self.cursor.execute( - "INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (run_id, round, self.analyze_response_id, cmd, result, answer.duration, answer.tokens_query, - answer.tokens_response, answer.prompt, answer.answer)) - - def add_log_update_state(self, run_id, round, cmd, result, answer): - - if answer is not None: + "SELECT COUNT(*) FROM messages WHERE run_id = ? AND id = ?", + (run_id, message_id), + ) + if self.cursor.fetchone()[0] == 0: self.cursor.execute( - "INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (run_id, round, self.state_update_id, cmd, result, answer.duration, answer.tokens_query, - answer.tokens_response, answer.prompt, answer.answer)) + "INSERT INTO messages (run_id, conversation, id, role, content, tokens_query, tokens_response, duration) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (run_id, conversation, message_id, role, content, tokens_query, tokens_response, duration.total_seconds()), + ) else: - self.cursor.execute( - "INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (run_id, round, self.state_update_id, cmd, result, 0, 0, 0, '', '')) + if len(content) > 0: + self.cursor.execute( + "UPDATE messages SET conversation = ?, role = ?, content = ?, tokens_query = ?, tokens_response = ?, duration = ? WHERE run_id = ? AND id = ?", + (conversation, role, content, tokens_query, tokens_response, duration.total_seconds(), run_id, message_id), + ) + else: + self.cursor.execute( + "UPDATE messages SET conversation = ?, role = ?, tokens_query = ?, tokens_response = ?, duration = ? WHERE run_id = ? AND id = ?", + (conversation, role, tokens_query, tokens_response, duration.total_seconds(), run_id, message_id), + ) + + def add_section(self, run_id: int, section_id: int, name: str, from_message: int, to_message: int, duration: datetime.timedelta): + self.cursor.execute( + "INSERT OR REPLACE INTO sections (run_id, id, name, from_message, to_message, duration) VALUES (?, ?, ?, ?, ?, ?)", + (run_id, section_id, name, from_message, to_message, duration.total_seconds()) + ) - def add_log_message(self, run_id: int, role: str, content: str, tokens_query: int, tokens_response: int, duration): + def add_tool_call(self, run_id: int, message_id: int, tool_call_id: str, function_name: str, arguments: str, result_text: str, duration: datetime.timedelta): self.cursor.execute( - "INSERT INTO messages (run_id, message_id, role, content, tokens_query, tokens_response, duration) VALUES (?, (SELECT COALESCE(MAX(message_id), 0) + 1 FROM messages WHERE run_id = ?), ?, ?, ?, ?, ?)", - (run_id, run_id, role, content, tokens_query, tokens_response, duration)) - self.cursor.execute("SELECT MAX(message_id) FROM messages WHERE run_id = ?", (run_id,)) - return self.cursor.fetchone()[0] + "INSERT INTO tool_calls (run_id, message_id, id, function_name, arguments, result_text, duration) VALUES (?, ?, ?, ?, ?, ?, ?)", + (run_id, message_id, tool_call_id, function_name, arguments, result_text, duration.total_seconds()), + ) - def add_log_tool_call(self, run_id: int, message_id: int, tool_call_id: str, function_name: str, arguments: str, result_text: str, duration): + def handle_message_update(self, run_id: int, message_id: int, action: StreamAction, content: str): + if action != "append": + raise ValueError("unsupported action" + action) self.cursor.execute( - "INSERT INTO tool_calls (run_id, message_id, tool_call_id, function_name, arguments, result_text, duration) VALUES (?, ?, ?, ?, ?, ?, ?)", - (run_id, message_id, tool_call_id, function_name, arguments, result_text, duration)) - - def get_round_data(self, run_id, round, explanation, status_update): - rows = self.cursor.execute( - "select cmd_id, query, response, duration, tokens_query, tokens_response from queries where run_id = ? and round = ?", - (run_id, round)).fetchall() - if len(rows) == 0: - return [] - - for row in rows: - if row[0] == self.query_cmd_id: - cmd = row[1] - size_resp = str(len(row[2])) - duration = f"{row[3]:.4f}" - tokens = f"{row[4]}/{row[5]}" - if row[0] == self.analyze_response_id and explanation: - reason = row[2] - analyze_time = f"{row[3]:.4f}" - analyze_token = f"{row[4]}/{row[5]}" - if row[0] == self.state_update_id and status_update: - state_time = f"{row[3]:.4f}" - state_token = f"{row[4]}/{row[5]}" - - result = [duration, tokens, cmd, size_resp] - if explanation: - result += [analyze_time, analyze_token, reason] - if status_update: - result += [state_time, state_token] - return result - - def get_max_round_for(self, run_id): - run = self.cursor.execute("select max(round) from queries where run_id = ?", (run_id,)).fetchone() - if run is not None: - return run[0] - else: - return None + "UPDATE messages SET content = content || ?, version = version + 1 WHERE run_id = ? AND id = ?", + (content, run_id, message_id), + ) - def get_run_data(self, run_id): - run = self.cursor.execute("select * from runs where id = ?", (run_id,)).fetchone() - if run is not None: - return run[1], run[2], run[4], run[3], run[7], run[8] + def finalize_message(self, run_id: int, message_id: int, tokens_query: int, tokens_response: int, duration: datetime.timedelta, overwrite_finished_message: Optional[str] = None): + if overwrite_finished_message: + self.cursor.execute( + "UPDATE messages SET content = ?, tokens_query = ?, tokens_response = ?, duration = ? WHERE run_id = ? AND id = ?", + (overwrite_finished_message, tokens_query, tokens_response, duration.total_seconds(), run_id, message_id), + ) else: - return None - - def get_log_overview(self): - result = {} - - max_rounds = self.cursor.execute("select run_id, max(round) from queries group by run_id").fetchall() - for row in max_rounds: - state = self.cursor.execute("select state from runs where id = ?", (row[0],)).fetchone() - last_cmd = self.cursor.execute("select query from queries where run_id = ? and round = ?", - (row[0], row[1])).fetchone() - - result[row[0]] = { - "max_round": int(row[1]) + 1, - "state": state[0], - "last_cmd": last_cmd[0] - } - - return result - - def get_cmd_history(self, run_id): - rows = self.cursor.execute( - "select query, response from queries where run_id = ? and cmd_id = ? order by round asc", - (run_id, self.query_cmd_id)).fetchall() - - result = [] - - for row in rows: - result.append([row[0], row[1]]) + self.cursor.execute( + "UPDATE messages SET tokens_query = ?, tokens_response = ?, duration = ? WHERE run_id = ? AND id = ?", + (tokens_query, tokens_response, duration.total_seconds(), run_id, message_id), + ) - return result + def update_run(self, run_id: int, model: str, state: str, tag: str, started_at: datetime.datetime, stopped_at: datetime.datetime, configuration: str): + self.cursor.execute( + "UPDATE runs SET model = ?, state = ?, tag = ?, started_at = ?, stopped_at = ?, configuration = ? WHERE id = ?", + (model, state, tag, started_at, stopped_at, configuration, run_id), + ) - def run_was_success(self, run_id, round): - self.cursor.execute("update runs set state=?,stopped_at=datetime('now'), rounds=? where id = ?", - ("got root", round, run_id)) + def run_was_success(self, run_id): + self.cursor.execute( + "update runs set state=?,stopped_at=datetime('now') where id = ?", + ("got root", run_id), + ) self.db.commit() - def run_was_failure(self, run_id, round): - self.cursor.execute("update runs set state=?, stopped_at=datetime('now'), rounds=? where id = ?", - ("reached max runs", round, run_id)) + def run_was_failure(self, run_id: int, reason: str): + self.cursor.execute( + "update runs set state=?, stopped_at=datetime('now') where id = ?", + (reason, run_id), + ) self.db.commit() - def commit(self): - self.db.commit() + +DbStorage = Global(RawDbStorage) diff --git a/src/hackingBuddyGPT/utils/llm_util.py b/src/hackingBuddyGPT/utils/llm_util.py index 658abe44..fc04dc62 100644 --- a/src/hackingBuddyGPT/utils/llm_util.py +++ b/src/hackingBuddyGPT/utils/llm_util.py @@ -1,19 +1,27 @@ import abc +import datetime import re import typing from dataclasses import dataclass -from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, ChatCompletionToolMessageParam, ChatCompletionAssistantMessageParam, ChatCompletionFunctionMessageParam +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionFunctionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionUserMessageParam, +) SAFETY_MARGIN = 128 STEP_CUT_TOKENS = 128 + @dataclass class LLMResult: result: typing.Any prompt: str answer: str - duration: float = 0 + duration: datetime.timedelta = datetime.timedelta(0) tokens_query: int = 0 tokens_response: int = 0 @@ -92,6 +100,7 @@ def cmd_output_fixer(cmd: str) -> str: return cmd + # this is ugly, but basically we only have an approximation how many tokens # we are currently using. So we cannot just cut down to the desired size # what we're doing is: @@ -109,7 +118,7 @@ def trim_result_front(model: LLM, target_size: int, result: str) -> str: TARGET_SIZE_FACTOR = 3 if cur_size > TARGET_SIZE_FACTOR * target_size: print(f"big step trim-down from {cur_size} to {2 * target_size}") - result = result[:TARGET_SIZE_FACTOR * target_size] + result = result[: TARGET_SIZE_FACTOR * target_size] cur_size = model.count_tokens(result) while cur_size > target_size: @@ -119,4 +128,4 @@ def trim_result_front(model: LLM, target_size: int, result: str) -> str: result = result[:-step] cur_size = model.count_tokens(result) - return result \ No newline at end of file + return result diff --git a/src/hackingBuddyGPT/utils/logging.py b/src/hackingBuddyGPT/utils/logging.py new file mode 100644 index 00000000..5acee710 --- /dev/null +++ b/src/hackingBuddyGPT/utils/logging.py @@ -0,0 +1,360 @@ +import datetime +from enum import Enum +import time +from dataclasses import dataclass, field +from functools import wraps +from typing import Optional, Union +import threading + +from dataclasses_json.api import dataclass_json + +from hackingBuddyGPT.utils import Console, DbStorage, LLMResult, configurable, parameter +from hackingBuddyGPT.utils.db_storage.db_storage import StreamAction +from hackingBuddyGPT.utils.configurable import Global, Transparent +from rich.console import Group +from rich.panel import Panel +from websockets.sync.client import ClientConnection, connect as ws_connect + +from hackingBuddyGPT.utils.db_storage.db_storage import Run, Section, Message, MessageStreamPart, ToolCall, ToolCallStreamPart + + +def log_section(name: str, logger_field_name: str = "log"): + def outer(fun): + @wraps(fun) + def inner(self, *args, **kwargs): + logger = getattr(self, logger_field_name) + with logger.section(name): + return fun(self, *args, **kwargs) + return inner + return outer + + +def log_conversation(conversation: str, start_section: bool = False, logger_field_name: str = "log"): + def outer(fun): + @wraps(fun) + def inner(self, *args, **kwargs): + logger = getattr(self, logger_field_name) + with logger.conversation(conversation, start_section): + return fun(self, *args, **kwargs) + return inner + return outer + + +MessageData = Union[Run, Section, Message, MessageStreamPart, ToolCall, ToolCallStreamPart] + + +class MessageType(str, Enum): + MESSAGE_REQUEST = "MessageRequest" + RUN = "Run" + SECTION = "Section" + MESSAGE = "Message" + MESSAGE_STREAM_PART = "MessageStreamPart" + TOOL_CALL = "ToolCall" + TOOL_CALL_STREAM_PART = "ToolCallStreamPart" + + def get_class(self): + return { + "Run": Run, + "Section": Section, + "Message": Message, + "MessageStreamPart": MessageStreamPart, + "ToolCall": ToolCall, + "ToolCallStreamPart": ToolCallStreamPart, + }[self.value] + + +@dataclass_json +@dataclass +class ControlMessage: + type: MessageType + data: MessageData + + @classmethod + def from_dict(cls, data): + type_ = MessageType(data['type']) + data_class = type_.get_class() + data_instance = data_class.from_dict(data['data']) + return cls(type=type_, data=data_instance) + + +@configurable("local_logger", "Local Logger") +@dataclass +class LocalLogger: + log_db: DbStorage + console: Console + + tag: str = parameter(desc="Tag for your current run", default="") + + run: Run = field(init=False, default=None) # field and not a parameter, since this can not be user configured + + _last_message_id: int = 0 + _last_section_id: int = 0 + _current_conversation: Optional[str] = None + + def start_run(self, name: str, configuration: str): + if self.run is not None: + raise ValueError("Run already started") + start_time = datetime.datetime.now() + run_id = self.log_db.create_run(name, self.tag, start_time , configuration) + self.run = Run(run_id, name, "", self.tag, start_time, None, configuration) + + def section(self, name: str) -> "LogSectionContext": + return LogSectionContext(self, name, self._last_message_id) + + def log_section(self, name: str, from_message: int, to_message: int, duration: datetime.timedelta): + section_id = self._last_section_id + self._last_section_id += 1 + + self.log_db.add_section(self.run.id, section_id, name, from_message, to_message, duration) + + return section_id + + def finalize_section(self, section_id: int, name: str, from_message: int, duration: datetime.timedelta): + self.log_db.add_section(self.run.id, section_id, name, from_message, self._last_message_id, duration) + + def conversation(self, conversation: str, start_section: bool = False) -> "LogConversationContext": + return LogConversationContext(self, start_section, conversation, self._current_conversation) + + def add_message(self, role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta) -> int: + message_id = self._last_message_id + self._last_message_id += 1 + + self.log_db.add_message(self.run.id, message_id, self._current_conversation, role, content, tokens_query, tokens_response, duration) + self.console.print(Panel(content, title=(("" if self._current_conversation is None else f"{self._current_conversation} - ") + role))) + + return message_id + + def _add_or_update_message(self, message_id: int, conversation: Optional[str], role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta): + self.log_db.add_or_update_message(self.run.id, message_id, conversation, role, content, tokens_query, tokens_response, duration) + + def add_tool_call(self, message_id: int, tool_call_id: str, function_name: str, arguments: str, result_text: str, duration: datetime.timedelta): + self.console.print(Panel( + Group( + Panel(arguments, title="arguments"), + Panel(result_text, title="result"), + ), + title=f"Tool Call: {function_name}")) + self.log_db.add_tool_call(self.run.id, message_id, tool_call_id, function_name, arguments, result_text, duration) + + def run_was_success(self): + self.status_message("Run finished successfully") + self.log_db.run_was_success(self.run.id) + + def run_was_failure(self, reason: str, details: Optional[str] = None): + full_reason = reason + ("" if details is None else f": {details}") + self.status_message(f"Run failed: {full_reason}") + self.log_db.run_was_failure(self.run.id, reason) + + def status_message(self, message: str): + self.add_message("status", message, 0, 0, datetime.timedelta(0)) + + def system_message(self, message: str): + self.add_message("system", message, 0, 0, datetime.timedelta(0)) + + def call_response(self, llm_result: LLMResult) -> int: + self.system_message(llm_result.prompt) + return self.add_message("assistant", llm_result.answer, llm_result.tokens_query, llm_result.tokens_response, llm_result.duration) + + def stream_message(self, role: str): + message_id = self._last_message_id + self._last_message_id += 1 + + return MessageStreamLogger(self, message_id, self._current_conversation, role) + + def add_message_update(self, message_id: int, action: StreamAction, content: str): + self.log_db.handle_message_update(self.run.id, message_id, action, content) + + +@configurable("remote_logger", "Remote Logger") +@dataclass +class RemoteLogger: + console: Console + log_server_address: str = parameter(desc="address:port of the log server to be used", default="localhost:4444") + + tag: str = parameter(desc="Tag for your current run", default="") + + run: Run = field(init=False, default=None) # field and not a parameter, since this can not be user configured + + _last_message_id: int = 0 + _last_section_id: int = 0 + _current_conversation: Optional[str] = None + _upstream_websocket: ClientConnection = None + + def __del__(self): + if self._upstream_websocket: + self._upstream_websocket.close() + + def init_websocket(self): + self._upstream_websocket = ws_connect(f"ws://{self.log_server_address}/ingress") # TODO: we want to support wss at some point + + def send(self, type: MessageType, data: MessageData): + self._upstream_websocket.send(ControlMessage(type, data).to_json()) + + def start_run(self, name: str, configuration: str, tag: Optional[str] = None, start_time: Optional[datetime.datetime] = None, end_time: Optional[datetime.datetime] = None): + if self._upstream_websocket is None: + self.init_websocket() + + if self.run is not None: + raise ValueError("Run already started") + + if tag is None: + tag = self.tag + + if start_time is None: + start_time = datetime.datetime.now() + + self.run = Run(None, name, None, tag, start_time, None, configuration) + self.send(MessageType.RUN, self.run) + self.run = Run.from_json(self._upstream_websocket.recv()) + + def section(self, name: str) -> "LogSectionContext": + return LogSectionContext(self, name, self._last_message_id) + + def log_section(self, name: str, from_message: int, to_message: int, duration: datetime.timedelta): + section_id = self._last_section_id + self._last_section_id += 1 + + section = Section(self.run.id, section_id, name, from_message, to_message, duration) + self.send(MessageType.SECTION, section) + + return section_id + + def finalize_section(self, section_id: int, name: str, from_message: int, duration: datetime.timedelta): + self.send(MessageType.SECTION, Section(self.run.id, section_id, name, from_message, self._last_message_id, duration)) + + def conversation(self, conversation: str, start_section: bool = False) -> "LogConversationContext": + return LogConversationContext(self, start_section, conversation, self._current_conversation) + + def add_message(self, role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta) -> int: + message_id = self._last_message_id + self._last_message_id += 1 + + msg = Message(self.run.id, message_id, version=1, conversation=self._current_conversation, role=role, content=content, duration=duration, tokens_query=tokens_query, tokens_response=tokens_response) + self.send(MessageType.MESSAGE, msg) + self.console.print(Panel(content, title=(("" if self._current_conversation is None else f"{self._current_conversation} - ") + role))) + + return message_id + + def _add_or_update_message(self, message_id: int, conversation: Optional[str], role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta): + msg = Message(self.run.id, message_id, version=0, conversation=conversation, role=role, content=content, duration=duration, tokens_query=tokens_query, tokens_response=tokens_response) + self.send(MessageType.MESSAGE, msg) + + def add_tool_call(self, message_id: int, tool_call_id: str, function_name: str, arguments: str, result_text: str, duration: datetime.timedelta): + self.console.print(Panel( + Group( + Panel(arguments, title="arguments"), + Panel(result_text, title="result"), + ), + title=f"Tool Call: {function_name}")) + tc = ToolCall(self.run.id, message_id, tool_call_id, 0, function_name, arguments, "success", result_text, duration) + self.send(MessageType.TOOL_CALL, tc) + + def run_was_success(self): + self.status_message("Run finished successfully") + self.run.stopped_at = datetime.datetime.now() + self.run.state = "success" + self.send(MessageType.RUN, self.run) + self.run = Run.from_json(self._upstream_websocket.recv()) + + def run_was_failure(self, reason: str, details: Optional[str] = None): + full_reason = reason + ("" if details is None else f": {details}") + self.status_message(f"Run failed: {full_reason}") + self.run.stopped_at = datetime.datetime.now() + self.run.state = reason + self.send(MessageType.RUN, self.run) + self.run = Run.from_json(self._upstream_websocket.recv()) + + def status_message(self, message: str): + self.add_message("status", message, 0, 0, datetime.timedelta(0)) + + def system_message(self, message: str): + self.add_message("system", message, 0, 0, datetime.timedelta(0)) + + def call_response(self, llm_result: LLMResult) -> int: + self.system_message(llm_result.prompt) + return self.add_message("assistant", llm_result.answer, llm_result.tokens_query, llm_result.tokens_response, llm_result.duration) + + def stream_message(self, role: str): + message_id = self._last_message_id + self._last_message_id += 1 + + return MessageStreamLogger(self, message_id, self._current_conversation, role) + + def add_message_update(self, message_id: int, action: StreamAction, content: str): + part = MessageStreamPart(id=None, run_id=self.run.id, message_id=message_id, action=action, content=content) + self.send(MessageType.MESSAGE_STREAM_PART, part) + + +GlobalLocalLogger = Global(LocalLogger) +GlobalRemoteLogger = Global(RemoteLogger) +Logger = Union[GlobalRemoteLogger, GlobalLocalLogger] +log_param = parameter(desc="choice of logging backend", default="local_logger") + + +@dataclass +class LogSectionContext: + logger: Logger + name: str + from_message: int + + _section_id: int = 0 + + def __enter__(self): + self._start = datetime.datetime.now() + self._section_id = self.logger.log_section(self.name, self.from_message, None, datetime.timedelta(0)) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + duration = datetime.datetime.now() - self._start + self.logger.finalize_section(self._section_id, self.name, self.from_message, duration) + + +@dataclass +class LogConversationContext: + logger: Logger + with_section: bool + conversation: str + previous_conversation: Optional[str] + + _section: Optional[LogSectionContext] = None + + def __enter__(self): + if self.with_section: + self._section = LogSectionContext(self.logger, self.conversation, self.logger._last_message_id) + self._section.__enter__() + self.logger._current_conversation = self.conversation + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._section is not None: + self._section.__exit__(exc_type, exc_val, exc_tb) + del self._section + self.logger._current_conversation = self.previous_conversation + + +@dataclass +class MessageStreamLogger: + logger: Logger + message_id: int + conversation: Optional[str] + role: str + + _completed: bool = False + + def __post_init__(self): + self.logger._add_or_update_message(self.message_id, self.conversation, self.role, "", 0, 0, datetime.timedelta(0)) + + def __del__(self): + if not self._completed: + print(f"streamed message was not finalized ({self.logger.run.id}, {self.message_id}), please make sure to call finalize() on MessageStreamLogger objects") + self.finalize(0, 0, datetime.timedelta(0)) + + def append(self, content: str): + if self._completed: + raise ValueError("MessageStreamLogger already finalized") + self.logger.add_message_update(self.message_id, "append", content) + + def finalize(self, tokens_query: int, tokens_response: int, duration: datetime.timedelta, overwrite_finished_message: Optional[str] = None): + self._completed = True + self.logger._add_or_update_message(self.message_id, self.conversation, self.role, "", tokens_query, tokens_response, duration) + return self.message_id diff --git a/src/hackingBuddyGPT/utils/openai/__init__.py b/src/hackingBuddyGPT/utils/openai/__init__.py index 4c01b0f9..674681ed 100644 --- a/src/hackingBuddyGPT/utils/openai/__init__.py +++ b/src/hackingBuddyGPT/utils/openai/__init__.py @@ -1 +1,3 @@ -from .openai_llm import GPT35Turbo, GPT4, GPT4Turbo +from .openai_llm import GPT4, GPT4Turbo, GPT35Turbo + +__all__ = ["GPT4", "GPT4Turbo", "GPT35Turbo"] diff --git a/src/hackingBuddyGPT/utils/openai/openai_lib.py b/src/hackingBuddyGPT/utils/openai/openai_lib.py index 3e6f8da4..64e1b366 100644 --- a/src/hackingBuddyGPT/utils/openai/openai_lib.py +++ b/src/hackingBuddyGPT/utils/openai/openai_lib.py @@ -1,26 +1,32 @@ -import instructor -from typing import Dict, Union, Iterable, Optional +import datetime +from dataclasses import dataclass +from typing import Dict, Iterable, Optional, Union -from rich.console import Console -from openai.types import CompletionUsage -from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageParam, \ - ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_message_tool_call import Function +import instructor import openai import tiktoken -import time from dataclasses import dataclass +from openai.types import CompletionUsage +from openai.types.chat import ( + ChatCompletionChunk, + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, +) +from openai.types.chat.chat_completion_chunk import ChoiceDelta +from openai.types.chat.chat_completion_message_tool_call import Function +from rich.console import Console -from hackingBuddyGPT.utils import LLM, configurable, LLMResult -from hackingBuddyGPT.utils.configurable import parameter from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.capability import capabilities_to_tools +from hackingBuddyGPT.utils import LLM, LLMResult, configurable +from hackingBuddyGPT.utils.configurable import parameter @configurable("openai-lib", "OpenAI Library based connection") @dataclass class OpenAILib(LLM): - api_key: str = parameter(desc="OpenAI API Key") + api_key: str = parameter(desc="OpenAI API Key", secret=True) model: str = parameter(desc="OpenAI model name") context_size: int = parameter(desc="OpenAI model context size") api_url: str = parameter(desc="URL of the OpenAI API", default="https://api.openai.com/v1") @@ -30,7 +36,12 @@ class OpenAILib(LLM): _client: openai.OpenAI = None def init(self): - self._client = openai.OpenAI(api_key=self.api_key, base_url=self.api_url, timeout=self.api_timeout, max_retries=self.api_retries) + self._client = openai.OpenAI( + api_key=self.api_key, + base_url=self.api_url, + timeout=self.api_timeout, + max_retries=self.api_retries, + ) @property def client(self) -> openai.OpenAI: @@ -40,7 +51,7 @@ def client(self) -> openai.OpenAI: def instructor(self) -> instructor.Instructor: return instructor.from_openai(self.client) - def get_response(self, prompt, *, capabilities: Dict[str, Capability]=None, **kwargs) -> LLMResult: + def get_response(self, prompt, *, capabilities: Optional[Dict[str, Capability] ] = None, **kwargs) -> LLMResult: """ # TODO: re-enable compatibility layer if isinstance(prompt, str) or hasattr(prompt, "render"): prompt = {"role": "user", "content": prompt} @@ -57,30 +68,38 @@ def get_response(self, prompt, *, capabilities: Dict[str, Capability]=None, **kw if capabilities: tools = capabilities_to_tools(capabilities) - tic = time.perf_counter() + tic = datetime.datetime.now() response = self._client.chat.completions.create( model=self.model, messages=prompt, tools=tools, ) - toc = time.perf_counter() + duration = datetime.datetime.now() - tic message = response.choices[0].message return LLMResult( message, str(prompt), message.content, - toc-tic, + duration, response.usage.prompt_tokens, response.usage.completion_tokens, ) - def stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: Console, capabilities: Dict[str, Capability] = None) -> Iterable[Union[ChatCompletionChunk, LLMResult]]: + def stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: Console, capabilities: Dict[str, Capability] = None, get_individual_updates=False) -> Union[LLMResult, Iterable[Union[ChoiceDelta, LLMResult]]]: + generator = self._stream_response(prompt, console, capabilities) + + if get_individual_updates: + return generator + + return list(generator)[-1] + + def _stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: Console, capabilities: Dict[str, Capability] = None) -> Iterable[Union[ChoiceDelta, LLMResult]]: tools = None if capabilities: tools = capabilities_to_tools(capabilities) - tic = time.perf_counter() + tic = datetime.datetime.now() chunks = self._client.chat.completions.create( model=self.model, messages=prompt, @@ -117,20 +136,31 @@ def stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: for tool_call in delta.tool_calls: if len(message.tool_calls) <= tool_call.index: if len(message.tool_calls) != tool_call.index: - print(f"WARNING: Got a tool call with index {tool_call.index} but expected {len(message.tool_calls)}") + print( + f"WARNING: Got a tool call with index {tool_call.index} but expected {len(message.tool_calls)}" + ) return console.print(f"\n\n[bold red]TOOL CALL - {tool_call.function.name}:[/bold red]") - message.tool_calls.append(ChatCompletionMessageToolCall(id=tool_call.id, function=Function(name=tool_call.function.name, arguments=tool_call.function.arguments), type="function")) + message.tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call.id, + function=Function( + name=tool_call.function.name, arguments=tool_call.function.arguments + ), + type="function", + ) + ) console.print(tool_call.function.arguments, end="") message.tool_calls[tool_call.index].function.arguments += tool_call.function.arguments outputs += 1 + yield delta + if chunk.usage is not None: usage = chunk.usage if outputs > 1: print("WARNING: Got more than one output in the stream response") - yield chunk console.print() if usage is None: @@ -140,16 +170,15 @@ def stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: if len(message.tool_calls) == 0: # the openAI API does not like getting empty tool call lists message.tool_calls = None - toc = time.perf_counter() + toc = datetime.datetime.now() yield LLMResult( message, str(prompt), message.content, - toc-tic, + toc - tic, usage.prompt_tokens, usage.completion_tokens, - ) - pass + ) def encode(self, query) -> list[int]: return tiktoken.encoding_for_model(self.model).encode(query) diff --git a/src/hackingBuddyGPT/utils/openai/openai_llm.py b/src/hackingBuddyGPT/utils/openai/openai_llm.py index befd9251..297aeac6 100644 --- a/src/hackingBuddyGPT/utils/openai/openai_llm.py +++ b/src/hackingBuddyGPT/utils/openai/openai_llm.py @@ -1,11 +1,14 @@ -import requests -import tiktoken import time - +import datetime from dataclasses import dataclass +import requests +import tiktoken +from urllib.parse import urlparse + from hackingBuddyGPT.utils.configurable import configurable, parameter -from hackingBuddyGPT.utils.llm_util import LLMResult, LLM +from hackingBuddyGPT.utils.llm_util import LLM, LLMResult + @configurable("openai-compatible-llm-api", "OpenAI-compatible LLM API") @dataclass @@ -17,32 +20,49 @@ class OpenAIConnection(LLM): If you really must use it, you can import it directly from the utils.openai.openai_llm module, which will later on show you, that you did not specialize yet. """ - api_key: str = parameter(desc="OpenAI API Key") + + api_key: str = parameter(desc="OpenAI API Key", secret=True) model: str = parameter(desc="OpenAI model name") - context_size: int = parameter(desc="Maximum context size for the model, only used internally for things like trimming to the context size") + context_size: int = parameter( + desc="Maximum context size for the model, only used internally for things like trimming to the context size" + ) api_url: str = parameter(desc="URL of the OpenAI API", default="https://api.openai.com") api_path: str = parameter(desc="Path to the OpenAI API", default="/v1/chat/completions") api_timeout: int = parameter(desc="Timeout for the API request", default=240) api_backoff: int = parameter(desc="Backoff time in seconds when running into rate-limits", default=60) api_retries: int = parameter(desc="Number of retries when running into rate-limits", default=3) - def get_response(self, prompt, *, retry: int = 0, **kwargs) -> LLMResult: + def get_response(self, prompt, *, retry: int = 0,azure_retry: int = 0, **kwargs) -> LLMResult: if retry >= self.api_retries: raise Exception("Failed to get response from OpenAI API") if hasattr(prompt, "render"): prompt = prompt.render(**kwargs) - headers = {"Authorization": f"Bearer {self.api_key}"} - data = {'model': self.model, 'messages': [{'role': 'user', 'content': prompt}]} + if urlparse(self.api_url).hostname and urlparse(self.api_url).hostname.endswith(".azure.com"): + # azure ai header + headers = {"api-key": f"{self.api_key}"} + else: + # normal header + headers = {"Authorization": f"Bearer {self.api_key}"} + + data = {"model": self.model, "messages": [{"role": "user", "content": prompt}]} try: - tic = time.perf_counter() + tic = datetime.datetime.now() response = requests.post(f'{self.api_url}{self.api_path}', headers=headers, json=data, timeout=self.api_timeout) + if response.status_code == 429: print(f"[RestAPI-Connector] running into rate-limits, waiting for {self.api_backoff} seconds") time.sleep(self.api_backoff) - return self.get_response(prompt, retry=retry+1) + return self.get_response(prompt, retry=retry + 1) + + if response.status_code == 408: + if azure_retry < self.api_retries: + print("Received 408 Status Code, trying again.") + return self.get_response(prompt, azure_retry = azure_retry + 1) + else: + raise Exception(f"Error from Gateway ({response.status_code}") if response.status_code != 200: raise Exception(f"Error from OpenAI Gateway ({response.status_code}") @@ -50,26 +70,26 @@ def get_response(self, prompt, *, retry: int = 0, **kwargs) -> LLMResult: except requests.exceptions.ConnectionError: print("Connection error! Retrying in 5 seconds..") time.sleep(5) - return self.get_response(prompt, retry=retry+1) + return self.get_response(prompt, retry=retry + 1) except requests.exceptions.Timeout: print("Timeout while contacting LLM REST endpoint") - return self.get_response(prompt, retry=retry+1) + return self.get_response(prompt, retry=retry + 1) # now extract the JSON status message # TODO: error handling.. - toc = time.perf_counter() response = response.json() - result = response['choices'][0]['message']['content'] - tok_query = response['usage']['prompt_tokens'] - tok_res = response['usage']['completion_tokens'] + result = response["choices"][0]["message"]["content"] + tok_query = response["usage"]["prompt_tokens"] + tok_res = response["usage"]["completion_tokens"] + duration = datetime.datetime.now() - tic - return LLMResult(result, prompt, result, toc - tic, tok_query, tok_res) + return LLMResult(result, prompt, result, duration, tok_query, tok_res) def encode(self, query) -> list[int]: # I know this is crappy for all non-openAI models but sadly this # has to be good enough for now - if self.model.startswith("gpt-"): + if self.model.startswith("gpt-") and not self.model.startswith("gpt-4o"): encoding = tiktoken.encoding_for_model(self.model) else: encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") @@ -95,3 +115,31 @@ class GPT4(OpenAIConnection): class GPT4Turbo(OpenAIConnection): model: str = "gpt-4-turbo-preview" context_size: int = 128000 + + +@configurable("openai/gpt-4o", "OpenAI GPT-4o") +@dataclass +class GPT4oMini(OpenAIConnection): + model: str = "gpt-4o" + context_size: int = 128000 + + +@configurable("openai/gpt-4o-mini", "OpenAI GPT-4o-mini") +@dataclass +class GPT4oMini(OpenAIConnection): + model: str = "gpt-4o-mini" + context_size: int = 128000 + + +@configurable("openai/o1-preview", "OpenAI o1-preview") +@dataclass +class O1Preview(OpenAIConnection): + model: str = "o1-preview" + context_size: int = 128000 + + +@configurable("openai/o1-mini", "OpenAI o1-mini") +@dataclass +class O1Mini(OpenAIConnection): + model: str = "o1-mini" + context_size: int = 128000 diff --git a/src/hackingBuddyGPT/utils/psexec/__init__.py b/src/hackingBuddyGPT/utils/psexec/__init__.py index 04c06af4..51a3b367 100644 --- a/src/hackingBuddyGPT/utils/psexec/__init__.py +++ b/src/hackingBuddyGPT/utils/psexec/__init__.py @@ -1 +1,3 @@ from .psexec import PSExecConnection + +__all__ = ["PSExecConnection"] diff --git a/src/hackingBuddyGPT/utils/psexec/psexec.py b/src/hackingBuddyGPT/utils/psexec/psexec.py index dcc95240..822768a2 100644 --- a/src/hackingBuddyGPT/utils/psexec/psexec.py +++ b/src/hackingBuddyGPT/utils/psexec/psexec.py @@ -1,7 +1,8 @@ from dataclasses import dataclass -from pypsexec.client import Client from typing import Tuple +from pypsexec.client import Client + from hackingBuddyGPT.utils.configurable import configurable diff --git a/src/hackingBuddyGPT/utils/shell_root_detection.py b/src/hackingBuddyGPT/utils/shell_root_detection.py index f741ab53..1747d1b3 100644 --- a/src/hackingBuddyGPT/utils/shell_root_detection.py +++ b/src/hackingBuddyGPT/utils/shell_root_detection.py @@ -1,15 +1,11 @@ import re -GOT_ROOT_REGEXPs = [ - re.compile("^# $"), - re.compile("^bash-[0-9]+.[0-9]# $") -] +GOT_ROOT_REGEXPs = [re.compile("^# $"), re.compile("^bash-[0-9]+.[0-9]# $")] def got_root(hostname: str, output: str) -> bool: for i in GOT_ROOT_REGEXPs: if i.fullmatch(output): return True - if output.startswith(f'root@{hostname}:'): - return True - return False + + return output.startswith(f"root@{hostname}:") diff --git a/src/hackingBuddyGPT/utils/ssh_connection/__init__.py b/src/hackingBuddyGPT/utils/ssh_connection/__init__.py index 89f7f349..25febf9a 100644 --- a/src/hackingBuddyGPT/utils/ssh_connection/__init__.py +++ b/src/hackingBuddyGPT/utils/ssh_connection/__init__.py @@ -1 +1,3 @@ from .ssh_connection import SSHConnection + +__all__ = ["SSHConnection"] diff --git a/src/hackingBuddyGPT/utils/ssh_connection/ssh_connection.py b/src/hackingBuddyGPT/utils/ssh_connection/ssh_connection.py index 33bf8557..60cface1 100644 --- a/src/hackingBuddyGPT/utils/ssh_connection/ssh_connection.py +++ b/src/hackingBuddyGPT/utils/ssh_connection/ssh_connection.py @@ -1,8 +1,9 @@ -import invoke from dataclasses import dataclass -from fabric import Connection from typing import Optional, Tuple +import invoke +from fabric import Connection + from hackingBuddyGPT.utils.configurable import configurable @@ -13,25 +14,33 @@ class SSHConnection: hostname: str username: str password: str + keyfilename: str port: int = 22 _conn: Connection = None def init(self): # create the SSH Connection - conn = Connection( - f"{self.username}@{self.host}:{self.port}", - connect_kwargs={"password": self.password, "look_for_keys": False, "allow_agent": False}, - ) + if self.keyfilename == '' or self.keyfilename == None: + conn = Connection( + f"{self.username}@{self.host}:{self.port}", + connect_kwargs={"password": self.password, "look_for_keys": False, "allow_agent": False}, + ) + else: + conn = Connection( + f"{self.username}@{self.host}:{self.port}", + connect_kwargs={"password": self.password, "key_filename": self.keyfilename, "look_for_keys": False, "allow_agent": False}, + ) self._conn = conn self._conn.open() - def new_with(self, *, host=None, hostname=None, username=None, password=None, port=None) -> "SSHConnection": + def new_with(self, *, host=None, hostname=None, username=None, password=None, keyfilename=None, port=None) -> "SSHConnection": return SSHConnection( host=host or self.host, hostname=hostname or self.hostname, username=username or self.username, password=password or self.password, + keyfilename=keyfilename or self.keyfilename, port=port or self.port, ) diff --git a/src/hackingBuddyGPT/utils/ui.py b/src/hackingBuddyGPT/utils/ui.py index 753ec223..20ff85f8 100644 --- a/src/hackingBuddyGPT/utils/ui.py +++ b/src/hackingBuddyGPT/utils/ui.py @@ -2,8 +2,11 @@ from .db_storage.db_storage import DbStorage + # helper to fill the history table with data from the db -def get_history_table(enable_explanation: bool, enable_update_state: bool, run_id: int, db: DbStorage, turn: int) -> Table: +def get_history_table( + enable_explanation: bool, enable_update_state: bool, run_id: int, db: DbStorage, turn: int +) -> Table: table = Table(title="Executed Command History", show_header=True, show_lines=True) table.add_column("ThinkTime", style="dim") table.add_column("Tokens", style="dim") @@ -17,7 +20,7 @@ def get_history_table(enable_explanation: bool, enable_update_state: bool, run_i table.add_column("StateUpdTime", style="dim") table.add_column("StateUpdTokens", style="dim") - for i in range(1, turn+1): + for i in range(1, turn + 1): table.add_row(*db.get_round_data(run_id, i, enable_explanation, enable_update_state)) return table diff --git a/tests/integration_minimal_test.py b/tests/integration_minimal_test.py index 8eb95871..96909c79 100644 --- a/tests/integration_minimal_test.py +++ b/tests/integration_minimal_test.py @@ -1,7 +1,14 @@ - from typing import Tuple -from hackingBuddyGPT.usecases.examples.agent import ExPrivEscLinux, ExPrivEscLinuxUseCase -from hackingBuddyGPT.usecases.examples.agent_with_state import ExPrivEscLinuxTemplated, ExPrivEscLinuxTemplatedUseCase + +from hackingBuddyGPT.utils.logging import LocalLogger +from hackingBuddyGPT.usecases.examples.agent import ( + ExPrivEscLinux, + ExPrivEscLinuxUseCase, +) +from hackingBuddyGPT.usecases.examples.agent_with_state import ( + ExPrivEscLinuxTemplated, + ExPrivEscLinuxTemplatedUseCase, +) from hackingBuddyGPT.usecases.privesc.linux import LinuxPrivesc, LinuxPrivescUseCase from hackingBuddyGPT.utils.console.console import Console from hackingBuddyGPT.utils.db_storage.db_storage import DbStorage @@ -9,9 +16,9 @@ class FakeSSHConnection: - username : str = 'lowpriv' - password : str = 'toomanysecrets' - hostname : str = 'theoneandonly' + username: str = "lowpriv" + password: str = "toomanysecrets" + hostname: str = "theoneandonly" results = { "id": "uid=1001(lowpriv) gid=1001(lowpriv) groups=1001(lowpriv)", @@ -31,111 +38,115 @@ class FakeSSHConnection: │ /usr/lib/dbus-1.0/dbus-daemon-launch-helper │ /usr/lib/openssh/ssh-keysign """, - "/usr/bin/python3.11 -c 'import os; os.setuid(0); os.system(\"/bin/sh\")'": "# " + "/usr/bin/python3.11 -c 'import os; os.setuid(0); os.system(\"/bin/sh\")'": "# ", } def run(self, cmd, *args, **kwargs) -> Tuple[str, str, int]: + out_stream = kwargs.get("out_stream", None) - out_stream = kwargs.get('out_stream', None) - - if cmd in self.results.keys(): + if cmd in self.results: out_stream.write(self.results[cmd]) - return self.results[cmd], '', 0 + return self.results[cmd], "", 0 else: - return '', 'Command not found', 1 + return "", "Command not found", 1 + class FakeLLM(LLM): - model:str = 'fake_model' - context_size:int = 4096 + model: str = "fake_model" + context_size: int = 4096 - counter:int = 0 + counter: int = 0 responses = [ "id", "sudo -l", "find / -perm -4000 2>/dev/null", - "/usr/bin/python3.11 -c 'import os; os.setuid(0); os.system(\"/bin/sh\")'" + "/usr/bin/python3.11 -c 'import os; os.setuid(0); os.system(\"/bin/sh\")'", ] def get_response(self, prompt, *, capabilities=None, **kwargs) -> LLMResult: response = self.responses[self.counter] self.counter += 1 - return LLMResult(result=response, prompt='this would be the prompt', answer=response) + return LLMResult(result=response, prompt="this would be the prompt", answer=response) def encode(self, query) -> list[int]: return [0] -def test_linuxprivesc(): +def test_linuxprivesc(): conn = FakeSSHConnection() llm = FakeLLM() - log_db = DbStorage(':memory:') + log_db = DbStorage(":memory:") console = Console() log_db.init() + log = LocalLogger( + log_db=log_db, + console=console, + tag="integration_test_linuxprivesc", + ) priv_esc = LinuxPrivescUseCase( - agent = LinuxPrivesc( + agent=LinuxPrivesc( conn=conn, enable_explanation=False, disable_history=False, - hint='', - llm = llm, + hint="", + llm=llm, + log=log, ), - log_db = log_db, - console = console, - tag = 'integration_test_linuxprivesc', - max_turns = len(llm.responses) + log=log, + max_turns=len(llm.responses), ) priv_esc.init() - result = priv_esc.run() + result = priv_esc.run({}) assert result is True -def test_minimal_agent(): +def test_minimal_agent(): conn = FakeSSHConnection() llm = FakeLLM() - log_db = DbStorage(':memory:') + log_db = DbStorage(":memory:") console = Console() log_db.init() + log = LocalLogger( + log_db=log_db, + console=console, + tag="integration_test_minimallinuxprivesc", + ) priv_esc = ExPrivEscLinuxUseCase( - agent = ExPrivEscLinux( - conn=conn, - llm=llm - ), - log_db = log_db, - console = console, - tag = 'integration_test_minimallinuxprivesc', - max_turns = len(llm.responses) + agent=ExPrivEscLinux(conn=conn, llm=llm, log=log), + log=log, + max_turns=len(llm.responses) ) priv_esc.init() - result = priv_esc.run() + result = priv_esc.run({}) assert result is True -def test_minimal_agent_state(): +def test_minimal_agent_state(): conn = FakeSSHConnection() llm = FakeLLM() - log_db = DbStorage(':memory:') + log_db = DbStorage(":memory:") console = Console() log_db.init() + log = LocalLogger( + log_db=log_db, + console=console, + tag="integration_test_linuxprivesc", + ) priv_esc = ExPrivEscLinuxTemplatedUseCase( - agent = ExPrivEscLinuxTemplated( - conn=conn, - llm = llm, - ), - log_db = log_db, - console = console, - tag = 'integration_test_linuxprivesc', - max_turns = len(llm.responses) + agent=ExPrivEscLinuxTemplated(conn=conn, llm=llm, log=log), + log=log, + max_turns=len(llm.responses) ) priv_esc.init() - result = priv_esc.run() - assert result is True \ No newline at end of file + result = priv_esc.run({}) + assert result is True diff --git a/tests/test_llm_handler.py b/tests/test_llm_handler.py index 2c9078d1..9e1447ad 100644 --- a/tests/test_llm_handler.py +++ b/tests/test_llm_handler.py @@ -1,15 +1,16 @@ import unittest from unittest.mock import MagicMock + from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler class TestLLMHandler(unittest.TestCase): def setUp(self): self.llm_mock = MagicMock() - self.capabilities = {'cap1': MagicMock(), 'cap2': MagicMock()} + self.capabilities = {"cap1": MagicMock(), "cap2": MagicMock()} self.llm_handler = LLMHandler(self.llm_mock, self.capabilities) - '''@patch('hackingBuddyGPT.usecases.web_api_testing.utils.capabilities_to_action_model') + """@patch('hackingBuddyGPT.usecases.web_api_testing.utils.capabilities_to_action_model') def test_call_llm(self, mock_capabilities_to_action_model): prompt = [{'role': 'user', 'content': 'Hello, LLM!'}] response_mock = MagicMock() @@ -26,10 +27,11 @@ def test_call_llm(self, mock_capabilities_to_action_model): messages=prompt, response_model=mock_model ) - self.assertEqual(response, response_mock)''' + self.assertEqual(response, response_mock)""" + def test_add_created_object(self): created_object = MagicMock() - object_type = 'test_type' + object_type = "test_type" self.llm_handler.add_created_object(created_object, object_type) @@ -38,7 +40,7 @@ def test_add_created_object(self): def test_add_created_object_limit(self): created_object = MagicMock() - object_type = 'test_type' + object_type = "test_type" for _ in range(8): # Exceed the limit of 7 objects self.llm_handler.add_created_object(created_object, object_type) @@ -47,7 +49,7 @@ def test_add_created_object_limit(self): def test_get_created_objects(self): created_object = MagicMock() - object_type = 'test_type' + object_type = "test_type" self.llm_handler.add_created_object(created_object, object_type) created_objects = self.llm_handler.get_created_objects() @@ -56,5 +58,6 @@ def test_get_created_objects(self): self.assertIn(created_object, created_objects[object_type]) self.assertEqual(created_objects, self.llm_handler.created_objects) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_openAPI_specification_manager.py b/tests/test_openAPI_specification_manager.py index bc9fade7..e6088c00 100644 --- a/tests/test_openAPI_specification_manager.py +++ b/tests/test_openAPI_specification_manager.py @@ -2,7 +2,9 @@ from unittest.mock import MagicMock, patch from hackingBuddyGPT.capabilities.http_request import HTTPRequest -from hackingBuddyGPT.usecases.web_api_testing.documentation.openapi_specification_handler import OpenAPISpecificationHandler +from hackingBuddyGPT.usecases.web_api_testing.documentation.openapi_specification_handler import ( + OpenAPISpecificationHandler, +) class TestSpecificationHandler(unittest.TestCase): @@ -11,19 +13,17 @@ def setUp(self): self.response_handler = MagicMock() self.doc_handler = OpenAPISpecificationHandler(self.llm_handler, self.response_handler) - @patch('os.makedirs') - @patch('builtins.open') + @patch("os.makedirs") + @patch("builtins.open") def test_write_openapi_to_yaml(self, mock_open, mock_makedirs): self.doc_handler.write_openapi_to_yaml() mock_makedirs.assert_called_once_with(self.doc_handler.file_path, exist_ok=True) - mock_open.assert_called_once_with(self.doc_handler.file, 'w') + mock_open.assert_called_once_with(self.doc_handler.file, "w") # Create a mock HTTPRequest object response_mock = MagicMock() response_mock.action = HTTPRequest( - host="https://jsonplaceholder.typicode.com", - follow_redirects=False, - use_cookie_jar=True + host="https://jsonplaceholder.typicode.com", follow_redirects=False, use_cookie_jar=True ) response_mock.action.method = "GET" response_mock.action.path = "/test" @@ -38,11 +38,11 @@ def test_write_openapi_to_yaml(self, mock_open, mock_makedirs): self.assertIn("/test", self.doc_handler.openapi_spec["endpoints"]) self.assertIn("get", self.doc_handler.openapi_spec["endpoints"]["/test"]) - self.assertEqual(self.doc_handler.openapi_spec["endpoints"]["/test"]["get"]["summary"], - "GET operation on /test") + self.assertEqual( + self.doc_handler.openapi_spec["endpoints"]["/test"]["get"]["summary"], "GET operation on /test" + ) self.assertEqual(endpoints, ["/test"]) - def test_partial_match(self): string_list = ["test_endpoint", "another_endpoint"] self.assertTrue(self.doc_handler.is_partial_match("test", string_list)) diff --git a/tests/test_openapi_converter.py b/tests/test_openapi_converter.py index c9b086e7..f4609d1e 100644 --- a/tests/test_openapi_converter.py +++ b/tests/test_openapi_converter.py @@ -1,8 +1,10 @@ -import unittest -from unittest.mock import patch, mock_open import os +import unittest +from unittest.mock import mock_open, patch -from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing.openapi_converter import OpenAPISpecificationConverter +from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing.openapi_converter import ( + OpenAPISpecificationConverter, +) class TestOpenAPISpecificationConverter(unittest.TestCase): @@ -22,9 +24,9 @@ def test_convert_file_yaml_to_json(self, mock_json_dump, mock_yaml_safe_load, mo result = self.converter.convert_file(input_filepath, output_directory, input_type, output_type) - mock_open_file.assert_any_call(input_filepath, 'r') + mock_open_file.assert_any_call(input_filepath, "r") mock_yaml_safe_load.assert_called_once() - mock_open_file.assert_any_call(expected_output_path, 'w') + mock_open_file.assert_any_call(expected_output_path, "w") mock_json_dump.assert_called_once_with({"key": "value"}, mock_open_file(), indent=2) mock_makedirs.assert_called_once_with(os.path.join("base_directory", output_directory), exist_ok=True) self.assertEqual(result, expected_output_path) @@ -42,10 +44,12 @@ def test_convert_file_json_to_yaml(self, mock_yaml_dump, mock_json_load, mock_op result = self.converter.convert_file(input_filepath, output_directory, input_type, output_type) - mock_open_file.assert_any_call(input_filepath, 'r') + mock_open_file.assert_any_call(input_filepath, "r") mock_json_load.assert_called_once() - mock_open_file.assert_any_call(expected_output_path, 'w') - mock_yaml_dump.assert_called_once_with({"key": "value"}, mock_open_file(), allow_unicode=True, default_flow_style=False) + mock_open_file.assert_any_call(expected_output_path, "w") + mock_yaml_dump.assert_called_once_with( + {"key": "value"}, mock_open_file(), allow_unicode=True, default_flow_style=False + ) mock_makedirs.assert_called_once_with(os.path.join("base_directory", output_directory), exist_ok=True) self.assertEqual(result, expected_output_path) @@ -60,7 +64,7 @@ def test_convert_file_yaml_to_json_error(self, mock_yaml_safe_load, mock_open_fi result = self.converter.convert_file(input_filepath, output_directory, input_type, output_type) - mock_open_file.assert_any_call(input_filepath, 'r') + mock_open_file.assert_any_call(input_filepath, "r") mock_yaml_safe_load.assert_called_once() mock_makedirs.assert_called_once_with(os.path.join("base_directory", output_directory), exist_ok=True) self.assertIsNone(result) @@ -76,10 +80,11 @@ def test_convert_file_json_to_yaml_error(self, mock_json_load, mock_open_file, m result = self.converter.convert_file(input_filepath, output_directory, input_type, output_type) - mock_open_file.assert_any_call(input_filepath, 'r') + mock_open_file.assert_any_call(input_filepath, "r") mock_json_load.assert_called_once() mock_makedirs.assert_called_once_with(os.path.join("base_directory", output_directory), exist_ok=True) self.assertIsNone(result) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_openapi_parser.py b/tests/test_openapi_parser.py index fb7bb1c3..a4f73443 100644 --- a/tests/test_openapi_parser.py +++ b/tests/test_openapi_parser.py @@ -1,8 +1,11 @@ import unittest -from unittest.mock import patch, mock_open +from unittest.mock import mock_open, patch + import yaml -from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing import OpenAPISpecificationParser +from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing import ( + OpenAPISpecificationParser, +) class TestOpenAPISpecificationParser(unittest.TestCase): @@ -37,7 +40,10 @@ def setUp(self): """ @patch("builtins.open", new_callable=mock_open, read_data="") - @patch("yaml.safe_load", return_value=yaml.safe_load(""" + @patch( + "yaml.safe_load", + return_value=yaml.safe_load( + """ openapi: 3.0.0 info: title: Sample API @@ -63,15 +69,20 @@ def setUp(self): responses: '200': description: Expected response to a valid request - """)) + """ + ), + ) def test_load_yaml(self, mock_yaml_load, mock_open_file): parser = OpenAPISpecificationParser(self.filepath) - self.assertEqual(parser.api_data['info']['title'], "Sample API") - self.assertEqual(parser.api_data['info']['version'], "1.0.0") - self.assertEqual(len(parser.api_data['servers']), 2) + self.assertEqual(parser.api_data["info"]["title"], "Sample API") + self.assertEqual(parser.api_data["info"]["version"], "1.0.0") + self.assertEqual(len(parser.api_data["servers"]), 2) @patch("builtins.open", new_callable=mock_open, read_data="") - @patch("yaml.safe_load", return_value=yaml.safe_load(""" + @patch( + "yaml.safe_load", + return_value=yaml.safe_load( + """ openapi: 3.0.0 info: title: Sample API @@ -97,14 +108,19 @@ def test_load_yaml(self, mock_yaml_load, mock_open_file): responses: '200': description: Expected response to a valid request - """)) + """ + ), + ) def test_get_servers(self, mock_yaml_load, mock_open_file): parser = OpenAPISpecificationParser(self.filepath) servers = parser._get_servers() self.assertEqual(servers, ["https://api.example.com", "https://staging.api.example.com"]) @patch("builtins.open", new_callable=mock_open, read_data="") - @patch("yaml.safe_load", return_value=yaml.safe_load(""" + @patch( + "yaml.safe_load", + return_value=yaml.safe_load( + """ openapi: 3.0.0 info: title: Sample API @@ -130,7 +146,9 @@ def test_get_servers(self, mock_yaml_load, mock_open_file): responses: '200': description: Expected response to a valid request - """)) + """ + ), + ) def test_get_paths(self, mock_yaml_load, mock_open_file): parser = OpenAPISpecificationParser(self.filepath) paths = parser.get_paths() @@ -138,36 +156,24 @@ def test_get_paths(self, mock_yaml_load, mock_open_file): "/pets": { "get": { "summary": "List all pets", - "responses": { - "200": { - "description": "A paged array of pets" - } - } + "responses": {"200": {"description": "A paged array of pets"}}, }, - "post": { - "summary": "Create a pet", - "responses": { - "200": { - "description": "Pet created" - } - } - } + "post": {"summary": "Create a pet", "responses": {"200": {"description": "Pet created"}}}, }, "/pets/{petId}": { "get": { "summary": "Info for a specific pet", - "responses": { - "200": { - "description": "Expected response to a valid request" - } - } + "responses": {"200": {"description": "Expected response to a valid request"}}, } - } + }, } self.assertEqual(paths, expected_paths) @patch("builtins.open", new_callable=mock_open, read_data="") - @patch("yaml.safe_load", return_value=yaml.safe_load(""" + @patch( + "yaml.safe_load", + return_value=yaml.safe_load( + """ openapi: 3.0.0 info: title: Sample API @@ -193,32 +199,26 @@ def test_get_paths(self, mock_yaml_load, mock_open_file): responses: '200': description: Expected response to a valid request - """)) + """ + ), + ) def test_get_operations(self, mock_yaml_load, mock_open_file): parser = OpenAPISpecificationParser(self.filepath) operations = parser._get_operations("/pets") expected_operations = { "get": { "summary": "List all pets", - "responses": { - "200": { - "description": "A paged array of pets" - } - } + "responses": {"200": {"description": "A paged array of pets"}}, }, - "post": { - "summary": "Create a pet", - "responses": { - "200": { - "description": "Pet created" - } - } - } + "post": {"summary": "Create a pet", "responses": {"200": {"description": "Pet created"}}}, } self.assertEqual(operations, expected_operations) @patch("builtins.open", new_callable=mock_open, read_data="") - @patch("yaml.safe_load", return_value=yaml.safe_load(""" + @patch( + "yaml.safe_load", + return_value=yaml.safe_load( + """ openapi: 3.0.0 info: title: Sample API @@ -244,15 +244,18 @@ def test_get_operations(self, mock_yaml_load, mock_open_file): responses: '200': description: Expected response to a valid request - """)) + """ + ), + ) def test_print_api_details(self, mock_yaml_load, mock_open_file): parser = OpenAPISpecificationParser(self.filepath) - with patch('builtins.print') as mocked_print: + with patch("builtins.print") as mocked_print: parser._print_api_details() mocked_print.assert_any_call("API Title:", "Sample API") mocked_print.assert_any_call("API Version:", "1.0.0") mocked_print.assert_any_call("Servers:", ["https://api.example.com", "https://staging.api.example.com"]) mocked_print.assert_any_call("\nAvailable Paths and Operations:") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_prompt_engineer_documentation.py b/tests/test_prompt_engineer_documentation.py index 22d24b9c..daeedbbd 100644 --- a/tests/test_prompt_engineer_documentation.py +++ b/tests/test_prompt_engineer_documentation.py @@ -1,9 +1,16 @@ import unittest from unittest.mock import MagicMock -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import PromptStrategy, PromptEngineer -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptContext + from openai.types.chat import ChatCompletionMessage +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PromptContext, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import ( + PromptEngineer, + PromptStrategy, +) + class TestPromptEngineer(unittest.TestCase): def setUp(self): @@ -13,11 +20,12 @@ def setUp(self): self.schemas = MagicMock() self.response_handler = MagicMock() self.prompt_engineer = PromptEngineer( - strategy=self.strategy, handlers=(self.llm_handler, self.response_handler), history=self.history, - context=PromptContext.DOCUMENTATION + strategy=self.strategy, + handlers=(self.llm_handler, self.response_handler), + history=self.history, + context=PromptContext.DOCUMENTATION, ) - def test_in_context_learning_no_hint(self): self.prompt_engineer.strategy = PromptStrategy.IN_CONTEXT expected_prompt = "initial_prompt\ninitial_prompt" @@ -36,7 +44,8 @@ def test_in_context_learning_with_doc_and_hint(self): hint = "This is another hint." expected_prompt = "initial_prompt\ninitial_prompt\nThis is another hint." actual_prompt = self.prompt_engineer.generate_prompt(hint=hint, turn=1) - self.assertEqual(expected_prompt, actual_prompt[1]["content"]) + self.assertEqual(expected_prompt, actual_prompt[1]["content"]) + def test_generate_prompt_chain_of_thought(self): self.prompt_engineer.strategy = PromptStrategy.CHAIN_OF_THOUGHT self.response_handler.get_response_for_prompt = MagicMock(return_value="response_text") @@ -44,7 +53,7 @@ def test_generate_prompt_chain_of_thought(self): prompt_history = self.prompt_engineer.generate_prompt(turn=1) - self.assertEqual( 2, len(prompt_history)) + self.assertEqual(2, len(prompt_history)) def test_generate_prompt_tree_of_thought(self): # Set the strategy to TREE_OF_THOUGHT @@ -55,7 +64,7 @@ def test_generate_prompt_tree_of_thought(self): # Create mock previous prompts with valid roles previous_prompts = [ ChatCompletionMessage(role="assistant", content="initial_prompt"), - ChatCompletionMessage(role="assistant", content="previous_prompt") + ChatCompletionMessage(role="assistant", content="previous_prompt"), ] # Assign the previous prompts to prompt_engineer._prompt_history @@ -69,4 +78,4 @@ def test_generate_prompt_tree_of_thought(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_prompt_engineer_testing.py b/tests/test_prompt_engineer_testing.py index 7fba2f3b..198bbbc6 100644 --- a/tests/test_prompt_engineer_testing.py +++ b/tests/test_prompt_engineer_testing.py @@ -1,9 +1,16 @@ import unittest from unittest.mock import MagicMock -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import PromptStrategy, PromptEngineer -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptContext + from openai.types.chat import ChatCompletionMessage +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PromptContext, +) +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import ( + PromptEngineer, + PromptStrategy, +) + class TestPromptEngineer(unittest.TestCase): def setUp(self): @@ -13,11 +20,12 @@ def setUp(self): self.schemas = MagicMock() self.response_handler = MagicMock() self.prompt_engineer = PromptEngineer( - strategy=self.strategy, handlers=(self.llm_handler, self.response_handler), history=self.history, - context=PromptContext.PENTESTING + strategy=self.strategy, + handlers=(self.llm_handler, self.response_handler), + history=self.history, + context=PromptContext.PENTESTING, ) - def test_in_context_learning_no_hint(self): self.prompt_engineer.strategy = PromptStrategy.IN_CONTEXT expected_prompt = "initial_prompt\ninitial_prompt" @@ -36,7 +44,8 @@ def test_in_context_learning_with_doc_and_hint(self): hint = "This is another hint." expected_prompt = "initial_prompt\ninitial_prompt\nThis is another hint." actual_prompt = self.prompt_engineer.generate_prompt(hint=hint, turn=1) - self.assertEqual(expected_prompt, actual_prompt[1]["content"]) + self.assertEqual(expected_prompt, actual_prompt[1]["content"]) + def test_generate_prompt_chain_of_thought(self): self.prompt_engineer.strategy = PromptStrategy.CHAIN_OF_THOUGHT self.response_handler.get_response_for_prompt = MagicMock(return_value="response_text") @@ -44,7 +53,7 @@ def test_generate_prompt_chain_of_thought(self): prompt_history = self.prompt_engineer.generate_prompt(turn=1) - self.assertEqual( 2, len(prompt_history)) + self.assertEqual(2, len(prompt_history)) def test_generate_prompt_tree_of_thought(self): # Set the strategy to TREE_OF_THOUGHT @@ -55,7 +64,7 @@ def test_generate_prompt_tree_of_thought(self): # Create mock previous prompts with valid roles previous_prompts = [ ChatCompletionMessage(role="assistant", content="initial_prompt"), - ChatCompletionMessage(role="assistant", content="previous_prompt") + ChatCompletionMessage(role="assistant", content="previous_prompt"), ] # Assign the previous prompts to prompt_engineer._prompt_history @@ -68,7 +77,5 @@ def test_generate_prompt_tree_of_thought(self): self.assertEqual(len(prompt_history), 3) # Adjust to 3 if previous prompt exists + new prompt - - if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_prompt_generation_helper.py b/tests/test_prompt_generation_helper.py index 2192d21a..06aca3b4 100644 --- a/tests/test_prompt_generation_helper.py +++ b/tests/test_prompt_generation_helper.py @@ -1,6 +1,9 @@ import unittest from unittest.mock import MagicMock -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_generation_helper import PromptGenerationHelper + +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_generation_helper import ( + PromptGenerationHelper, +) class TestPromptHelper(unittest.TestCase): @@ -8,16 +11,15 @@ def setUp(self): self.response_handler = MagicMock() self.prompt_helper = PromptGenerationHelper(self.response_handler) - def test_check_prompt(self): self.response_handler.get_response_for_prompt = MagicMock(return_value="shortened_prompt") prompt = self.prompt_helper.check_prompt( - previous_prompt="previous_prompt", steps=["step1", "step2", "step3", "step4", "step5", "step6"], - max_tokens=2) + previous_prompt="previous_prompt", + steps=["step1", "step2", "step3", "step4", "step5", "step6"], + max_tokens=2, + ) self.assertEqual("shortened_prompt", prompt) - - if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_response_analyzer.py b/tests/test_response_analyzer.py index fd41640f..0c621bcf 100644 --- a/tests/test_response_analyzer.py +++ b/tests/test_response_analyzer.py @@ -1,12 +1,15 @@ import unittest from unittest.mock import patch -from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_analyzer import ResponseAnalyzer -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptPurpose +from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( + PromptPurpose, +) +from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_analyzer import ( + ResponseAnalyzer, +) class TestResponseAnalyzer(unittest.TestCase): - def setUp(self): # Example HTTP response to use in tests self.raw_http_response = """HTTP/1.1 404 Not Found @@ -29,37 +32,38 @@ def test_parse_http_response(self): status_code, headers, body = analyzer.parse_http_response(self.raw_http_response) self.assertEqual(status_code, 404) - self.assertEqual(headers['Content-Type'], 'application/json; charset=utf-8') - self.assertEqual(body, 'Empty') + self.assertEqual(headers["Content-Type"], "application/json; charset=utf-8") + self.assertEqual(body, "Empty") def test_analyze_authentication_authorization(self): analyzer = ResponseAnalyzer(PromptPurpose.AUTHENTICATION_AUTHORIZATION) analysis = analyzer.analyze_response(self.raw_http_response) - self.assertEqual(analysis['status_code'], 404) - self.assertEqual(analysis['authentication_status'], 'Unknown') - self.assertTrue(analysis['content_body'], 'Empty') - self.assertIn('X-Ratelimit-Limit', analysis['rate_limiting']) + self.assertEqual(analysis["status_code"], 404) + self.assertEqual(analysis["authentication_status"], "Unknown") + self.assertTrue(analysis["content_body"], "Empty") + self.assertIn("X-Ratelimit-Limit", analysis["rate_limiting"]) def test_analyze_input_validation(self): analyzer = ResponseAnalyzer(PromptPurpose.INPUT_VALIDATION) analysis = analyzer.analyze_response(self.raw_http_response) - self.assertEqual(analysis['status_code'], 404) - self.assertEqual(analysis['is_valid_response'], 'Error') - self.assertTrue(analysis['response_body'], 'Empty') - self.assertIn('security_headers_present', analysis) + self.assertEqual(analysis["status_code"], 404) + self.assertEqual(analysis["is_valid_response"], "Error") + self.assertTrue(analysis["response_body"], "Empty") + self.assertIn("security_headers_present", analysis) - @patch('builtins.print') + @patch("builtins.print") def test_print_analysis(self, mock_print): analyzer = ResponseAnalyzer(PromptPurpose.INPUT_VALIDATION) analysis = analyzer.analyze_response(self.raw_http_response) - analysis_str =analyzer.print_analysis(analysis) + analysis_str = analyzer.print_analysis(analysis) # Check that the correct calls were made to print self.assertIn("HTTP Status Code: 404", analysis_str) self.assertIn("Response Body: Empty", analysis_str) self.assertIn("Security Headers Present: Yes", analysis_str) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_response_handler.py b/tests/test_response_handler.py index c72572c7..31a223de 100644 --- a/tests/test_response_handler.py +++ b/tests/test_response_handler.py @@ -1,7 +1,9 @@ import unittest from unittest.mock import MagicMock, patch -from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_handler import ResponseHandler +from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_handler import ( + ResponseHandler, +) class TestResponseHandler(unittest.TestCase): @@ -17,7 +19,9 @@ def test_get_response_for_prompt(self): response_text = self.response_handler.get_response_for_prompt(prompt) - self.llm_handler_mock.call_llm.assert_called_once_with([{"role": "user", "content": [{"type": "text", "text": prompt}]}]) + self.llm_handler_mock.call_llm.assert_called_once_with( + [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + ) self.assertEqual(response_text, "Response text") def test_parse_http_status_line_valid(self): @@ -47,18 +51,20 @@ def test_extract_response_example_invalid(self): result = self.response_handler.extract_response_example(html_content) self.assertIsNone(result) - @patch('hackingBuddyGPT.usecases.web_api_testing.response_processing.ResponseHandler.parse_http_response_to_openapi_example') + @patch( + "hackingBuddyGPT.usecases.web_api_testing.response_processing.ResponseHandler.parse_http_response_to_openapi_example" + ) def test_parse_http_response_to_openapi_example(self, mock_parse_http_response_to_schema): - openapi_spec = { - "components": {"schemas": {}} - } - http_response = "HTTP/1.1 200 OK\r\n\r\n{\"id\": 1, \"name\": \"test\"}" + openapi_spec = {"components": {"schemas": {}}} + http_response = 'HTTP/1.1 200 OK\r\n\r\n{"id": 1, "name": "test"}' path = "/test" method = "GET" mock_parse_http_response_to_schema.return_value = ("#/components/schemas/Test", "Test", openapi_spec) - entry_dict, reference, updated_spec = self.response_handler.parse_http_response_to_openapi_example(openapi_spec, http_response, path, method) + entry_dict, reference, updated_spec = self.response_handler.parse_http_response_to_openapi_example( + openapi_spec, http_response, path, method + ) self.assertEqual(reference, "Test") self.assertEqual(updated_spec, openapi_spec) @@ -72,29 +78,26 @@ def test_extract_description(self): from unittest.mock import patch - @patch('hackingBuddyGPT.usecases.web_api_testing.response_processing.ResponseHandler.parse_http_response_to_schema') + @patch("hackingBuddyGPT.usecases.web_api_testing.response_processing.ResponseHandler.parse_http_response_to_schema") def test_parse_http_response_to_schema(self, mock_parse_http_response_to_schema): - openapi_spec = { - "components": {"schemas": {}} - } + openapi_spec = {"components": {"schemas": {}}} body_dict = {"id": 1, "name": "test"} path = "/tests" def mock_side_effect(spec, body, path): schema_name = "Test" - spec['components']['schemas'][schema_name] = { + spec["components"]["schemas"][schema_name] = { "type": "object", - "properties": { - key: {"type": type(value).__name__, "example": value} for key, value in body.items() - } + "properties": {key: {"type": type(value).__name__, "example": value} for key, value in body.items()}, } reference = f"#/components/schemas/{schema_name}" return reference, schema_name, spec mock_parse_http_response_to_schema.side_effect = mock_side_effect - reference, object_name, updated_spec = self.response_handler.parse_http_response_to_schema(openapi_spec, - body_dict, path) + reference, object_name, updated_spec = self.response_handler.parse_http_response_to_schema( + openapi_spec, body_dict, path + ) self.assertEqual(reference, "#/components/schemas/Test") self.assertEqual(object_name, "Test") @@ -102,12 +105,12 @@ def mock_side_effect(spec, body, path): self.assertIn("id", updated_spec["components"]["schemas"]["Test"]["properties"]) self.assertIn("name", updated_spec["components"]["schemas"]["Test"]["properties"]) - @patch('builtins.open', new_callable=unittest.mock.mock_open, read_data='yaml_content') + @patch("builtins.open", new_callable=unittest.mock.mock_open, read_data="yaml_content") def test_read_yaml_to_string(self, mock_open): filepath = "test.yaml" result = self.response_handler.read_yaml_to_string(filepath) - mock_open.assert_called_once_with(filepath, 'r') - self.assertEqual(result, 'yaml_content') + mock_open.assert_called_once_with(filepath, "r") + self.assertEqual(result, "yaml_content") def test_read_yaml_to_string_file_not_found(self): filepath = "nonexistent.yaml" @@ -117,7 +120,7 @@ def test_read_yaml_to_string_file_not_found(self): def test_extract_endpoints(self): note = "1. GET /test\n" result = self.response_handler.extract_endpoints(note) - self.assertEqual( {'/test': ['GET']}, result) + self.assertEqual({"/test": ["GET"]}, result) def test_extract_keys(self): key = "name" @@ -127,5 +130,6 @@ def test_extract_keys(self): self.assertIn(key, result) self.assertEqual(result[key], {"type": "str", "example": "test"}) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_root_detection.py b/tests/test_root_detection.py index 9567e680..881d8ae9 100644 --- a/tests/test_root_detection.py +++ b/tests/test_root_detection.py @@ -1,5 +1,6 @@ from hackingBuddyGPT.utils.shell_root_detection import got_root + def test_got_root(): hostname = "i_dont_care" diff --git a/tests/test_web_api_documentation.py b/tests/test_web_api_documentation.py index 0cf00ffe..03f79127 100644 --- a/tests/test_web_api_documentation.py +++ b/tests/test_web_api_documentation.py @@ -1,28 +1,34 @@ import unittest from unittest.mock import MagicMock, patch -from hackingBuddyGPT.usecases.web_api_testing.simple_openapi_documentation import SimpleWebAPIDocumentationUseCase, \ - SimpleWebAPIDocumentation -from hackingBuddyGPT.utils import DbStorage, Console +from hackingBuddyGPT.utils.logging import LocalLogger +from hackingBuddyGPT.usecases.web_api_testing.simple_openapi_documentation import ( + SimpleWebAPIDocumentation, + SimpleWebAPIDocumentationUseCase, +) +from hackingBuddyGPT.utils import Console, DbStorage -class TestSimpleWebAPIDocumentationTest(unittest.TestCase): - @patch('hackingBuddyGPT.utils.openai.openai_lib.OpenAILib') +class TestSimpleWebAPIDocumentationTest(unittest.TestCase): + @patch("hackingBuddyGPT.utils.openai.openai_lib.OpenAILib") def setUp(self, MockOpenAILib): # Mock the OpenAILib instance self.mock_llm = MockOpenAILib.return_value - log_db = DbStorage(':memory:') + log_db = DbStorage(":memory:") console = Console() log_db.init() - self.agent = SimpleWebAPIDocumentation(llm=self.mock_llm) + log = LocalLogger( + log_db=log_db, + console=console, + tag="webApiDocumentation", + ) + self.agent = SimpleWebAPIDocumentation(llm=self.mock_llm, log=log) self.agent.init() self.simple_api_testing = SimpleWebAPIDocumentationUseCase( agent=self.agent, - log_db=log_db, - console=console, - tag='webApiDocumentation', - max_turns=len(self.mock_llm.responses) + log=log, + max_turns=len(self.mock_llm.responses), ) self.simple_api_testing.init() @@ -30,15 +36,15 @@ def test_initial_prompt(self): # Test if the initial prompt is set correctly expected_prompt = "You're tasked with documenting the REST APIs of a website hosted at https://jsonplaceholder.typicode.com. Start with an empty OpenAPI specification.\nMaintain meticulousness in documenting your observations as you traverse the APIs." - self.assertIn(expected_prompt, self.agent._prompt_history[0]['content']) + self.assertIn(expected_prompt, self.agent._prompt_history[0]["content"]) def test_all_flags_found(self): # Mock console.print to suppress output during testing - with patch('rich.console.Console.print'): + with patch("rich.console.Console.print"): self.agent.all_http_methods_found(1) self.assertFalse(self.agent.all_http_methods_found(1)) - @patch('time.perf_counter', side_effect=[1, 2]) # Mocking perf_counter for consistent timing + @patch("time.perf_counter", side_effect=[1, 2]) # Mocking perf_counter for consistent timing def test_perform_round(self, mock_perf_counter): # Prepare mock responses mock_response = MagicMock() @@ -52,7 +58,9 @@ def test_perform_round(self, mock_perf_counter): # Mock the OpenAI LLM response self.agent.llm.instructor.chat.completions.create_with_completion.return_value = ( - mock_response, mock_completion) + mock_response, + mock_completion, + ) # Mock the tool execution result mock_response.execute.return_value = "HTTP/1.1 200 OK" @@ -71,5 +79,6 @@ def test_perform_round(self, mock_perf_counter): # Check if the prompt history was updated correctly self.assertGreaterEqual(len(self.agent._prompt_history), 1) # Initial message + LLM response + tool message -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_web_api_testing.py b/tests/test_web_api_testing.py index 0bce9dc6..6a320b68 100644 --- a/tests/test_web_api_testing.py +++ b/tests/test_web_api_testing.py @@ -1,42 +1,51 @@ import unittest from unittest.mock import MagicMock, patch + from hackingBuddyGPT.usecases import SimpleWebAPITesting -from hackingBuddyGPT.usecases.web_api_testing.simple_web_api_testing import SimpleWebAPITestingUseCase -from hackingBuddyGPT.utils import DbStorage, Console +from hackingBuddyGPT.utils.logging import LocalLogger +from hackingBuddyGPT.usecases.web_api_testing.simple_web_api_testing import ( + SimpleWebAPITestingUseCase, +) +from hackingBuddyGPT.utils import Console, DbStorage class TestSimpleWebAPITestingTest(unittest.TestCase): - - @patch('hackingBuddyGPT.utils.openai.openai_lib.OpenAILib') + @patch("hackingBuddyGPT.utils.openai.openai_lib.OpenAILib") def setUp(self, MockOpenAILib): # Mock the OpenAILib instance self.mock_llm = MockOpenAILib.return_value - log_db = DbStorage(':memory:') + log_db = DbStorage(":memory:") console = Console() log_db.init() - self.agent = SimpleWebAPITesting(llm=self.mock_llm) + log = LocalLogger( + log_db=log_db, + console=console, + tag="integration_test_linuxprivesc", + ) + self.agent = SimpleWebAPITesting(llm=self.mock_llm, log=log) self.agent.init() self.simple_api_testing = SimpleWebAPITestingUseCase( agent=self.agent, - log_db=log_db, - console=console, - tag='integration_test_linuxprivesc', - max_turns=len(self.mock_llm.responses) + log=log, + max_turns=len(self.mock_llm.responses), ) self.simple_api_testing.init() def test_initial_prompt(self): # Test if the initial prompt is set correctly - self.assertIn("You're a seasoned API tester tasked with evaluating the REST APIs of a website hosted at https://jsonplaceholder.typicode.com. Your primary objective is to thoroughly explore the APIs, understanding their endpoints, parameters, and responses, and then devise and execute potential exploits. Be diligent in documenting your findings as you navigate through the APIs. Avoid resorting to brute-force tactics. All necessary information can be obtained from the API endpoints themselves. Remember, if you encounter an HTTP method (A string that represents an HTTP method (e.g., 'GET', 'POST', etc.).), promptly submit it as it is of utmost importance.", self.agent._prompt_history[0]['content']) + self.assertIn( + "You're a seasoned API tester tasked with evaluating the REST APIs of a website hosted at https://jsonplaceholder.typicode.com. Your primary objective is to thoroughly explore the APIs, understanding their endpoints, parameters, and responses, and then devise and execute potential exploits. Be diligent in documenting your findings as you navigate through the APIs. Avoid resorting to brute-force tactics. All necessary information can be obtained from the API endpoints themselves. Remember, if you encounter an HTTP method (A string that represents an HTTP method (e.g., 'GET', 'POST', etc.).), promptly submit it as it is of utmost importance.", + self.agent._prompt_history[0]["content"], + ) def test_all_flags_found(self): # Mock console.print to suppress output during testing - with patch('rich.console.Console.print'): + with patch("rich.console.Console.print"): self.agent.all_http_methods_found() self.assertFalse(self.agent.all_http_methods_found()) - @patch('time.perf_counter', side_effect=[1, 2]) # Mocking perf_counter for consistent timing + @patch("time.perf_counter", side_effect=[1, 2]) # Mocking perf_counter for consistent timing def test_perform_round(self, mock_perf_counter): # Prepare mock responses mock_response = MagicMock() @@ -49,7 +58,10 @@ def test_perform_round(self, mock_perf_counter): mock_completion.usage.completion_tokens = 20 # Mock the OpenAI LLM response - self.agent.llm.instructor.chat.completions.create_with_completion.return_value = ( mock_response, mock_completion) + self.agent.llm.instructor.chat.completions.create_with_completion.return_value = ( + mock_response, + mock_completion, + ) # Mock the tool execution result mock_response.execute.return_value = "HTTP/1.1 200 OK" @@ -64,12 +76,11 @@ def test_perform_round(self, mock_perf_counter): # Check if the LLM was called with the correct parameters mock_create_with_completion = self.agent.llm.instructor.chat.completions.create_with_completion - # if it can be called multiple times, use assert_called self.assertGreaterEqual(mock_create_with_completion.call_count, 1) # Check if the prompt history was updated correctly self.assertGreaterEqual(len(self.agent._prompt_history), 1) # Initial message + LLM response + tool message -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()