Sign Up
Log In
Log In
or
Sign Up
Places
All Projects
Status Monitor
Collapse sidebar
system:homeautomation:home-assistant
python-generative-ai-python
_service:obs_scm:generative-ai-python-0.1.0~rc2...
Overview
Repositories
Revisions
Requests
Users
Attributes
Meta
File _service:obs_scm:generative-ai-python-0.1.0~rc2.obscpio of Package python-generative-ai-python
07070100000000000081A40000000000000000000000016459839500000054000000000000000000000000000000000000002A00000000generative-ai-python-0.1.0~rc2/.gitignore/venv/ /.eggs/ /.idea/ /.pytype/ /build/ /docs/api *.egg-info .DS_Store __pycache__ 07070100000001000081A400000000000000000000000164598395000011CE000000000000000000000000000000000000003200000000generative-ai-python-0.1.0~rc2/CODE_OF_CONDUCT.md<!-- # Generated by synthtool. DO NOT EDIT! !--> # Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. This Code of Conduct also applies outside the project spaces when the Project Steward has a reasonable belief that an individual's behavior may have a negative impact on the project or its community. ## Conflict Resolution We do not believe that all conflict is bad; healthy debate and disagreement often yield positive results. However, it is never okay to be disrespectful or to engage in behavior that violates the project’s code of conduct. If you see someone violating the code of conduct, you are encouraged to address the behavior directly with those involved. Many issues can be resolved quickly and easily, and this gives people more control over the outcome of their dispute. If you are unable to resolve the matter for any reason, or if the behavior is threatening or harassing, report it. We are dedicated to providing an environment where participants feel welcome and safe. Reports should be directed to *googleapis-stewards@google.com*, the Project Steward(s) for *Google Cloud Client Libraries*. It is the Project Steward’s duty to receive and address reported violations of the code of conduct. They will then work with a committee consisting of representatives from the Open Source Programs Office and the Google Open Source Strategy team. If for any reason you are uncomfortable reaching out to the Project Steward, please email opensource@google.com. We will investigate every complaint, but you may not receive a direct response. We will use our discretion in determining when and how to follow up on reported incidents, which may range from not taking action to permanent expulsion from the project and project-sponsored spaces. We will notify the accused of the report and provide them an opportunity to discuss it before any action is taken. The identity of the reporter will be omitted from the details of the report supplied to the accused. In potentially harmful situations, such as ongoing harassment or threats to anyone's safety, we may take action without notice. ## Attribution This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html07070100000002000081A40000000000000000000000016459839500000C3A000000000000000000000000000000000000002F00000000generative-ai-python-0.1.0~rc2/CONTRIBUTING.md# How to become a contributor and submit your own code **Table of contents** * [Contributor License Agreements](#contributor-license-agreements) * [Contributing a patch](#contributing-a-patch) * [Running the tests](#running-the-tests) * [Releasing the library](#releasing-the-library) ## Contributor License Agreements We'd love to accept your sample apps and patches! Before we can take them, we have to jump a couple of legal hurdles. Please fill out either the individual or corporate Contributor License Agreement (CLA). * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](https://developers.google.com/open-source/cla/individual). * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](https://developers.google.com/open-source/cla/corporate). Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests. ## Contributing A Patch 1. Submit an issue describing your proposed change to the repo in question. 1. The repo owner will respond to your issue promptly. 1. If your proposed change is accepted, and you haven't already done so, sign a Contributor License Agreement (see details above). 1. Fork the desired repo, develop and test your code changes. 1. Ensure that your code adheres to the existing style in the code to which you are contributing. 1. Ensure that your code has an appropriate set of tests which all pass. 1. Title your pull request following [Conventional Commits](https://www.conventionalcommits.org/) styling. 1. Submit a pull request. ### Before you begin 1. [Select or create a Cloud Platform project][projects]. 1. [Enable billing for your project][billing]. 1. [Enable the Generative Language API][enable_api]. 1. [Set up authentication with a service account][auth] so you can access the API from your local workstation. You can use an API-key, but remember never to same it in your source files. ## Development ### Local install Install the source in "editable" mode, with testing requirements: ``` pip install -e .[dev] ``` This "editable" mode lets you edit the source without needing to reinstall the package. ### Testing Use the builtin unittest package: ``` python -m unittest discover --pattern '*test*.py' ``` Or to debug, use: ```commandline nose2 --debugger ``` ### Type checking Use `pytype` (configured in `pyproject.toml`) ``` pip install pytype pytype ``` ### Formatting: Use black: ``` pip install black black . ``` ### Generate api reference ``` python docs/build_docs.py ``` [setup]: https://cloud.google.com/nodejs/docs/setup [projects]: https://console.cloud.google.com/project [billing]: https://support.google.com/cloud/answer/6293499#enable-billing [enable_api]: https://console.cloud.google.com/flows/enableapi?apiid=generativelanguage.googleapis.com [auth]: https://cloud.google.com/docs/authentication/getting-started07070100000003000081A40000000000000000000000016459839500002C5E000000000000000000000000000000000000002700000000generative-ai-python-0.1.0~rc2/LICENSE Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 07070100000004000081A4000000000000000000000001645983950000003A000000000000000000000000000000000000002900000000generative-ai-python-0.1.0~rc2/README.mdThis is a repo for the generative language client library.07070100000005000041ED0000000000000000000000026459839500000000000000000000000000000000000000000000002400000000generative-ai-python-0.1.0~rc2/docs07070100000006000081A40000000000000000000000016459839500001B59000000000000000000000000000000000000003200000000generative-ai-python-0.1.0~rc2/docs/build_docs.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Api reference docs generation script, using tensorflow_docs This script generates API reference docs for the reference doc generator. $> pip install -U git+https://github.com/tensorflow/docs $> python build_docs.py """ import os import pathlib import textwrap from absl import app from absl import flags import google from google import generativeai from google.ai import generativelanguage from tensorflow_docs.api_generator import generate_lib from tensorflow_docs.api_generator import public_api import yaml # del google.ai.generativelanguage_v1beta2 google.ai.generativelanguage.__doc__ = """\ This package, `google.ai.generativelanguage`, is a low-level auto-generated client library for the PaLM API. ```posix-terminal pip install google.ai.generativelanguage ``` It is built using the same tooling as Google Cloud client libraries, and will be quite familiar if you've used those before. While we encourage Python users to access the PaLM API using the `google.ganerativeai` package (aka `palm`), this lower level package is also available. Each method in the PaLM API is connected to one of the client classes. Pass your API-key to the class' `client_options` when initializing a client: ``` from google.ai import generativelanguage as glm client = glm.DiscussServiceClient( client_options={'api_key':'YOUR_API_KEY'}) ``` To call the api, pass an appropriate request-proto-object. For the `DiscussServiceClient.generate_message` pass a `generativelanguage.GenerateMessageRequest` instance: ``` request = glm.GenerateMessageRequest( model='models/chat-bison-001', prompt=glm.MessagePrompt( messages=[glm.Message(content='Hello!')])) client.generate_message(request) ``` ``` candidates { author: "1" content: "Hello! How can I help you today?" } ... ``` For simplicity: * The API methods also accept key-word arguments. * Anywhere you might pass a proto-object, the library will also accept simple python structures. So the following is equivalent to the previous example: ``` client.generate_message( model='models/chat-bison-001', prompt={'messages':[{'content':'Hello!'}]}) ``` ``` candidates { author: "1" content: "Hello! How can I help you today?" } ... ``` """ HERE = pathlib.Path(__file__).parent PROJECT_SHORT_NAME = "genai" PROJECT_FULL_NAME = "Generative AI - Python" _OUTPUT_DIR = flags.DEFINE_string( "output_dir", default=str(HERE / "api/"), help="Where to write the resulting docs to.", ) _SEARCH_HINTS = flags.DEFINE_bool( "search_hints", True, "Include metadata search hints in the generated files" ) _SITE_PATH = flags.DEFINE_string( "site_path", "/api/python", "Path prefix in the _toc.yaml" ) class MyFilter: def __init__(self, base_dirs): self.filter_base_dirs = public_api.FilterBaseDirs(base_dirs) def drop_staticmethods(self, parent, children): parent = dict(parent.__dict__) for name, value in children: if not isinstance(parent.get(name, None), staticmethod): yield name, value def __call__(self, path, parent, children): if "generativelanguage" in path or "generativeai" in path: children = self.filter_base_dirs(path, parent, children) children = public_api.explicit_package_contents_filter( path, parent, children ) if "generativelanguage" in path: if "ServiceClient" in path[-1]: children = list(self.drop_staticmethods(parent, children)) return children class MyDocGenerator(generate_lib.DocGenerator): def make_default_filters(self): return [ # filter the api. public_api.FailIfNestedTooDeep(10), public_api.filter_module_all, public_api.add_proto_fields, public_api.filter_builtin_modules, public_api.filter_private_symbols, MyFilter(self._base_dir), # public_api.FilterBaseDirs(self._base_dir), public_api.FilterPrivateMap(self._private_map), public_api.filter_doc_controls_skip, public_api.ignore_typing, ] def gen_api_docs(): """Generates api docs for the tensorflow docs package.""" for name in dir(google): if name not in ("generativeai", "ai"): delattr(google, name) google.__name__ = "google" google.__doc__ = textwrap.dedent( """\ This is the top-level google namespace. """ ) doc_generator = MyDocGenerator( root_title=PROJECT_FULL_NAME, # Replace `tensorflow_docs.api_generator` with your module, here. py_modules=[("google", google)], # Replace `tensorflow_docs.api_generator` with your module, here. base_dir=( pathlib.Path(google.generativeai.__file__).parent, pathlib.Path(google.ai.generativelanguage.__file__).parent.parent, ), code_url_prefix=(None, None), search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, # This callback ensures that docs are only generated for objects that # are explicitly imported in your __init__.py files. There are other # options but this is a good starting point. callbacks=[], ) out_path = pathlib.Path(_OUTPUT_DIR.value) doc_generator.build(out_path) # Fixup the toc file. toc_path = out_path / "google/_toc.yaml" toc = yaml.safe_load(toc_path.read_text()) toc["toc"] = toc["toc"][1:] toc["toc"][0]["title"] = "google.ai.generativelanguage" toc["toc"][0]["section"] = toc["toc"][0]["section"][1]["section"] toc["toc"][0], toc["toc"][1] = toc["toc"][1], toc["toc"][0] toc_path.write_text(yaml.dump(toc)) # remove some dummy files and redirect them to `api/` (out_path / "google.md").unlink() (out_path / "google/ai.md").unlink() redirects_path = out_path / "_redirects.yaml" redirects = {"redirects": []} redirects["redirects"].insert(0, {"from": "/api/python/google/ai", "to": "/api/"}) redirects["redirects"].insert(0, {"from": "/api/python/google", "to": "/api/"}) redirects["redirects"].insert(0, {"from": "/api/python", "to": "/api/"}) redirects_path.write_text(yaml.dump(redirects)) print("Output docs to: ", _OUTPUT_DIR.value) def main(_): gen_api_docs() if __name__ == "__main__": app.run(main) 07070100000007000041ED0000000000000000000000036459839500000000000000000000000000000000000000000000002600000000generative-ai-python-0.1.0~rc2/google07070100000008000081A40000000000000000000000016459839500000323000000000000000000000000000000000000003200000000generative-ai-python-0.1.0~rc2/google/__init__.py# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Play nicely with other google.* namespaced-packages. try: import pkg_resources pkg_resources.declare_namespace(__name__) except ImportError: import pkgutil __path__ = pkgutil.extend_path(__path__, __name__) 07070100000009000041ED0000000000000000000000046459839500000000000000000000000000000000000000000000003300000000generative-ai-python-0.1.0~rc2/google/generativeai0707010000000A000081A400000000000000000000000164598395000007B9000000000000000000000000000000000000003F00000000generative-ai-python-0.1.0~rc2/google/generativeai/__init__.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A high level client library for generative AI. ## Setup ```posix-terminal pip install google-generativeai ``` ``` import google.generativeai as genai genai.configure(api_key=os.environ['API_KEY']) ``` ## Chat Use the `genai.chat` function to have a discussion with a model: ``` response = genai.chat(messages=["Hello."]) print(response.last) # 'Hello! What can I help you with?' response.reply("Can you tell me a joke?") ``` ## Models Use the model service discover models and find out more about them: Use `genai.get_model` to get details if you know a model's name: ``` model = genai.get_model('chat-bison-001') # 🦬 ``` Use `genai.list_models` to discover models: ``` import pprint for model in genai.list_models(): pprint.pprint(model) # 🦎🦦🦬🦄 ``` """ from google.generativeai import types from google.generativeai import version from google.generativeai.discuss import chat from google.generativeai.discuss import chat_async from google.generativeai.discuss import count_message_tokens from google.generativeai.text import generate_text from google.generativeai.text import generate_embeddings from google.generativeai.models import list_models from google.generativeai.models import get_model from google.generativeai.client import configure __version__ = version.__version__ del discuss del text del models del client del version 0707010000000B000081A4000000000000000000000001645983950000162D000000000000000000000000000000000000003D00000000generative-ai-python-0.1.0~rc2/google/generativeai/client.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import cast, Optional, Union import google.ai.generativelanguage as glm from google.auth import credentials as ga_credentials from google.api_core import client_options as client_options_lib from google.api_core import gapic_v1 from google.generativeai import version USER_AGENT = "genai-py" default_client_config = {} default_discuss_client = None default_discuss_async_client = None default_model_client = None default_text_client = None def configure( *, api_key: Optional[str] = None, credentials: Union[ga_credentials.Credentials, dict, None] = None, # The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'. # See `_transport_registry` in `DiscussServiceClientMeta`. # Since the transport classes align with the client classes it wouldn't make # sense to accept a `Transport` object here even though the client classes can. # We could accept a dict since all the `Transport` classes take the same args, # but that seems rare. Users that need it can just switch to the low level API. transport: Union[str, None] = None, client_options: Union[client_options_lib.ClientOptions, dict, None] = None, client_info: Optional[gapic_v1.client_info.ClientInfo] = None, ): """Captures default client configuration. If no API key has been provided (either directly, or on `client_options`) and the `GOOGLE_API_KEY` environment variable is set, it will be used as the API key. Args: Refer to `glm.DiscussServiceClient`, and `glm.ModelsServiceClient` for details on additional arguments. api_key: The API-Key to use when creating the default clients (each service uses a separate client). This is a shortcut for `client_options={"api_key": api_key}`. If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be used. """ global default_client_config global default_discuss_client global default_model_client global default_text_client if isinstance(client_options, dict): client_options = client_options_lib.from_dict(client_options) if client_options is None: client_options = client_options_lib.ClientOptions() client_options = cast(client_options_lib.ClientOptions, client_options) had_api_key_value = getattr(client_options, "api_key", None) if had_api_key_value: if api_key is not None: raise ValueError( "You can't set both `api_key` and `client_options['api_key']`." ) else: if api_key is None: # If no key is provided explicitly, attempt to load one from the # environment. api_key = os.getenv("GOOGLE_API_KEY") client_options.api_key = api_key user_agent = f"{USER_AGENT}/{version.__version__}" if client_info: # Be respectful of any existing agent setting. if client_info.user_agent: client_info.user_agent += f" {user_agent}" else: client_info.user_agent = user_agent else: client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent) new_default_client_config = { "credentials": credentials, "transport": transport, "client_options": client_options, "client_info": client_info, } new_default_client_config = { key: value for key, value in new_default_client_config.items() if value is not None } default_client_config = new_default_client_config default_discuss_client = None default_text_client = None default_model_client = None def get_default_discuss_client() -> glm.DiscussServiceClient: global default_discuss_client if default_discuss_client is None: # Attempt to configure using defaults. if not default_client_config: configure() default_discuss_client = glm.DiscussServiceClient(**default_client_config) return default_discuss_client def get_default_text_client(): global default_text_client if default_text_client is None: # Attempt to configure using defaults. if not default_client_config: configure() default_text_client = glm.TextServiceClient(**default_client_config) return default_text_client def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient: global default_discuss_async_client if default_discuss_async_client is None: # Attempt to configure using defaults. if not default_client_config: configure() default_discuss_async_client = glm.DiscussServiceAsyncClient( **default_client_config ) return default_discuss_async_client def get_default_model_client(): global default_model_client if default_model_client is None: # Attempt to configure using defaults. if not default_client_config: configure() default_model_client = glm.ModelServiceClient(**default_client_config) return default_model_client 0707010000000C000081A400000000000000000000000164598395000044D9000000000000000000000000000000000000003E00000000generative-ai-python-0.1.0~rc2/google/generativeai/discuss.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import sys import textwrap from typing import Iterable, List, Optional, Union import google.ai.generativelanguage as glm from google.generativeai.client import get_default_discuss_client from google.generativeai.client import get_default_discuss_async_client from google.generativeai.types import discuss_types from google.generativeai.types import model_types from google.generativeai.types import safety_types def _make_message(content: discuss_types.MessageOptions) -> glm.Message: if isinstance(content, glm.Message): return content if isinstance(content, str): return glm.Message(content=content) else: return glm.Message(content) def _make_messages(messages: discuss_types.MessagesOptions) -> List[glm.Message]: if isinstance(messages, (str, dict, glm.Message)): messages = [_make_message(messages)] else: messages = [_make_message(message) for message in messages] even_authors = set(msg.author for msg in messages[::2] if msg.author) if not even_authors: even_author = "0" elif len(even_authors) == 1: even_author = even_authors.pop() else: raise discuss_types.AuthorError("Authors are not strictly alternating") odd_authors = set(msg.author for msg in messages[1::2] if msg.author) if not odd_authors: odd_author = "1" elif len(odd_authors) == 1: odd_author = odd_authors.pop() else: raise discuss_types.AuthorError("Authors are not strictly alternating") if all(msg.author for msg in messages): return messages authors = [even_author, odd_author] for i, msg in enumerate(messages): msg.author = authors[i % 2] return messages def _make_example(item: discuss_types.ExampleOptions) -> glm.Example: if isinstance(item, glm.Example): return item if isinstance(item, dict): item = item.copy() item["input"] = _make_message(item["input"]) item["output"] = _make_message(item["output"]) return glm.Example(item) if isinstance(item, Iterable): input, output = list(item) return glm.Example(input=_make_message(input), output=_make_message(output)) # try anyway return glm.Example(item) def _make_examples_from_flat( examples: List[discuss_types.MessageOptions], ) -> List[glm.Example]: if len(examples) % 2 != 0: raise ValueError( textwrap.dedent( """\ You must pass `Primer` objects, pairs of messages, or an *even* number of messages, got: {len(primers)} messages""" ) ) result = [] pair = [] for n, item in enumerate(examples): msg = _make_message(item) pair.append(msg) if n % 2 == 0: continue primer = glm.Example( input=pair[0], output=pair[1], ) result.append(primer) pair = [] return result def _make_examples(examples: discuss_types.ExamplesOptions) -> List[glm.Example]: if isinstance(examples, glm.Example): return [examples] if isinstance(examples, dict): return [_make_example(examples)] examples = list(examples) if not examples: return examples first = examples[0] if isinstance(first, dict): if "content" in first: # These are `Messages` return _make_examples_from_flat(examples) else: if not ("input" in first and "output" in first): raise TypeError( "To create an `Example` from a dict you must supply both `input` and an `output` keys" ) else: if isinstance(first, discuss_types.MESSAGE_OPTIONS): return _make_examples_from_flat(examples) result = [] for item in examples: result.append(_make_example(item)) return result def _make_message_prompt_dict( prompt: discuss_types.MessagePromptOptions = None, *, context: Optional[str] = None, examples: Optional[discuss_types.ExamplesOptions] = None, messages: Optional[discuss_types.MessagesOptions] = None, ) -> glm.MessagePrompt: if prompt is None: prompt = dict( context=context, examples=examples, messages=messages, ) else: flat_prompt = ( (context is not None) or (examples is not None) or (messages is not None) ) if flat_prompt: raise ValueError( "You can't set `prompt`, and its fields `(context, examples, messages)`" " at the same time" ) if isinstance(prompt, glm.MessagePrompt): return prompt elif isinstance(prompt, dict): # Always check dict before Iterable. pass else: prompt = {"messages": prompt} keys = set(prompt.keys()) if not keys.issubset(discuss_types.MESSAGE_PROMPT_KEYS): raise KeyError( f"Found extra entries in the prompt dictionary: {keys - discuss_types.MESSAGE_PROMPT_KEYS}" ) examples = prompt.get("examples", None) if examples is not None: prompt["examples"] = _make_examples(examples) messages = prompt.get("messages", None) if messages is not None: prompt["messages"] = _make_messages(messages) prompt = {k: v for k, v in prompt.items() if v is not None} return prompt def _make_message_prompt( prompt: discuss_types.MessagePromptOptions = None, *, context: Optional[str] = None, examples: Optional[discuss_types.ExamplesOptions] = None, messages: Optional[discuss_types.MessagesOptions] = None, ) -> glm.MessagePrompt: prompt = _make_message_prompt_dict( prompt=prompt, context=context, examples=examples, messages=messages ) return glm.MessagePrompt(prompt) def _make_generate_message_request( *, model: Optional[model_types.ModelNameOptions], context: Optional[str] = None, examples: Optional[discuss_types.ExamplesOptions] = None, messages: Optional[discuss_types.MessagesOptions] = None, temperature: Optional[float] = None, candidate_count: Optional[int] = None, top_p: Optional[float] = None, top_k: Optional[float] = None, prompt: Optional[discuss_types.MessagePromptOptions] = None, ) -> glm.GenerateMessageRequest: model = model_types.make_model_name(model) prompt = _make_message_prompt( prompt=prompt, context=context, examples=examples, messages=messages ) return glm.GenerateMessageRequest( model=model, prompt=prompt, temperature=temperature, top_p=top_p, top_k=top_k, candidate_count=candidate_count, ) def set_doc(doc): def inner(f): f.__doc__ = doc return f return inner DEFAULT_DISCUSS_MODEL = "models/chat-bison-001" def chat( *, model: Optional[model_types.ModelNameOptions] = "models/chat-bison-001", context: Optional[str] = None, examples: Optional[discuss_types.ExamplesOptions] = None, messages: Optional[discuss_types.MessagesOptions] = None, temperature: Optional[float] = None, candidate_count: Optional[int] = None, top_p: Optional[float] = None, top_k: Optional[float] = None, prompt: Optional[discuss_types.MessagePromptOptions] = None, client: Optional[glm.DiscussServiceClient] = None, ) -> discuss_types.ChatResponse: """Calls the API and returns a `types.ChatResponse` containing the response. Args: model: Which model to call, as a string or a `types.Model`. context: Text that should be provided to the model first, to ground the response. If not empty, this `context` will be given to the model first before the `examples` and `messages`. This field can be a description of your prompt to the model to help provide context and guide the responses. Examples: * "Translate the phrase from English to French." * "Given a statement, classify the sentiment as happy, sad or neutral." Anything included in this field will take precedence over history in `messages` if the total input size exceeds the model's `Model.input_token_limit`. examples: Examples of what the model should generate. This includes both the user input and the response that the model should emulate. These `examples` are treated identically to conversation messages except that they take precedence over the history in `messages`: If the total input size exceeds the model's `input_token_limit` the input will be truncated. Items will be dropped from `messages` before `examples` messages: A snapshot of the conversation history sorted chronologically. Turns alternate between two authors. If the total input size exceeds the model's `input_token_limit` the input will be truncated: The oldest items will be dropped from `messages`. temperature: Controls the randomness of the output. Must be positive. Typical values are in the range: `[0.0,1.0]`. Higher values produce a more random and varied response. A temperature of zero will be deterministic. candidate_count: The **maximum** number of generated response messages to return. This value must be between `[1, 8]`, inclusive. If unset, this will default to `1`. Note: Only unique candidates are returned. Higher temperatures are more likely to produce unique candidates. Setting `temperature=0.0` will always return 1 candidate regardless of the `candidate_count`. top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. `top_k` sets the maximum number of tokens to sample from on each step. top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. `top_p` configures the nucleus sampling. It sets the maximum cumulative probability of tokens to sample from. For example, if the sorted probabilities are `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample as `[0.625, 0.25, 0.125, 0, 0, 0]`. Typical values are in the `[0.9, 1.0]` range. prompt: You may pass a `types.MessagePromptOptions` **instead** of a setting `context`/`examples`/`messages`, but not both. client: If you're not relying on the default client, you pass a `glm.DiscussServiceClient` instead. Returns: A `types.ChatResponse` containing the model's reply. """ request = _make_generate_message_request( model=model, context=context, examples=examples, messages=messages, temperature=temperature, candidate_count=candidate_count, top_p=top_p, top_k=top_k, prompt=prompt, ) return _generate_response(client=client, request=request) @set_doc(chat.__doc__) async def chat_async( *, model: Optional[model_types.ModelNameOptions] = None, context: Optional[str] = None, examples: Optional[discuss_types.ExamplesOptions] = None, messages: Optional[discuss_types.MessagesOptions] = None, temperature: Optional[float] = None, candidate_count: Optional[int] = None, top_p: Optional[float] = None, top_k: Optional[float] = None, prompt: Optional[discuss_types.MessagePromptOptions] = None, client: Optional[glm.DiscussServiceAsyncClient] = None, ) -> discuss_types.ChatResponse: request = _make_generate_message_request( model=model, context=context, examples=examples, messages=messages, temperature=temperature, candidate_count=candidate_count, top_p=top_p, top_k=top_k, prompt=prompt, ) return await _generate_response_async(client=client, request=request) if (sys.version_info.major, sys.version_info.minor) >= (3, 10): DATACLASS_KWARGS = {"kw_only": True} else: DATACLASS_KWARGS = {} @set_doc(discuss_types.ChatResponse.__doc__) @dataclasses.dataclass(**DATACLASS_KWARGS, init=False) class ChatResponse(discuss_types.ChatResponse): _client: Optional[glm.DiscussServiceClient] = dataclasses.field( default=lambda: None, repr=False ) def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) @property @set_doc(discuss_types.ChatResponse.last.__doc__) def last(self) -> Optional[str]: if self.messages[-1]: return self.messages[-1]["content"] else: return None @last.setter def last(self, message: discuss_types.MessageOptions): message = _make_message(message) self.messages[-1] = message @set_doc(discuss_types.ChatResponse.reply.__doc__) def reply( self, message: discuss_types.MessageOptions ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceAsyncClient): raise TypeError( f"reply can't be called on an async client, use reply_async instead." ) if self.last is None: raise ValueError( "The last response from the model did not return any candidates.\n" "Check the `.filters` attribute to see why the responses were filtered:\n" f"{self.filters}" ) request = self.to_dict() request.pop("candidates") request.pop("filters", None) request["messages"] = list(request["messages"]) request["messages"].append(_make_message(message)) request = _make_generate_message_request(**request) return _generate_response(request=request, client=self._client) @set_doc(discuss_types.ChatResponse.reply.__doc__) async def reply_async( self, message: discuss_types.MessageOptions ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceClient): raise TypeError( f"reply_async can't be called on a non-async client, use reply instead." ) request = self.to_dict() request.pop("candidates") request.pop("filters") request["messages"] = list(request["messages"]) request["messages"].append(_make_message(message)) request = _make_generate_message_request(**request) return await _generate_response_async(request=request, client=self._client) def _build_chat_response( request: glm.GenerateMessageRequest, response: glm.GenerateMessageResponse, client: Union[glm.DiscussServiceClient, glm.DiscussServiceAsyncClient], ) -> ChatResponse: request = type(request).to_dict(request) prompt = request.pop("prompt") request["examples"] = prompt["examples"] request["context"] = prompt["context"] request["messages"] = prompt["messages"] response = type(response).to_dict(response) response.pop("messages") response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) if response["candidates"]: last = response["candidates"][0] else: last = None request["messages"].append(last) request.setdefault("temperature", None) request.setdefault("candidate_count", None) return ChatResponse( _client=client, **response, **request ) # pytype: disable=missing-parameter def _generate_response( request: glm.GenerateMessageRequest, client: Optional[glm.DiscussServiceClient] = None, ) -> ChatResponse: if client is None: client = get_default_discuss_client() response = client.generate_message(request) return _build_chat_response(request, response, client) async def _generate_response_async( request: glm.GenerateMessageRequest, client: Optional[glm.DiscussServiceAsyncClient] = None, ) -> ChatResponse: if client is None: client = get_default_discuss_async_client() response = await client.generate_message(request) return _build_chat_response(request, response, client) def count_message_tokens( *, prompt: discuss_types.MessagePromptOptions = None, context: Optional[str] = None, examples: Optional[discuss_types.ExamplesOptions] = None, messages: Optional[discuss_types.MessagesOptions] = None, model: str = DEFAULT_DISCUSS_MODEL, client: Optional[glm.DiscussServiceAsyncClient] = None, ): prompt = _make_message_prompt( prompt, context=context, examples=examples, messages=messages ) if client is None: client = get_default_discuss_client() result = client.count_message_tokens(model=model, prompt=prompt) return type(result).to_dict(result) 0707010000000D000081A4000000000000000000000001645983950000034D000000000000000000000000000000000000004600000000generative-ai-python-0.1.0~rc2/google/generativeai/docstring_utils.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. def strip_oneof(docstring): lines = docstring.splitlines() lines = [line for line in lines if ".. _oneof:" not in line] lines = [line for line in lines if "This field is a member of `oneof`_" not in line] return "\n".join(lines) 0707010000000E000081A40000000000000000000000016459839500000B82000000000000000000000000000000000000003D00000000generative-ai-python-0.1.0~rc2/google/generativeai/models.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses from typing import Optional, List import google.ai.generativelanguage as glm from google.generativeai.client import get_default_model_client from google.generativeai.types import model_types def get_model(name, *, client=None) -> model_types.Model: """Get the `types.Model` for the given model name.""" if client is None: client = get_default_model_client() result = client.get_model(name=name) result = type(result).to_dict(result) return model_types.Model(**result) class ModelsIterable(model_types.ModelsIterable): def __init__( self, *, page_size: int, page_token: Optional[str], models: List[model_types.Model], client: Optional[glm.ModelServiceClient] ): self._page_size = page_size self._page_token = page_token self._models = models self._client = client def __iter__(self): while self: page = self._models yield from page self = self._next_page() def _next_page(self): if not self._page_token: return None return _list_models( page_size=self._page_size, page_token=self._page_token, client=self._client ) def _list_models(page_size, page_token, client): result = client.list_models(page_size=page_size, page_token=page_token) result = result._response result = type(result).to_dict(result) result["models"] = [model_types.Model(**mod) for mod in result["models"]] result["page_size"] = page_size result["page_token"] = result.pop("next_page_token") result["client"] = client return ModelsIterable(**result) def list_models( *, page_size: Optional[int] = None, client: Optional[glm.ModelServiceClient] = None ) -> model_types.ModelsIterable: """Lists available models. ``` import pprint for model in genai.list_models(): pprint.pprint(model) ``` Args: page_size: How many `types.Models` to fetch per page (api call). client: You may pass a `glm.ModelServiceClient` instead of using the default client. Returns: An iterable of `types.Model` objects. """ if client is None: client = get_default_model_client() return _list_models(page_size, page_token=None, client=client) 0707010000000F000041ED0000000000000000000000036459839500000000000000000000000000000000000000000000003C00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook07070100000010000081A40000000000000000000000016459839500000463000000000000000000000000000000000000004800000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/__init__.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Notebook extensions for Generative AI.""" def load_ipython_extension(ipython): """Register the Colab Magic extension to support %load_ext.""" # pylint: disable-next=g-import-not-at-top from google.generativeai.notebook import magics ipython.register_magics(magics.Magics) # Since we're in an interactive environment, make the tables prettier. try: # pylint: disable-next=g-import-not-at-top from google import colab colab.data_table.enable_dataframe_formatter() except ImportError: pass 07070100000011000081A40000000000000000000000016459839500000ECF000000000000000000000000000000000000004F00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/argument_parser.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Customized ArgumentParser. The default behvaior of argparse.ArgumentParser's parse_args() method is to exit with a SystemExit exception in the following cases: 1. When the user requests a help message (with the --help or -h flags), or 2. When there's a parsing error (e.g. missing required flags or mistyped flags) To make the errors more user-friendly, this class customizes argparse.ArgumentParser and raises either ParserNormalExit for (1) or ParserError for (2); this way the caller has control over how to display them to the user. """ from __future__ import annotations import abc import argparse from typing import Sequence from google.generativeai.notebook import ipython_env # pylint: disable-next=g-bad-exception-name class _ParserBaseException(RuntimeError, metaclass=abc.ABCMeta): """Base class for parser exceptions including normal exit.""" def __init__(self, msgs: Sequence[str], *args, **kwargs): super().__init__("".join(msgs), *args, **kwargs) self._msgs = msgs self._ipython_env: ipython_env.IPythonEnv | None = None def set_ipython_env(self, env: ipython_env.IPythonEnv) -> None: self._ipython_env = env def _ipython_display_(self): self.display(self._ipython_env) def msgs(self) -> Sequence[str]: return self._msgs @abc.abstractmethod def display(self, env: ipython_env.IPythonEnv | None) -> None: """Display this exception on an IPython console.""" # ParserNormalExit is not an error: it's a way for ArgumentParser to indicate # that the user has entered a special request (e.g. "--help") instead of a # runnable command. # pylint: disable-next=g-bad-exception-name class ParserNormalExit(_ParserBaseException): """Exception thrown when the parser exits normally. This is usually thrown when the user requests the help message. """ def display(self, env: ipython_env.IPythonEnv | None) -> None: for msg in self._msgs: print(msg) class ParserError(_ParserBaseException): """Exception thrown when there is an error.""" def display(self, env: ipython_env.IPythonEnv | None) -> None: for msg in self._msgs: print(msg) if env is not None: # Highlight to the user that an error has occurred. env.display_html("<b style='font-family:courier new'>ERROR</b>") class ArgumentParser(argparse.ArgumentParser): """Customized ArgumentParser for LLM Magics. This class overrides the parent argparse.ArgumentParser's error-handling methods to avoid side-effects like printing to stderr. The messages are accumulated and passed into the raised exceptions for the caller to handle them. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._messages: list[str] = [] def _print_message(self, message, file=None): """Override ArgumentParser's _print_message() method.""" del file self._messages.append(message) def exit(self, status=0, message=None): """Override ArgumentParser's exit() method.""" if message: self._print_message(message) msgs = self._messages self._messages = [] if status == 0: raise ParserNormalExit(msgs=msgs) else: raise ParserError(msgs=msgs) 07070100000012000081A40000000000000000000000016459839500000776000000000000000000000000000000000000005400000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/argument_parser_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for ArgumentParser.""" from __future__ import annotations import argparse from absl.testing import absltest from google.generativeai.notebook import argument_parser as parser_lib class ArgumentParserTest(absltest.TestCase): def test_help(self): """Verify that help messages raise ParserNormalExit.""" parser = parser_lib.ArgumentParser() with self.assertRaisesRegex( parser_lib.ParserNormalExit, "show this help message and exit" ): parser.parse_args(["-h"]) def test_parse_arg_errors(self): def new_parser() -> argparse.ArgumentParser: parser = parser_lib.ArgumentParser() parser.add_argument("--value", type=int, required=True) return parser # Normal case: no error. results = new_parser().parse_args(["--value", "42"]) self.assertEqual(42, results.value) with self.assertRaisesRegex(parser_lib.ParserError, "invalid int value"): new_parser().parse_args(["--value", "forty-two"]) with self.assertRaisesRegex( parser_lib.ParserError, "the following arguments are required" ): new_parser().parse_args([]) with self.assertRaisesRegex( parser_lib.ParserError, "expected one argument" ): new_parser().parse_args(["--value"]) if __name__ == "__main__": absltest.main() 07070100000013000081A40000000000000000000000016459839500004C9C000000000000000000000000000000000000004F00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/cmd_line_parser.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Parses an LLM command line.""" from __future__ import annotations import argparse import shlex import sys from typing import AbstractSet, Any, Callable, MutableMapping, Sequence from google.generativeai.notebook import argument_parser from google.generativeai.notebook import flag_def from google.generativeai.notebook import input_utils from google.generativeai.notebook import model_registry from google.generativeai.notebook import output_utils from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import post_process_utils from google.generativeai.notebook import py_utils from google.generativeai.notebook import sheets_utils from google.generativeai.notebook.lib import llm_function from google.generativeai.notebook.lib import llmfn_inputs_source from google.generativeai.notebook.lib import llmfn_outputs from google.generativeai.notebook.lib import model as model_lib _MIN_CANDIDATE_COUNT = 1 _MAX_CANDIDATE_COUNT = 8 def _validate_input_source_against_placeholders( source: llmfn_inputs_source.LLMFnInputsSource, placeholders: AbstractSet[str], ) -> None: for inputs in source.to_normalized_inputs(): for keyword in placeholders: if keyword not in inputs: raise ValueError('Placeholder "{}" not found in input'.format(keyword)) def _get_resolve_input_from_py_var_fn( placeholders: AbstractSet[str] | None, ) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]: def _fn(var_name: str) -> llmfn_inputs_source.LLMFnInputsSource: source = input_utils.get_inputs_source_from_py_var(var_name) if placeholders: _validate_input_source_against_placeholders(source, placeholders) return source return _fn def _resolve_compare_fn_var( name: str, ) -> tuple[str, parsed_args_lib.TextResultCompareFn]: """Resolves a value passed into --compare_fn.""" fn = py_utils.get_py_var(name) if not isinstance(fn, Callable): raise ValueError( 'Variable "{}" does not contain a Callable object'.format(name) ) return name, fn def _resolve_ground_truth_var(name: str) -> Sequence[str]: """Resolves a value passed into --ground_truth.""" value = py_utils.get_py_var(name) # "str" and "bytes" are also Sequences but we want an actual Sequence of # strings, like a list. if ( not isinstance(value, Sequence) or isinstance(value, str) or isinstance(value, bytes) ): raise ValueError( 'Variable "{}" does not contain a Sequence of strings'.format(name) ) for x in value: if not isinstance(x, str): raise ValueError( 'Variable "{}" does not contain a Sequence of strings'.format(name) ) return value def _get_resolve_sheets_inputs_fn( placeholders: AbstractSet[str] | None, ) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]: def _fn(value: str) -> llmfn_inputs_source.LLMFnInputsSource: sheets_id = sheets_utils.get_sheets_id_from_str(value) source = sheets_utils.SheetsInputs(sheets_id) if placeholders: _validate_input_source_against_placeholders(source, placeholders) return source return _fn def _resolve_sheets_outputs(value: str) -> llmfn_outputs.LLMFnOutputsSink: sheets_id = sheets_utils.get_sheets_id_from_str(value) return sheets_utils.SheetsOutputs(sheets_id) def _add_model_flags( parser: argparse.ArgumentParser, ) -> None: """Adds flags that are related to model selection and config.""" flag_def.EnumFlagDef( name="model_type", short_name="mt", enum_type=model_registry.ModelName, default_value=model_registry.ModelRegistry.DEFAULT_MODEL, help_msg="The type of model to use.", ).add_argument_to_parser(parser) def _check_is_greater_than_or_equal_to_zero(x: float) -> float: if x < 0: raise ValueError( "Value should be greater than or equal to zero, got {}".format(x) ) return x flag_def.SingleValueFlagDef( name="temperature", short_name="t", parse_type=float, # Use None for default value to indicate that this will use the default # value in Text service. default_value=None, parse_to_dest_type_fn=_check_is_greater_than_or_equal_to_zero, help_msg=( "Controls the randomness of the output. Must be positive. Typical" " values are in the range: [0.0, 1.0]. Higher values produce a more" " random and varied response. A temperature of zero will be" " deterministic." ), ).add_argument_to_parser(parser) flag_def.SingleValueFlagDef( name="model", short_name="m", default_value=None, help_msg=( "The name of the model to use. If not provided, a default model will" " be used." ), ).add_argument_to_parser(parser) def _check_candidate_count_range(x: Any) -> int: if x < _MIN_CANDIDATE_COUNT or x > _MAX_CANDIDATE_COUNT: raise ValueError( "Value should be in the range [{}, {}], got {}".format( _MIN_CANDIDATE_COUNT, _MAX_CANDIDATE_COUNT, x ) ) return int(x) flag_def.SingleValueFlagDef( name="candidate_count", short_name="cc", parse_type=int, # Use None for default value to indicate that this will use the default # value in Text service. default_value=None, parse_to_dest_type_fn=_check_candidate_count_range, help_msg="The number of candidates to produce.", ).add_argument_to_parser(parser) flag_def.BooleanFlagDef( name="unique", help_msg="Whether to dedupe candidates returned by the model.", ).add_argument_to_parser(parser) def _add_input_flags( parser: argparse.ArgumentParser, placeholders: AbstractSet[str] | None, ) -> None: """Adds flags to read inputs from a Python variable or Sheets.""" flag_def.MultiValuesFlagDef( name="inputs", short_name="i", dest_type=llmfn_inputs_source.LLMFnInputsSource, parse_to_dest_type_fn=_get_resolve_input_from_py_var_fn(placeholders), help_msg=( "Optional names of Python variables containing inputs to use to" " instantiate a prompt. The variable must be either: a dictionary" " {'key1': ['val1', 'val2'] ...}, or an instance of LLMFnInputsSource" " such as SheetsInput." ), ).add_argument_to_parser(parser) flag_def.MultiValuesFlagDef( name="sheets_input_names", short_name="si", dest_type=llmfn_inputs_source.LLMFnInputsSource, parse_to_dest_type_fn=_get_resolve_sheets_inputs_fn(placeholders), help_msg=( "Optional names of Google Sheets to read inputs from. This is" " equivalent to using --inputs with the names of variables that are" " instances of SheetsInputs, just more convenient to use." ), ).add_argument_to_parser(parser) def _add_output_flags( parser: argparse.ArgumentParser, ) -> None: """Adds flags to write outputs to a Python variable.""" flag_def.MultiValuesFlagDef( name="outputs", short_name="o", dest_type=llmfn_outputs.LLMFnOutputsSink, parse_to_dest_type_fn=output_utils.get_outputs_sink_from_py_var, help_msg=( "Optional names of Python variables to output to. If the Python" " variable has not already been defined, it will be created. If the" " variable is defined and is an instance of LLMFnOutputsSink, the" " outputs will be written through the sink's write_outputs() method." ), ).add_argument_to_parser(parser) flag_def.MultiValuesFlagDef( name="sheets_output_names", short_name="so", dest_type=llmfn_outputs.LLMFnOutputsSink, parse_to_dest_type_fn=_resolve_sheets_outputs, help_msg=( "Optional names of Google Sheets to write inputs to. This is" " equivalent to using --outputs with the names of variables that are" " instances of SheetsOutputs, just more convenient to use." ), ).add_argument_to_parser(parser) def _add_compare_flags( parser: argparse.ArgumentParser, ) -> None: flag_def.MultiValuesFlagDef( name="compare_fn", dest_type=tuple, parse_to_dest_type_fn=_resolve_compare_fn_var, help_msg=( "An optional function that takes two inputs: (lhs_result, rhs_result)" " which are the results of the left- and right-hand side functions. " "Multiple comparison functions can be provided." ), ).add_argument_to_parser(parser) def _add_eval_flags( parser: argparse.ArgumentParser, ) -> None: flag_def.SingleValueFlagDef( name="ground_truth", required=True, dest_type=Sequence, parse_to_dest_type_fn=_resolve_ground_truth_var, help_msg=( "A variable containing a Sequence of strings representing the ground" " truth that the output of this cell will be compared against. It" " should have the same number of entries as inputs." ), ).add_argument_to_parser(parser) def _create_run_parser( parser: argparse.ArgumentParser, placeholders: AbstractSet[str] | None, ) -> None: """Adds flags for the `run` command. `run` sends one or more prompts to a model. Args: parser: The parser to which flags will be added. placeholders: Placeholders from prompts in the cell contents. """ _add_model_flags(parser) _add_input_flags(parser, placeholders) _add_output_flags(parser) def _create_compile_parser( parser: argparse.ArgumentParser, ) -> None: """Adds flags for the compile command. `compile` "compiles" a prompt and model call into a callable function. Args: parser: The parser to which flags will be added. """ # Add a positional argument for "compile_save_name". def _compile_save_name_fn(var_name: str) -> str: try: py_utils.validate_var_name(var_name) except ValueError as e: # Re-raise as ArgumentError to preserve the original error message. raise argparse.ArgumentError(None, "{}".format(e)) from e return var_name save_name_help = ( "The name of a Python variable to save the compiled function to." ) parser.add_argument( "compile_save_name", help=save_name_help, type=_compile_save_name_fn ) _add_model_flags(parser) def _create_compare_parser( parser: argparse.ArgumentParser, placeholders: AbstractSet[str] | None, ) -> None: """Adds flags for the compare command. Args: parser: The parser to which flags will be added. placeholders: Placeholders from prompts in the compiled functions. """ # Add positional arguments. def _resolve_llm_function_fn( var_name: str, ) -> tuple[str, llm_function.LLMFunction]: try: py_utils.validate_var_name(var_name) except ValueError as e: # Re-raise as ArgumentError to preserve the original error message. raise argparse.ArgumentError(None, "{}".format(e)) from e fn = py_utils.get_py_var(var_name) if not isinstance(fn, llm_function.LLMFunction): raise argparse.ArgumentError( None, '{} is not a function created with the "compile" command'.format( var_name ), ) return var_name, fn name_help = ( "The name of a Python variable containing a function previously created" ' with the "compile" command.' ) parser.add_argument( "lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn ) parser.add_argument( "rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn ) _add_input_flags(parser, placeholders) _add_output_flags(parser) _add_compare_flags(parser) def _create_eval_parser( parser: argparse.ArgumentParser, placeholders: AbstractSet[str] | None, ) -> None: """Adds flags for the eval command. Args: parser: The parser to which flags will be added. placeholders: Placeholders from prompts in the cell contents. """ _add_model_flags(parser) _add_input_flags(parser, placeholders) _add_output_flags(parser) _add_compare_flags(parser) _add_eval_flags(parser) def _create_parser( placeholders: AbstractSet[str] | None, ) -> argparse.ArgumentParser: """Create the full parser.""" system_name = "palm" description = "A system for interacting with LLMs." epilog = "" # Commands extra_args = {} if sys.version_info[0:2] >= (3, 9): extra_args["exit_on_error"] = False parser = argument_parser.ArgumentParser( prog=system_name, description=description, epilog=epilog, **extra_args, ) subparsers = parser.add_subparsers(dest="cmd") _create_run_parser( subparsers.add_parser(parsed_args_lib.CommandName.RUN_CMD.value), placeholders, ) _create_compile_parser( subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value) ) _create_compare_parser( subparsers.add_parser(parsed_args_lib.CommandName.COMPARE_CMD.value), placeholders, ) _create_eval_parser( subparsers.add_parser(parsed_args_lib.CommandName.EVAL_CMD.value), placeholders, ) return parser def _validate_parsed_args(parsed_args: parsed_args_lib.ParsedArgs) -> None: # If candidate_count is not set (i.e. is None), assuming the default value # is 1. if parsed_args.unique and ( parsed_args.model_args.candidate_count is None or parsed_args.model_args.candidate_count == 1 ): print( '"--unique" works across candidates only: it should be used with' " --candidate_count set to a value greater-than one." ) class CmdLineParser: """Implementation of Magics command line parser.""" # Commands DEFAULT_CMD = parsed_args_lib.CommandName.RUN_CMD # Post-processing operator. PIPE_OP = "|" @classmethod def _split_post_processing_tokens( cls, tokens: Sequence[str], ) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]: """Splits inputs into the command and post processing tokens. The command is represented as a sequence of tokens. See comments on the PostProcessingTokens type alias. E.g. Given: "run --temperature 0.5 | add_score | to_lower_case" The command will be: ["run", "--temperature", "0.5"]. The post processing tokens will be: [["add_score"], ["to_lower_case"]] Args: tokens: The command line tokens. Returns: A tuple of (command line, post processing tokens). """ split_tokens = [] start_idx: int | None = None for token_num, token in enumerate(tokens): if start_idx is None: start_idx = token_num if token == CmdLineParser.PIPE_OP: split_tokens.append( tokens[start_idx:token_num] if start_idx is not None else [] ) start_idx = None # Add the remaining tokens after the last PIPE_OP. split_tokens.append(tokens[start_idx:] if start_idx is not None else []) return split_tokens[0], split_tokens[1:] @classmethod def _tokenize_line( cls, line: str ) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]: """Parses `line` and returns command line and post processing tokens.""" # Check to make sure there is a command at the start. If not, add the # default command to the list of tokens. tokens = shlex.split(line) if not tokens: tokens = [CmdLineParser.DEFAULT_CMD.value] first_token = tokens[0] # Add default command if the first token is not the help token. if not first_token[0].isalpha() and first_token not in ["-h", "--help"]: tokens = [CmdLineParser.DEFAULT_CMD.value] + tokens # Split line into tokens and post-processing return CmdLineParser._split_post_processing_tokens(tokens) @classmethod def _get_model_args( cls, parsed_results: MutableMapping[str, Any] ) -> tuple[MutableMapping[str, Any], model_lib.ModelArguments]: """Extracts fields for model args from `parsed_results`. Keys specific to model arguments will be removed from `parsed_results`. Args: parsed_results: A dictionary of parsed arguments (from ArgumentParser). It will be modified in place. Returns: A tuple of (updated parsed_results, model arguments). """ model = parsed_results.pop("model", None) temperature = parsed_results.pop("temperature", None) candidate_count = parsed_results.pop("candidate_count", None) model_args = model_lib.ModelArguments( model=model, temperature=temperature, candidate_count=candidate_count ) return parsed_results, model_args def parse_line( self, line: str, placeholders: AbstractSet[str] | None = None, ) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]: """Parses the commandline and returns ParsedArgs and post-processing tokens. Args: line: The line to parse (usually contents from cell Magics). placeholders: Placeholders from prompts in the cell contents. Returns: A tuple of (parsed_args, post_processing_tokens). """ tokens, post_processing_tokens = CmdLineParser._tokenize_line(line) parsed_args = self._get_parsed_args_from_cmd_line_tokens( tokens=tokens, placeholders=placeholders ) # Special-case for "compare" command: because the prompts are compiled into # the left- and right-hand side functions rather than in the cell body, we # cannot examine the cell body to get the placeholders. # # Instead we parse the command line twice: once to get the left- and right- # functions, then we query the functions for their placeholders, then # parse the commandline again to validate the inputs. if parsed_args.cmd == parsed_args_lib.CommandName.COMPARE_CMD: assert parsed_args.lhs_name_and_fn is not None assert parsed_args.rhs_name_and_fn is not None _, lhs_fn = parsed_args.lhs_name_and_fn _, rhs_fn = parsed_args.rhs_name_and_fn parsed_args = self._get_parsed_args_from_cmd_line_tokens( tokens=tokens, placeholders=frozenset(lhs_fn.get_placeholders()).union( rhs_fn.get_placeholders() ), ) _validate_parsed_args(parsed_args) for expr in post_processing_tokens: post_process_utils.validate_one_post_processing_expression(expr) return parsed_args, post_processing_tokens def _get_parsed_args_from_cmd_line_tokens( self, tokens: Sequence[str], placeholders: AbstractSet[str] | None, ) -> parsed_args_lib.ParsedArgs: """Returns ParsedArgs from a tokenized command line.""" # Create a new parser to avoid reusing the temporary argparse.Namespace # object. results = _create_parser(placeholders).parse_args(tokens) results_dict = vars(results) results_dict["cmd"] = parsed_args_lib.CommandName(results_dict["cmd"]) results_dict, model_args = CmdLineParser._get_model_args(results_dict) results_dict["model_args"] = model_args return parsed_args_lib.ParsedArgs(**results_dict) 07070100000014000081A40000000000000000000000016459839500003B6D000000000000000000000000000000000000005400000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/cmd_line_parser_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittests for cmd_line_parser.""" from __future__ import annotations import sys from unittest import mock from absl.testing import absltest from google.generativeai.notebook import argument_parser from google.generativeai.notebook import cmd_line_parser from google.generativeai.notebook import model_registry from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import post_process_utils from google.generativeai.notebook.lib import llm_function from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_outputs from google.generativeai.notebook.lib import model as model_lib _INPUT_VAR_ONE = {"word": ["one"]} _INPUT_VAR_TWO = {"word": ["two"]} _INPUT_VAR_THREE = {"word": ["three"]} _NOT_WORD_INPUT_VAR = {"not_word": ["hello", "world"]} _OUTPUT_VAR_ONE: llmfn_outputs.LLMFnOutputs | None = None _OUTPUT_VAR_TWO: llmfn_outputs.LLMFnOutputs | None = None _GROUND_TRUTH_VAR = ["apple", "banana", "cantaloupe"] def _set_output_sink( text_result: str, sink: llmfn_outputs.LLMFnOutputsSink ) -> None: sink.write_outputs( llmfn_outputs.LLMFnOutputs( outputs=[ llmfn_outputs.LLMFnOutputEntry( prompt_num=0, input_num=0, prompt_vars={}, output_rows=[ llmfn_output_row.LLMFnOutputRow( data={ llmfn_outputs.ColumnNames.RESULT_NUM: 0, llmfn_outputs.ColumnNames.TEXT_RESULT: ( text_result ), }, result_type=str, ) ], ), ] ) ) class CmdLineParserTestBase(absltest.TestCase): def setUp(self): super().setUp() # Reset variables. global _OUTPUT_VAR_ONE global _OUTPUT_VAR_TWO _OUTPUT_VAR_ONE = None _OUTPUT_VAR_TWO = None # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class CmdLineParserCommonTest(CmdLineParserTestBase): """For tests that are not specific to any command.""" def test_parse_args_help(self): parser = cmd_line_parser.CmdLineParser() with self.assertRaises(argument_parser.ParserNormalExit): parser.parse_line("-h") with self.assertRaises(argument_parser.ParserNormalExit): parser.parse_line("--help") with self.assertRaises(argument_parser.ParserNormalExit): parser.parse_line("run -h") with self.assertRaises(argument_parser.ParserNormalExit): parser.parse_line("run --help") def test_parse_args_empty(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line("") self.assertEqual(parsed_args_lib.CommandName.RUN_CMD, results.cmd) self.assertEqual(model_registry.ModelName.TEXT_MODEL, results.model_type) self.assertEqual([], results.inputs) self.assertEqual( model_lib.ModelArguments(), results.model_args, ) def test_parse_args_no_reuse(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line("--inputs _INPUT_VAR_ONE") self.assertLen(results.inputs, 1) self.assertEqual( [{"word": "one"}], results.inputs[0].to_normalized_inputs() ) # Calling parse_line() again should return brand new results. results, _ = parser.parse_line("--inputs _INPUT_VAR_TWO _INPUT_VAR_THREE") self.assertLen(results.inputs, 2) self.assertEqual( [{"word": "two"}], results.inputs[0].to_normalized_inputs() ) self.assertEqual( [{"word": "three"}], results.inputs[1].to_normalized_inputs() ) # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class CmdLineParserModelFlagsTest(CmdLineParserTestBase): def test_parse_args_sets_model_type(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line("--model_type=echo") self.assertEqual(model_registry.ModelName.ECHO_MODEL, results.model_type) results, _ = parser.parse_line("--model_type=text") self.assertEqual(model_registry.ModelName.TEXT_MODEL, results.model_type) def test_parse_args_sets_model(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line("--model=/ml/test") self.assertEqual( model_lib.ModelArguments(model="/ml/test"), results.model_args ) def test_parse_args_sets_temperature(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line("--temperature=0") self.assertEqual( model_lib.ModelArguments(temperature=0), results.model_args ) results, _ = parser.parse_line("--temperature=0.5") self.assertEqual( model_lib.ModelArguments(temperature=0.5), results.model_args ) with self.assertRaisesRegex( argument_parser.ParserError, ( 'Error with value "-1.0", got ValueError: Value should be greater' " than or equal to zero, got -1.0" ), ): parser.parse_line("--temperature=-1") def test_parse_args_sets_candidate_count(self): parser = cmd_line_parser.CmdLineParser() # Test that the min and max values are accepted. results, _ = parser.parse_line("--candidate_count=1") self.assertEqual( model_lib.ModelArguments(candidate_count=1), results.model_args ) results, _ = parser.parse_line("--candidate_count=8") self.assertEqual( model_lib.ModelArguments(candidate_count=8), results.model_args ) # Test that values outside the min and max are rejected. with self.assertRaisesRegex( argument_parser.ParserError, ( r'Error with value "0", got ValueError: Value should be in the' r" range \[1, 8\], got 0" ), ): parser.parse_line("--candidate_count=0") with self.assertRaisesRegex( argument_parser.ParserError, ( r'Error with value "9", got ValueError: Value should be in the' r" range \[1, 8\], got 9" ), ): parser.parse_line("--candidate_count=9") def test_parse_args_sets_unique(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line("") self.assertFalse(results.unique) results, _ = parser.parse_line("--unique") self.assertTrue(results.unique) # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class CmdLineParserRunTest(CmdLineParserTestBase): """For the "run" command.""" def test_parse_args_run_is_default(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line("--model_type=echo") self.assertEqual(parsed_args_lib.CommandName.RUN_CMD, results.cmd) self.assertEqual(model_registry.ModelName.ECHO_MODEL, results.model_type) def test_parse_input_and_output_args(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line( "run --model_type=echo --inputs _INPUT_VAR_ONE _INPUT_VAR_TWO --outputs" " _OUTPUT_VAR_ONE _OUTPUT_VAR_TWO _UNDECLARED_OUTPUT_VAR" ) self.assertEqual(parsed_args_lib.CommandName.RUN_CMD, results.cmd) self.assertEqual(model_registry.ModelName.ECHO_MODEL, results.model_type) self.assertLen(results.inputs, 2) self.assertEqual( [{"word": "one"}], results.inputs[0].to_normalized_inputs() ) self.assertEqual( [{"word": "two"}], results.inputs[1].to_normalized_inputs() ) self.assertLen(results.outputs, 3) # Check that the output is going to the correct variable by writing a value # to the sink then reading it back. _set_output_sink(text_result="one", sink=results.outputs[0]) self.assertIsInstance(_OUTPUT_VAR_ONE, llmfn_outputs.LLMFnOutputs) self.assertEqual( "one", _OUTPUT_VAR_ONE[0].output_rows[0][ llmfn_outputs.ColumnNames.TEXT_RESULT ], ) _set_output_sink(text_result="two", sink=results.outputs[1]) self.assertIsInstance(_OUTPUT_VAR_TWO, llmfn_outputs.LLMFnOutputs) self.assertEqual( "two", _OUTPUT_VAR_TWO[0].output_rows[0][ llmfn_outputs.ColumnNames.TEXT_RESULT ], ) _set_output_sink(text_result="undeclared", sink=results.outputs[2]) # pylint: disable-next=undefined-variable undeclared_var = _UNDECLARED_OUTPUT_VAR # type: ignore self.assertIsInstance(undeclared_var, llmfn_outputs.LLMFnOutputs) self.assertEqual( "undeclared", undeclared_var[0].output_rows[0][llmfn_outputs.ColumnNames.TEXT_RESULT], ) def test_placeholder_error(self): parser = cmd_line_parser.CmdLineParser() with self.assertRaisesRegex( argument_parser.ParserError, ( 'argument --inputs/-i: Error with value "_NOT_WORD_INPUT_VAR", got' ' ValueError: Placeholder "word" not found in input' ), ): parser.parse_line( "run --inputs _NOT_WORD_INPUT_VAR", placeholders=frozenset({"word"}) ) # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class CmdLineParserCompileTest(CmdLineParserTestBase): """For the "compile" command.""" def test_parse_args_needs_save_name(self): parser = cmd_line_parser.CmdLineParser() with self.assertRaisesRegex( argument_parser.ParserError, "the following arguments are required: compile_save_name", ): parser.parse_line("compile") def test_parse_args_bad_save_name(self): parser = cmd_line_parser.CmdLineParser() with self.assertRaisesRegex( argument_parser.ParserError, "Invalid Python variable name" ): parser.parse_line("compile 1234") def test_parse_args_has_save_name(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line("compile my_fn") self.assertEqual("my_fn", results.compile_save_name) _test_lhs_fn = llm_function.LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["dummy lhs prompt {word}"] ) _test_rhs_fn = llm_function.LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["dummy rhs prompt {word}"] ) def _test_compare_fn(lhs: str, rhs: str) -> bool: return lhs == rhs # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class CmdLineParserCompareTest(CmdLineParserTestBase): """For the "compare" command.""" def test_compare(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line( "compare _test_lhs_fn _test_rhs_fn --inputs _INPUT_VAR_ONE" ) self.assertEqual(("_test_lhs_fn", _test_lhs_fn), results.lhs_name_and_fn) self.assertEqual(("_test_rhs_fn", _test_rhs_fn), results.rhs_name_and_fn) self.assertEmpty(results.compare_fn) def test_compare_with_custom_compare_fn(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line( "compare _test_lhs_fn _test_rhs_fn --inputs _INPUT_VAR_ONE --compare_fn" " _test_compare_fn" ) self.assertEqual(("_test_lhs_fn", _test_lhs_fn), results.lhs_name_and_fn) self.assertEqual(("_test_rhs_fn", _test_rhs_fn), results.rhs_name_and_fn) self.assertEqual( [("_test_compare_fn", _test_compare_fn)], results.compare_fn ) def test_placeholder_error(self): parser = cmd_line_parser.CmdLineParser() with self.assertRaisesRegex( argument_parser.ParserError, ( 'argument --inputs/-i: Error with value "_NOT_WORD_INPUT_VAR", got' ' ValueError: Placeholder "word" not found in input' ), ): parser.parse_line( "compare _test_lhs_fn _test_rhs_fn --inputs _NOT_WORD_INPUT_VAR" ) # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class CmdLineParserEvalTest(CmdLineParserTestBase): """For the "eval" command.""" def test_eval(self): parser = cmd_line_parser.CmdLineParser() results, _ = parser.parse_line( "eval --ground_truth _GROUND_TRUTH_VAR --inputs _INPUT_VAR_ONE" ) self.assertEqual(["apple", "banana", "cantaloupe"], results.ground_truth) def test_placeholder_error(self): parser = cmd_line_parser.CmdLineParser() with self.assertRaisesRegex( argument_parser.ParserError, ( 'argument --inputs/-i: Error with value "_NOT_WORD_INPUT_VAR", got' ' ValueError: Placeholder "word" not found in input' ), ): parser.parse_line( "eval --ground_truth _GROUND_TRUTH_VAR --inputs _NOT_WORD_INPUT_VAR", placeholders=frozenset({"word"}), ) # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class CmdLineParserPostProcessingTest(CmdLineParserTestBase): """For the "run" command.""" def test_parse_tokens(self): parser = cmd_line_parser.CmdLineParser() _, post_process_exprs = parser.parse_line("| add_length | to_upper") self.assertLen(post_process_exprs, 2) self.assertEqual(["add_length"], post_process_exprs[0]) self.assertEqual(["to_upper"], post_process_exprs[1]) def test_illformed_expression(self): parser = cmd_line_parser.CmdLineParser() with self.assertRaisesRegex( post_process_utils.PostProcessParseError, "Cannot have empty post-processing expression", ): parser.parse_line("| | to_upper") with self.assertRaisesRegex( post_process_utils.PostProcessParseError, "Cannot have empty post-processing expression", ): parser.parse_line("| ") with self.assertRaisesRegex( post_process_utils.PostProcessParseError, "Cannot have empty post-processing expression", ): parser.parse_line("| add_length |") def test_cannot_parse_multiple_tokens_in_one_expression(self): parser = cmd_line_parser.CmdLineParser() with self.assertRaisesRegex( post_process_utils.PostProcessParseError, "Post-processing expression should be a single token", ): parser.parse_line("| add_length to_upper") if __name__ == "__main__": absltest.main() 07070100000015000081A400000000000000000000000164598395000005DF000000000000000000000000000000000000004700000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/command.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Command.""" from __future__ import annotations import abc import collections from typing import Sequence from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import post_process_utils ProcessingCommand = collections.namedtuple("ProcessingCommand", ["name", "fn"]) class Command(abc.ABC): """Base class for implementation of Magics commands like "run".""" @abc.abstractmethod def execute( self, parsed_args: parsed_args_lib.ParsedArgs, cell_content: str, post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], ): """Executes the command given `parsed_args` and the `cell_content`.""" @abc.abstractmethod def parse_post_processing_tokens( self, tokens: Sequence[Sequence[str]] ) -> Sequence[post_process_utils.ParsedPostProcessExpr]: """Parses post-processing tokens for this command.""" 07070100000016000081A400000000000000000000000164598395000017A1000000000000000000000000000000000000004D00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/command_utils.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for Commands. Common methods for Commands such as RunCommand and CompileCommand. """ from __future__ import annotations from typing import AbstractSet, Any, Callable, Sequence from google.generativeai.notebook import ipython_env from google.generativeai.notebook import model_registry from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import post_process_utils from google.generativeai.notebook.lib import llm_function from google.generativeai.notebook.lib import llmfn_input_utils from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_outputs from google.generativeai.notebook.lib import unique_fn class _GroundTruthLLMFunction(llm_function.LLMFunction): """LLMFunction that returns pre-generated ground truth data.""" def __init__(self, data: Sequence[str]): super().__init__(outputs_ipython_display_fn=None) self._data = data def get_placeholders(self) -> AbstractSet[str]: # Ground truth is fixed and thus has no placeholders. return frozenset({}) def _call_impl( self, inputs: llmfn_input_utils.LLMFunctionInputs | None ) -> Sequence[llmfn_outputs.LLMFnOutputEntry]: normalized_inputs = llmfn_input_utils.to_normalized_inputs(inputs) if len(self._data) != len(normalized_inputs): raise RuntimeError( "Ground truth should have same number of entries as inputs: {} vs {}" .format(len(self._data), len(normalized_inputs)) ) outputs: list[llmfn_outputs.LLMFnOutputEntry] = [] for idx, (value, prompt_vars) in enumerate( zip(self._data, normalized_inputs) ): output_row = llmfn_output_row.LLMFnOutputRow( data={ llmfn_outputs.ColumnNames.RESULT_NUM: 0, llmfn_outputs.ColumnNames.TEXT_RESULT: value, }, result_type=str, ) outputs.append( llmfn_outputs.LLMFnOutputEntry( prompt_num=0, input_num=idx, prompt_vars=prompt_vars, output_rows=[output_row], ) ) return outputs def _get_ipython_display_fn( env: ipython_env.IPythonEnv, ) -> Callable[[llmfn_outputs.LLMFnOutputs], None]: return lambda x: env.display(x.as_pandas_dataframe()) def create_llm_function( models: model_registry.ModelRegistry, env: ipython_env.IPythonEnv | None, parsed_args: parsed_args_lib.ParsedArgs, cell_content: str, post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], ) -> llm_function.LLMFunction: """Creates an LLMFunction from Command.execute() arguments.""" prompts: list[str] = [cell_content] llmfn_outputs_display_fn = _get_ipython_display_fn(env) if env else None llm_fn = llm_function.LLMFunctionImpl( model=models.get_model(parsed_args.model_type), model_args=parsed_args.model_args, prompts=prompts, outputs_ipython_display_fn=llmfn_outputs_display_fn, ) if parsed_args.unique: llm_fn = llm_fn.add_post_process_reorder_fn( name="unique", fn=unique_fn.unique_fn ) for fn in post_processing_fns: llm_fn = fn.add_to_llm_function(llm_fn) return llm_fn def _convert_simple_compare_fn( name_and_simple_fn: tuple[str, Callable[[str, str], Any]] ) -> tuple[str, llm_function.CompareFn]: simple_fn = name_and_simple_fn[1] new_fn = lambda x, y: simple_fn(x.result_value(), y.result_value()) return name_and_simple_fn[0], new_fn def create_llm_compare_function( env: ipython_env.IPythonEnv | None, parsed_args: parsed_args_lib.ParsedArgs, post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], ) -> llm_function.LLMFunction: """Creates an LLMCompareFunction from Command.execute() arguments.""" llmfn_outputs_display_fn = _get_ipython_display_fn(env) if env else None llm_cmp_fn = llm_function.LLMCompareFunction( lhs_name_and_fn=parsed_args.lhs_name_and_fn, rhs_name_and_fn=parsed_args.rhs_name_and_fn, compare_name_and_fns=[ _convert_simple_compare_fn(x) for x in parsed_args.compare_fn ], outputs_ipython_display_fn=llmfn_outputs_display_fn, ) for fn in post_processing_fns: llm_cmp_fn = fn.add_to_llm_function(llm_cmp_fn) return llm_cmp_fn def create_llm_eval_function( models: model_registry.ModelRegistry, env: ipython_env.IPythonEnv | None, parsed_args: parsed_args_lib.ParsedArgs, cell_content: str, post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], ) -> llm_function.LLMFunction: """Creates an LLMCompareFunction from Command.execute() arguments.""" llmfn_outputs_display_fn = _get_ipython_display_fn(env) if env else None # First construct a regular LLMFunction from the cell contents. llm_fn = create_llm_function( models=models, env=env, parsed_args=parsed_args, cell_content=cell_content, post_processing_fns=post_processing_fns, ) # Next create a LLMCompareFunction. ground_truth_fn = _GroundTruthLLMFunction(data=parsed_args.ground_truth) llm_cmp_fn = llm_function.LLMCompareFunction( lhs_name_and_fn=("actual", llm_fn), rhs_name_and_fn=("ground_truth", ground_truth_fn), compare_name_and_fns=[ _convert_simple_compare_fn(x) for x in parsed_args.compare_fn ], outputs_ipython_display_fn=llmfn_outputs_display_fn, ) return llm_cmp_fn 07070100000017000081A400000000000000000000000164598395000009A2000000000000000000000000000000000000004B00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/compare_cmd.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The compare command.""" from __future__ import annotations from typing import Sequence from google.generativeai.notebook import command from google.generativeai.notebook import command_utils from google.generativeai.notebook import input_utils from google.generativeai.notebook import ipython_env from google.generativeai.notebook import output_utils from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import post_process_utils import pandas class CompareCommand(command.Command): """Implementation of "compare" command.""" def __init__( self, env: ipython_env.IPythonEnv | None = None, ): """Constructor. Args: env: The IPythonEnv environment. """ super().__init__() self._ipython_env = env def execute( self, parsed_args: parsed_args_lib.ParsedArgs, cell_content: str, post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], ) -> pandas.DataFrame: # We expect CmdLineParser to have already read the inputs once to validate # that the placeholders in the prompt are present in the inputs, so we can # suppress the status messages here. inputs = input_utils.join_inputs_sources( parsed_args, suppress_status_msgs=True ) llm_cmp_fn = command_utils.create_llm_compare_function( env=self._ipython_env, parsed_args=parsed_args, post_processing_fns=post_processing_fns, ) results = llm_cmp_fn(inputs=inputs) output_utils.write_to_outputs(results=results, parsed_args=parsed_args) return results.as_pandas_dataframe() def parse_post_processing_tokens( self, tokens: Sequence[Sequence[str]] ) -> Sequence[post_process_utils.ParsedPostProcessExpr]: if tokens: raise RuntimeError('Post-processing is not supported by "compare"') return [] 07070100000018000081A400000000000000000000000164598395000008E5000000000000000000000000000000000000004B00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/compile_cmd.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The compile command.""" from __future__ import annotations from typing import Sequence from google.generativeai.notebook import command from google.generativeai.notebook import command_utils from google.generativeai.notebook import ipython_env from google.generativeai.notebook import model_registry from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import post_process_utils from google.generativeai.notebook import py_utils class CompileCommand(command.Command): """Implementation of the "compile" command.""" def __init__( self, models: model_registry.ModelRegistry, env: ipython_env.IPythonEnv | None = None, ): """Constructor. Args: models: ModelRegistry instance. env: The IPythonEnv environment. """ super().__init__() self._models = models self._ipython_env = env def execute( self, parsed_args: parsed_args_lib.ParsedArgs, cell_content: str, post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], ) -> str: llm_fn = command_utils.create_llm_function( models=self._models, env=self._ipython_env, parsed_args=parsed_args, cell_content=cell_content, post_processing_fns=post_processing_fns, ) py_utils.set_py_var(parsed_args.compile_save_name, llm_fn) return "Saved function to Python variable: {}".format( parsed_args.compile_save_name ) def parse_post_processing_tokens( self, tokens: Sequence[Sequence[str]] ) -> Sequence[post_process_utils.ParsedPostProcessExpr]: return post_process_utils.resolve_post_processing_tokens(tokens) 07070100000019000081A40000000000000000000000016459839500000A58000000000000000000000000000000000000004800000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/eval_cmd.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The eval command.""" from __future__ import annotations from typing import Sequence from google.generativeai.notebook import command from google.generativeai.notebook import command_utils from google.generativeai.notebook import input_utils from google.generativeai.notebook import ipython_env from google.generativeai.notebook import model_registry from google.generativeai.notebook import output_utils from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import post_process_utils import pandas class EvalCommand(command.Command): """Implementation of "eval" command.""" def __init__( self, models: model_registry.ModelRegistry, env: ipython_env.IPythonEnv | None = None, ): """Constructor. Args: models: ModelRegistry instance. env: The IPythonEnv environment. """ super().__init__() self._models = models self._ipython_env = env def execute( self, parsed_args: parsed_args_lib.ParsedArgs, cell_content: str, post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], ) -> pandas.DataFrame: # We expect CmdLineParser to have already read the inputs once to validate # that the placeholders in the prompt are present in the inputs, so we can # suppress the status messages here. inputs = input_utils.join_inputs_sources( parsed_args, suppress_status_msgs=True ) llm_cmp_fn = command_utils.create_llm_eval_function( models=self._models, env=self._ipython_env, parsed_args=parsed_args, cell_content=cell_content, post_processing_fns=post_processing_fns, ) results = llm_cmp_fn(inputs=inputs) output_utils.write_to_outputs(results=results, parsed_args=parsed_args) return results.as_pandas_dataframe() def parse_post_processing_tokens( self, tokens: Sequence[Sequence[str]] ) -> Sequence[post_process_utils.ParsedPostProcessExpr]: return post_process_utils.resolve_post_processing_tokens(tokens) 0707010000001A000081A40000000000000000000000016459839500003F67000000000000000000000000000000000000004800000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/flag_def.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Classes that define arguments for populating ArgumentParser. The argparse module's ArgumentParser.add_argument() takes several parameters and is quite customizable. However this can lead to bugs where arguments do not behave as expected. For better ease-of-use and better testability, define a set of classes for the types of flags used by LLM Magics. Sample usage: str_flag = SingleValueFlagDef(name="title", required=True) enum_flag = EnumFlagDef(name="colors", required=True, enum_type=ColorsEnum) str_flag.add_argument_to_parser(my_parser) enum_flag.add_argument_to_parser(my_parser) """ from __future__ import annotations import abc import argparse import dataclasses import enum from typing import Any, Callable, Sequence, Union, Tuple from google.generativeai.notebook.lib import llmfn_inputs_source from google.generativeai.notebook.lib import llmfn_outputs # These are the intermediate types that argparse.ArgumentParser.parse_args() # will pass command line arguments into. _PARSETYPES = Union[str, int, float] # These are the final result types that the intermediate parsed values will be # converted into. It is a superset of _PARSETYPES because we support converting # the parsed type into a more precise type, e.g. from str to Enum. _DESTTYPES = Union[ _PARSETYPES, enum.Enum, Tuple[str, Callable[[str, str], Any]], # For --compare_fn Sequence[str], # For --ground_truth llmfn_inputs_source.LLMFnInputsSource, # For --inputs llmfn_outputs.LLMFnOutputsSink, # For --outputs ] # The signature of a function that converts a command line argument from the # intermediate parsed type to the result type. _PARSEFN = Callable[[_PARSETYPES], _DESTTYPES] def _get_type_name(x: type[Any]) -> str: try: return x.__name__ except AttributeError: return str(x) def _validate_flag_name(name: str) -> str: """Validation for long and short names for flags.""" if not name: raise ValueError("Cannot be empty") if name[0] == "-": raise ValueError("Cannot start with dash") return name @dataclasses.dataclass(frozen=True) class FlagDef(abc.ABC): """Abstract base class for flag definitions. Attributes: name: Long name, e.g. "colors" will define the flag "--colors". required: Whether the flag must be provided on the command line. short_name: Optional short name. parse_type: The type that ArgumentParser should parse the command line argument to. dest_type: The type that the parsed value is converted to. This is used when we want ArgumentParser to parse as one type, then convert to a different type. E.g. for enums we parse as "str" then convert to the desired enum type in order to provide cleaner help messages. parse_to_dest_type_fn: If provided, this function will be used to convert the value from `parse_type` to `dest_type`. This can be used for validation as well. choices: If provided, limit the set of acceptable values to these choices. help_msg: If provided, adds help message when -h is used in the command line. """ name: str required: bool = False short_name: str | None = None parse_type: type[_PARSETYPES] = str dest_type: type[_DESTTYPES] | None = None parse_to_dest_type_fn: _PARSEFN | None = None choices: list[_PARSETYPES] | None = None help_msg: str | None = None @abc.abstractmethod def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: """Adds this flag as an argument to `parser`. Child classes should implement this as a call to parser.add_argument() with the appropriate parameters. Args: parser: The parser to which this argument will be added. """ @abc.abstractmethod def _do_additional_validation(self) -> None: """For child classes to do additional validation.""" def _get_dest_type(self) -> type[_DESTTYPES]: """Returns the final converted type.""" return self.parse_type if self.dest_type is None else self.dest_type def _get_parse_to_dest_type_fn( self, ) -> _PARSEFN: """Returns a function to convert from parse_type to dest_type.""" if self.parse_to_dest_type_fn is not None: return self.parse_to_dest_type_fn dest_type = self._get_dest_type() if dest_type == self.parse_type: return lambda x: x else: return dest_type def __post_init__(self): _validate_flag_name(self.name) if self.short_name is not None: _validate_flag_name(self.short_name) self._do_additional_validation() def _has_non_default_value( namespace: argparse.Namespace, dest: str, has_default: bool = False, default_value: Any = None, ) -> bool: """Returns true if `namespace.dest` is set to a non-default value. Args: namespace: The Namespace that is populated by ArgumentParser. dest: The attribute in the Namespacde to be populated. has_default: "None" is a valid default value so we use an additional `has_default` boolean to indicate that `default_value` is present. default_value: The default value to use when `has_default` is True. Returns: Whether namespace.dest is set to something other than the default value. """ if not hasattr(namespace, dest): return False if not has_default: # No default value provided so `namespace.dest` cannot possibly be equal to # the default value. return True return getattr(namespace, dest) != default_value class _SingleValueStoreAction(argparse.Action): """Custom Action for storing a value in an argparse.Namespace. This action checks that the flag is specified at-most once. """ def __init__( self, option_strings, dest, dest_type: type[Any], parse_to_dest_type_fn: _PARSEFN, **kwargs, ): super().__init__(option_strings, dest, **kwargs) self._dest_type = dest_type self._parse_to_dest_type_fn = parse_to_dest_type_fn def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: str | Sequence[Any] | None, option_string: str | None = None, ): # Because `nargs` is set to 1, `values` must be a Sequence, rather # than a string. assert not isinstance(values, str) and not isinstance(values, bytes) if _has_non_default_value( namespace, self.dest, has_default=hasattr(self, "default"), default_value=getattr(self, "default"), ): raise argparse.ArgumentError( self, "Cannot set {} more than once".format(option_string) ) try: converted_value = self._parse_to_dest_type_fn(values[0]) except Exception as e: raise argparse.ArgumentError( self, 'Error with value "{}", got {}: {}'.format( values[0], _get_type_name(type(e)), e ), ) if not isinstance(converted_value, self._dest_type): raise RuntimeError( "Converted to wrong type, expected {} got {}".format( _get_type_name(self._dest_type), _get_type_name(type(converted_value)), ) ) setattr(namespace, self.dest, converted_value) class _MultiValuesAppendAction(argparse.Action): """Custom Action for appending values in an argparse.Namespace. This action checks that the flag is specified at-most once. """ def __init__( self, option_strings, dest, dest_type: type[Any], parse_to_dest_type_fn: _PARSEFN, **kwargs, ): super().__init__(option_strings, dest, **kwargs) self._dest_type = dest_type self._parse_to_dest_type_fn = parse_to_dest_type_fn def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: str | Sequence[Any] | None, option_string: str | None = None, ): # Because `nargs` is set to "+", `values` must be a Sequence, rather # than a string. assert not isinstance(values, str) and not isinstance(values, bytes) curr_value = getattr(namespace, self.dest) if curr_value: raise argparse.ArgumentError( self, "Cannot set {} more than once".format(option_string) ) for value in values: try: converted_value = self._parse_to_dest_type_fn(value) except Exception as e: raise argparse.ArgumentError( self, 'Error with value "{}", got {}: {}'.format( values[0], _get_type_name(type(e)), e ), ) if not isinstance(converted_value, self._dest_type): raise RuntimeError( "Converted to wrong type, expected {} got {}".format( self._dest_type, type(converted_value) ) ) if converted_value in curr_value: raise argparse.ArgumentError( self, 'Duplicate values "{}"'.format(value) ) curr_value.append(converted_value) class _BooleanValueStoreAction(argparse.Action): """Custom Action for setting a boolean value in argparse.Namespace. The boolean flag expects the default to be False and will set the value to True. This action checks that the flag is specified at-most once. """ def __init__( self, option_strings, dest, **kwargs, ): super().__init__(option_strings, dest, **kwargs) def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: str | Sequence[Any] | None, option_string: str | None = None, ): if _has_non_default_value( namespace, self.dest, has_default=True, default_value=False, ): raise argparse.ArgumentError( self, "Cannot set {} more than once".format(option_string) ) setattr(namespace, self.dest, True) @dataclasses.dataclass(frozen=True) class SingleValueFlagDef(FlagDef): """Definition for a flag that takes a single value. Sample usage: # This defines a flag that can be specified on the command line as: # --count=10 flag = SingleValueFlagDef(name="count", parse_type=int, required=True) flag.add_argument_to_parser(argument_parser) Attributes: default_value: Default value for optional flags. """ class _DefaultValue(enum.Enum): """Special value to represent "no value provided". "None" can be used as a default value, so in order to differentiate between "None" and "no value provided", create a special value for "no value provided". """ NOT_SET = None default_value: _DESTTYPES | _DefaultValue | None = _DefaultValue.NOT_SET def _has_default_value(self) -> bool: """Returns whether `default_value` has been provided.""" return self.default_value != SingleValueFlagDef._DefaultValue.NOT_SET def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: args = ["--" + self.name] if self.short_name is not None: args += ["-" + self.short_name] kwargs = {} if self._has_default_value(): kwargs["default"] = self.default_value if self.choices is not None: kwargs["choices"] = self.choices if self.help_msg is not None: kwargs["help"] = self.help_msg parser.add_argument( *args, action=_SingleValueStoreAction, type=self.parse_type, dest_type=self._get_dest_type(), parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(), required=self.required, nargs=1, **kwargs, ) def _do_additional_validation(self) -> None: if self.required: if self._has_default_value(): raise ValueError("Required flags cannot have default value") else: if not self._has_default_value(): raise ValueError("Optional flags must have a default value") if self._has_default_value() and self.default_value is not None: if not isinstance(self.default_value, self._get_dest_type()): raise ValueError( "Default value must be of the same type as the destination type" ) class EnumFlagDef(SingleValueFlagDef): """Definition for a flag that takes a value from an Enum. Sample usage: # This defines a flag that can be specified on the command line as: # --color=red flag = SingleValueFlagDef(name="color", enum_type=ColorsEnum, required=True) flag.add_argument_to_parser(argument_parser) """ def __init__(self, *args, enum_type: type[enum.Enum], **kwargs): if not issubclass(enum_type, enum.Enum): raise TypeError('"enum_type" must be of type Enum') # These properties are set by "enum_type" so don"t let the caller set them. if "parse_type" in kwargs: raise ValueError( 'Cannot set "parse_type" for EnumFlagDef; set "enum_type" instead' ) kwargs["parse_type"] = str if "dest_type" in kwargs: raise ValueError( 'Cannot set "dest_type" for EnumFlagDef; set "enum_type" instead' ) kwargs["dest_type"] = enum_type if "choices" in kwargs: # Verify that entries in `choices` are valid enum values. for x in kwargs["choices"]: try: enum_type(x) except ValueError: raise ValueError( 'Invalid value in "choices": "{}"'.format(x) ) from None else: kwargs["choices"] = [x.value for x in enum_type] super().__init__(*args, **kwargs) class MultiValuesFlagDef(FlagDef): """Definition for a flag that takes multiple values. Sample usage: # This defines a flag that can be specified on the command line as: # --colors=red green blue flag = MultiValuesFlagDef(name="colors", parse_type=str, required=True) flag.add_argument_to_parser(argument_parser) """ def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: args = ["--" + self.name] if self.short_name is not None: args += ["-" + self.short_name] kwargs = {} if self.choices is not None: kwargs["choices"] = self.choices if self.help_msg is not None: kwargs["help"] = self.help_msg parser.add_argument( *args, action=_MultiValuesAppendAction, type=self.parse_type, dest_type=self._get_dest_type(), parse_to_dest_type_fn=self._get_parse_to_dest_type_fn(), required=self.required, default=[], nargs="+", **kwargs, ) def _do_additional_validation(self) -> None: # No additional validation needed. pass @dataclasses.dataclass(frozen=True) class BooleanFlagDef(FlagDef): """Definition for a Boolean flag. A boolean flag is always optional with a default value of False. The flag does not take any values. Specifying the flag on the commandline will set it to True. """ def _do_additional_validation(self) -> None: if self.dest_type is not None: raise ValueError("dest_type cannot be set for BooleanFlagDef") if self.parse_to_dest_type_fn is not None: raise ValueError("parse_to_dest_type_fn cannot be set for BooleanFlagDef") if self.choices is not None: raise ValueError("choices cannot be set for BooleanFlagDef") def add_argument_to_parser(self, parser: argparse.ArgumentParser) -> None: args = ["--" + self.name] if self.short_name is not None: args += ["-" + self.short_name] kwargs = {} if self.help_msg is not None: kwargs["help"] = self.help_msg parser.add_argument( *args, action=_BooleanValueStoreAction, type=bool, required=False, default=False, nargs=0, **kwargs, ) 0707010000001B000081A4000000000000000000000001645983950000382E000000000000000000000000000000000000004D00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/flag_def_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for Argument Definition.""" from __future__ import annotations import argparse import enum import math from absl.testing import absltest from google.generativeai.notebook import argument_parser from google.generativeai.notebook import flag_def def _new_parser(flag: flag_def.FlagDef) -> argparse.ArgumentParser: """Returns a new ArgumentParser with `flag` added as an argument.""" parser = argument_parser.ArgumentParser() flag.add_argument_to_parser(parser) return parser class SingleValueFlagDefTest(absltest.TestCase): def test_short_name(self): flag = flag_def.SingleValueFlagDef( name="value", short_name="v", parse_type=str, required=True ) results = _new_parser(flag).parse_args(["--value=forty-one"]) self.assertEqual("forty-one", results.value) results = _new_parser(flag).parse_args(["-v", "forty-two"]) self.assertEqual("forty-two", results.value) def test_cardinality(self): flag = flag_def.SingleValueFlagDef( name="value", parse_type=str, required=True ) # Parser should not accept flag without no values. with self.assertRaisesRegex( argument_parser.ParserError, "expected 1 argument" ): _new_parser(flag).parse_args(["--value"]) # Parser should not accept flag with more-than-one value. with self.assertRaisesRegex( argument_parser.ParserError, "unrecognized arguments: forty-two" ): _new_parser(flag).parse_args(["--value", "forty-one", "forty-two"]) # Parser should not accept a command-line with the flag specified # more-than-once. with self.assertRaisesRegex( argument_parser.ParserError, "Cannot set --value more than once" ): _new_parser(flag).parse_args( ["--value", "forty-one", "--value", "forty-two"] ) results = _new_parser(flag).parse_args(["--value", "forty-one"]) self.assertEqual("forty-one", results.value) def test_required(self): req_flag = flag_def.SingleValueFlagDef( name="value", parse_type=str, required=True, ) # Parser is able to parse a command line containing the required argument. results = _new_parser(req_flag).parse_args(["--value=forty-two"]) self.assertEqual("forty-two", results.value) # Parser should raise an Exception if the commandline does not contain the # required argument. with self.assertRaisesRegex( argument_parser.ParserError, "the following arguments are required" ): _new_parser(req_flag).parse_args([]) def test_optional(self): # Optional flags should have a default value with self.assertRaisesRegex( ValueError, "Optional flags must have a default value" ): flag_def.SingleValueFlagDef( name="value", parse_type=str, required=False, ) opt_flag = flag_def.SingleValueFlagDef( name="value", parse_type=str, required=False, default_value="zero", ) # Parser is able to parse a command line containing the optional argument. results = _new_parser(opt_flag).parse_args(["--value=forty-two"]) self.assertEqual("forty-two", results.value) # Parser should return the default value if the command line does not # contain the optional argument. results = _new_parser(opt_flag).parse_args([]) self.assertEqual("zero", results.value) # Parser should not accept flag without a value. with self.assertRaisesRegex( argument_parser.ParserError, "expected 1 argument" ): _new_parser(opt_flag).parse_args(["--value"]) def test_default_is_none(self): """Make sure None can be accepted as a default value.""" opt_flag = flag_def.SingleValueFlagDef( name="value", parse_type=str, required=False, default_value=None, ) # Parser is able to parse a command line containing the optional argument. results = _new_parser(opt_flag).parse_args(["--value=forty-two"]) self.assertEqual("forty-two", results.value) # Parser should return the default value if the command line does not # contain the optional argument. results = _new_parser(opt_flag).parse_args([]) self.assertIsNone(results.value) # Parser should not accept flag without a value. with self.assertRaisesRegex( argument_parser.ParserError, "expected 1 argument" ): _new_parser(opt_flag).parse_args(["--value"]) def test_type_conversion(self): # Default value must be of the same type as destination type. with self.assertRaisesRegex( ValueError, "Default value must be of the same type as the destination type", ): flag_def.SingleValueFlagDef( name="value", parse_type=int, required=False, default_value="zero", ) int_flag_def = flag_def.SingleValueFlagDef( name="value", parse_type=int, required=False, default_value=0, ) # Parser should not accept a value of the wrong type. with self.assertRaisesRegex( argument_parser.ParserError, "invalid int value: 'forty-two'" ): _new_parser(int_flag_def).parse_args(["--value", "forty-two"]) results = _new_parser(int_flag_def).parse_args(["--value", "42"]) self.assertEqual(42, results.value) def test_validation(self): def _check_is_not_nan(x: float) -> float: if math.isnan(x): raise ValueError("Must not be NAN") return x float_flag_def = flag_def.SingleValueFlagDef( name="value", parse_type=float, parse_to_dest_type_fn=_check_is_not_nan, required=True, ) results = _new_parser(float_flag_def).parse_args(["--value", "0.25"]) self.assertEqual(0.25, results.value) with self.assertRaisesRegex(argument_parser.ParserError, "Must not be NAN"): _new_parser(float_flag_def).parse_args(["--value", "nan"]) class ColorsEnum(enum.Enum): RED = "red" GREEN = "green" BLUE = "blue" class EnumFlagDefTest(absltest.TestCase): def test_construction(self): # "enum_type" must be provided. with self.assertRaisesRegex( TypeError, "missing 1 required keyword-only argument" ): # pylint: disable-next=missing-kwoa flag_def.EnumFlagDef(name="color", required=True) # type: ignore # "parse_type" cannot be provided. with self.assertRaisesRegex( ValueError, 'Cannot set "parse_type" for EnumFlagDef' ): flag_def.EnumFlagDef( name="color", required=True, enum_type=ColorsEnum, parse_type=int ) # "dest_type" cannot be provided. with self.assertRaisesRegex( ValueError, 'Cannot set "dest_type" for EnumFlagDef' ): flag_def.EnumFlagDef( name="color", required=True, enum_type=ColorsEnum, dest_type=str ) # This should succeed. flag_def.EnumFlagDef(name="color", required=True, enum_type=ColorsEnum) def test_parsing(self): flag = flag_def.EnumFlagDef( name="color", required=False, enum_type=ColorsEnum, default_value=ColorsEnum.RED, ) # "teal" is not one of the enum values. with self.assertRaisesRegex( argument_parser.ParserError, "invalid choice: 'teal'" ): _new_parser(flag).parse_args(["--color=teal"]) results = _new_parser(flag).parse_args(["--color=red"]) self.assertEqual(ColorsEnum.RED, results.color) results = _new_parser(flag).parse_args(["--color=green"]) self.assertEqual(ColorsEnum.GREEN, results.color) results = _new_parser(flag).parse_args(["--color=blue"]) self.assertEqual(ColorsEnum.BLUE, results.color) def test_choices(self): # If `choices` is provided, all values must be valid. with self.assertRaisesRegex(ValueError, 'Invalid value in "choices"'): flag_def.EnumFlagDef( name="color", required=True, enum_type=ColorsEnum, choices=["red", "green", "teal"], ) # Exclude "blue". flag = flag_def.EnumFlagDef( name="color", required=True, enum_type=ColorsEnum, choices=["red", "green"], ) # "blue" is no longer one of the choices. with self.assertRaisesRegex( argument_parser.ParserError, "invalid choice: 'blue'" ): _new_parser(flag).parse_args(["--color=blue"]) results = _new_parser(flag).parse_args(["--color=red"]) self.assertEqual(ColorsEnum.RED, results.color) results = _new_parser(flag).parse_args(["--color=green"]) self.assertEqual(ColorsEnum.GREEN, results.color) class MultiValuesFlagDefTest(absltest.TestCase): def test_basic(self): # Default value is not needed even if optional; the value would just be the # empty list. flag = flag_def.MultiValuesFlagDef( name="colors", parse_type=str, required=False ) # Default value is the empty list. results = _new_parser(flag).parse_args([]) self.assertEmpty(results.colors) results = _new_parser(flag).parse_args(["--colors", "red"]) self.assertEqual(["red"], results.colors) results = _new_parser(flag).parse_args(["--colors", "red", "green"]) self.assertEqual(["red", "green"], results.colors) def test_required(self): flag = flag_def.MultiValuesFlagDef( name="colors", parse_type=str, required=True, ) # Parser is able to parse a command line containing the required argument. results = _new_parser(flag).parse_args(["--colors", "red"]) self.assertEqual(["red"], results.colors) # Parser should raise an Exception if the commandline does not contain the # required argument. with self.assertRaisesRegex( argument_parser.ParserError, "the following arguments are required" ): _new_parser(flag).parse_args([]) def test_cannot_set_default_value(self): # `default_value` is not a field for MultiValueFlagsDef. with self.assertRaisesRegex( TypeError, "got an unexpected keyword argument" ): # pylint: disable-next=unexpected-keyword-arg flag_def.MultiValuesFlagDef( # type: ignore name="colors", parse_type=str, required=False, default_value="fuschia", ) def test_values_must_be_unique(self): flag = flag_def.MultiValuesFlagDef(name="colors") # Cannot specify "red" more than once. with self.assertRaisesRegex( argument_parser.ParserError, 'Duplicate values "red"' ): _new_parser(flag).parse_args(["--colors", "red", "green", "red"]) def test_cardinality(self): flag = flag_def.MultiValuesFlagDef( name="colors", parse_type=str, required=False, ) # Must have at least one argument. with self.assertRaisesRegex( argument_parser.ParserError, "expected at least one argument" ): _new_parser(flag).parse_args(["--colors"]) # Cannot specify "--colors" more than once. with self.assertRaisesRegex( argument_parser.ParserError, "Cannot set --colors more than once" ): _new_parser(flag).parse_args(["--colors", "red", "--colors", "blue"]) def test_dest_type_conversion(self): flag = flag_def.MultiValuesFlagDef( name="colors", parse_type=str, dest_type=ColorsEnum, required=False, choices=[x.value for x in ColorsEnum], ) # "fuschia" is not a valid value for enum. with self.assertRaisesRegex( argument_parser.ParserError, "invalid choice: 'fuschia'" ): _new_parser(flag).parse_args(["--colors", "fuschia"]) # Results are converted to a list of enums. results = _new_parser(flag).parse_args(["--colors", "red", "green"]) self.assertEqual([ColorsEnum.RED, ColorsEnum.GREEN], results.colors) def test_validation(self): def _check_is_not_nan(x: float) -> float: if math.isnan(x): raise ValueError("Must not be NAN") return x flag = flag_def.MultiValuesFlagDef( name="values", parse_type=float, parse_to_dest_type_fn=_check_is_not_nan, ) results = _new_parser(flag).parse_args(["--value", "0.25", "0.5"]) self.assertEqual([0.25, 0.5], results.values) with self.assertRaisesRegex(argument_parser.ParserError, "Must not be NAN"): _new_parser(flag).parse_args(["--value", "0.25", "nan"]) class BooleanFlagDefTest(absltest.TestCase): def test_basic(self): flag = flag_def.BooleanFlagDef(name="unique") results = _new_parser(flag).parse_args([]) self.assertFalse(results.unique) results = _new_parser(flag).parse_args(["--unique"]) self.assertTrue(results.unique) def test_constructor(self): """Check that invalid constructor arguments are rejected.""" with self.assertRaisesRegex( ValueError, "dest_type cannot be set for BooleanFlagDef" ): flag_def.BooleanFlagDef(name="unique", dest_type=bool) with self.assertRaisesRegex( ValueError, "parse_to_dest_type_fn cannot be set for BooleanFlagDef" ): flag_def.BooleanFlagDef( name="unique", parse_to_dest_type_fn=lambda x: True ) with self.assertRaisesRegex( ValueError, "choices cannot be set for BooleanFlagDef" ): flag_def.BooleanFlagDef(name="unique", choices=[True]) def test_cardinality(self): flag = flag_def.BooleanFlagDef(name="unique") with self.assertRaisesRegex( argument_parser.ParserError, "error: unrecognized arguments: True" ): _new_parser(flag).parse_args(["--unique", "True"]) with self.assertRaisesRegex( argument_parser.ParserError, "Cannot set --unique more than once" ): _new_parser(flag).parse_args(["--unique", "--unique"]) if __name__ == "__main__": absltest.main() 0707010000001C000081A40000000000000000000000016459839500001D1A000000000000000000000000000000000000004E00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/gspread_client.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module that holds a global gspread.client.Client.""" from __future__ import annotations import abc import datetime from typing import Any, Callable, Mapping, Sequence from google.auth import credentials from google.generativeai.notebook import html_utils from google.generativeai.notebook import ipython_env from google.generativeai.notebook import sheets_id # The code may be running in an environment where the gspread library has not # been installed. _gspread_import_error: Exception | None = None try: # pylint: disable-next=g-import-not-at-top from gspread import gspread except (ImportError, ModuleNotFoundError): try: # pylint: disable-next=g-import-not-at-top import gspread except ImportError as e: _gspread_import_error = e gspread = None # Base class of exceptions that gspread.open(), open_by_url() and open_by_key() # may throw. GSpreadException = Exception if gspread is None else gspread.exceptions.GSpreadException # type: ignore class SpreadsheetNotFoundError(RuntimeError): pass def _get_import_error() -> Exception: return RuntimeError( '"gspread" module not imported, got: {}'.format(_gspread_import_error) ) class GSpreadClient(abc.ABC): """Wrapper around gspread.client.Client. This adds a layer of indirection for us to inject mocks for testing. """ @abc.abstractmethod def validate(self, sid: sheets_id.SheetsIdentifier) -> None: """Validates that `name` is the name of a Google Sheets document. Raises an exception if false. Args: sid: The identifier for the document. """ @abc.abstractmethod def get_all_records( self, sid: sheets_id.SheetsIdentifier, worksheet_id: int, ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: """Returns all records for a Google Sheets worksheet.""" @abc.abstractmethod def write_records( self, sid: sheets_id.SheetsIdentifier, rows: Sequence[Sequence[Any]], ) -> None: """Writes results to a new worksheet to the Google Sheets document.""" class GSpreadClientImpl(GSpreadClient): """Concrete implementation of GSpreadClient.""" def __init__(self, client: Any, env: ipython_env.IPythonEnv | None): """Constructor. Args: client: Instance of gspread.client.Client. env: Optional instance of IPythonEnv. This is used to display messages such as the URL of the output Worksheet. """ self._client = client self._ipython_env = env def _open(self, sid: sheets_id.SheetsIdentifier): """Opens a Sheets document from `sid`. Args: sid: The identifier for the Sheets document. Raises: SpreadsheetNotFoundError: If the Sheets document cannot be found or cannot be opened. Returns: A gspread.Worksheet instance representing the worksheet referred to by `sid`. """ try: if sid.name(): return self._client.open(sid.name()) if sid.key(): return self._client.open_by_key(str(sid.key())) if sid.url(): return self._client.open_by_url(str(sid.url())) except GSpreadException as exc: raise SpreadsheetNotFoundError( "Unable to find Sheets with {}".format(sid) ) from exc raise SpreadsheetNotFoundError("Invalid sheets_id.SheetsIdentifier") def validate(self, sid: sheets_id.SheetsIdentifier) -> None: self._open(sid) def get_all_records( self, sid: sheets_id.SheetsIdentifier, worksheet_id: int, ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: sheet = self._open(sid) worksheet = sheet.get_worksheet(worksheet_id) if self._ipython_env is not None: env = self._ipython_env def _display_fn(): env.display_html( "Reading inputs from worksheet {}".format( html_utils.get_anchor_tag( url=sheets_id.SheetsURL(worksheet.url), text="{} in {}".format(worksheet.title, sheet.title), ) ) ) else: def _display_fn(): print( "Reading inputs from worksheet {} in {}".format( worksheet.title, sheet.title ) ) return worksheet.get_all_records(), _display_fn def write_records( self, sid: sheets_id.SheetsIdentifier, rows: Sequence[Sequence[Any]], ) -> None: sheet = self._open(sid) # Create a new Worksheet. # `title` has to be carefully constructed: some characters like colon ":" # will not work with gspread in Worksheet.append_rows(). current_datetime = datetime.datetime.now() title = f"Results {current_datetime:%Y_%m_%d} ({current_datetime:%s})" # append_rows() will resize the worksheet as needed, so `rows` and `cols` # can be set to 1 to create a worksheet with only a single cell. worksheet = sheet.add_worksheet(title=title, rows=1, cols=1) worksheet.append_rows(values=rows) if self._ipython_env is not None: self._ipython_env.display_html( "Results written to new worksheet {}".format( html_utils.get_anchor_tag( url=sheets_id.SheetsURL(worksheet.url), text="{} in {}".format(worksheet.title, sheet.title), ) ) ) else: print( "Results written to new worksheet {} in {}".format( worksheet.title, sheet.title ) ) class NullGSpreadClient(GSpreadClient): """Null-object implementation of GSpreadClient. This class raises an error if any of its methods are called. It is used when the gspread library is not available. """ def validate(self, sid: sheets_id.SheetsIdentifier) -> None: raise _get_import_error() def get_all_records( self, sid: sheets_id.SheetsIdentifier, worksheet_id: int, ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: raise _get_import_error() def write_records( self, sid: sheets_id.SheetsIdentifier, rows: Sequence[Sequence[Any]], ) -> None: raise _get_import_error() # Global instance of gspread client. _gspread_client: GSpreadClient | None = None def authorize( creds: credentials.Credentials, env: ipython_env.IPythonEnv | None ) -> None: """Sets up credential for gspreads.""" global _gspread_client if gspread is not None: client = gspread.authorize(creds) # type: ignore _gspread_client = GSpreadClientImpl(client=client, env=env) else: _gspread_client = NullGSpreadClient() def get_client() -> GSpreadClient: if not _gspread_client: raise RuntimeError("Must call authorize() first") return _gspread_client def testonly_set_client(client: GSpreadClient) -> None: """Overrides the global client for testing.""" global _gspread_client _gspread_client = client 0707010000001D000081A400000000000000000000000164598395000005E9000000000000000000000000000000000000004A00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/html_utils.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for generating HTML.""" from __future__ import annotations from xml.etree import ElementTree from google.generativeai.notebook import sheets_id def get_anchor_tag(url: sheets_id.SheetsURL, text: str) -> str: """Returns a HTML string representing an anchor tag. This class uses the xml.etree library to handle HTML escaping. Args: url: The Sheets URL to link to. text: The text body of the link. Returns: A string representing a HTML fragment. """ tag = ElementTree.Element( "a", attrib={ # Open in a new window/tab "target": "_blank", # See: # https://developer.chrome.com/en/docs/lighthouse/best-practices/external-anchors-use-rel-noopener/ "rel": "noopener", "href": str(url), }, ) tag.text = text if text else "link" return ElementTree.tostring(tag, encoding="unicode", method="html") 0707010000001E000081A4000000000000000000000001645983950000070A000000000000000000000000000000000000004F00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/html_utils_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for html_utils.""" from __future__ import annotations from absl.testing import absltest from google.generativeai.notebook import html_utils from google.generativeai.notebook import sheets_id class HtmlUtilsTest(absltest.TestCase): def test_get_anchor_tag_text_is_escaped(self): html = html_utils.get_anchor_tag( url=sheets_id.SheetsURL("https://docs.google.com/?a=b#hello"), text="hello<evil_tag/>world", ) self.assertEqual( ( '<a target="_blank" rel="noopener"' ' href="https://docs.google.com/?a=b#hello">hello<evil_tag/>world</a>' ), html, ) def test_get_anchor_tag_url_is_escaped(self): url = sheets_id.SheetsURL("https://docs.google.com/") # Break encapsulation to modify the URL. url._url = 'https://docs.google.com/"evil_string"' html = html_utils.get_anchor_tag( url=url, text="hello world", ) self.assertEqual( ( '<a target="_blank" rel="noopener"' ' href="https://docs.google.com/"evil_string"">hello' " world</a>" ), html, ) if __name__ == "__main__": absltest.main() 0707010000001F000081A40000000000000000000000016459839500000AB7000000000000000000000000000000000000004B00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/input_utils.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for handling input variables.""" from __future__ import annotations from typing import Callable, Mapping from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import py_utils from google.generativeai.notebook.lib import llmfn_input_utils from google.generativeai.notebook.lib import llmfn_inputs_source class _NormalizedInputsSource(llmfn_inputs_source.LLMFnInputsSource): """Wrapper around NormalizedInputsList. By design LLMFunction does not take NormalizedInputsList as input because NormalizedInputsList is an internal representation so we want to minimize exposure to the caller. When we have inputs already in normalized format (e.g. from join_prompt_inputs()) we can wrap it as an LLMFnInputsSource to pass as an input to LLMFunction. """ def __init__( self, normalized_inputs: llmfn_inputs_source.NormalizedInputsList ): super().__init__() self._normalized_inputs = normalized_inputs def _to_normalized_inputs_impl( self, ) -> tuple[llmfn_inputs_source.NormalizedInputsList, Callable[[], None]]: return self._normalized_inputs, lambda: None def get_inputs_source_from_py_var( var_name: str, ) -> llmfn_inputs_source.LLMFnInputsSource: data = py_utils.get_py_var(var_name) if isinstance(data, llmfn_inputs_source.LLMFnInputsSource): # No conversion needed. return data normalized_inputs = llmfn_input_utils.to_normalized_inputs(data) return _NormalizedInputsSource(normalized_inputs) def join_inputs_sources( parsed_args: parsed_args_lib.ParsedArgs, suppress_status_msgs: bool = False, ) -> llmfn_inputs_source.LLMFnInputsSource: """Get a single combined input source from `parsed_args.""" combined_inputs: list[Mapping[str, str]] = [] for source in parsed_args.inputs: combined_inputs.extend( source.to_normalized_inputs(suppress_status_msgs=suppress_status_msgs) ) for source in parsed_args.sheets_input_names: combined_inputs.extend( source.to_normalized_inputs(suppress_status_msgs=suppress_status_msgs) ) return _NormalizedInputsSource(combined_inputs) 07070100000020000081A4000000000000000000000001645983950000095C000000000000000000000000000000000000005000000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/input_utils_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for py_utils.""" from __future__ import annotations import sys from unittest import mock from absl.testing import absltest from google.generativeai.notebook import input_utils _EMPTY_INPUT_VAR_ONE = {} _EMPTY_INPUT_VAR_TWO = {"word": []} _INPUT_VAR_ONE = {"word": ["lukewarm"]} _INPUT_VAR_TWO = {"word": ["hot", "cold"]} _MULTI_INPUTS_VAR_ONE = {"a": ["apple"], "b": ["banana"]} _MULTI_INPUTS_VAR_TWO = {"a": ["australia", "alpha"], "b": ["brazil", "beta"]} # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class InputUtilsTest(absltest.TestCase): def test_get_inputs_source_from_py_var_invalid_name(self): with self.assertRaisesRegex(NameError, "UnknownVar"): input_utils.get_inputs_source_from_py_var("UnknownVar") def test_get_inputs_source_from_py_var_empty_one(self): source = input_utils.get_inputs_source_from_py_var("_EMPTY_INPUT_VAR_ONE") results = source.to_normalized_inputs() self.assertEmpty(results) def test_get_inputs_source_from_py_var_empty_two(self): source = input_utils.get_inputs_source_from_py_var("_EMPTY_INPUT_VAR_TWO") results = source.to_normalized_inputs() self.assertEmpty(results) def test_get_inputs_source_from_py_var_single_input_one(self): source = input_utils.get_inputs_source_from_py_var("_INPUT_VAR_ONE") results = source.to_normalized_inputs() self.assertEqual([{"word": "lukewarm"}], results) def test_get_inputs_source_from_py_var_single_input_two(self): source = input_utils.get_inputs_source_from_py_var("_INPUT_VAR_TWO") results = source.to_normalized_inputs() self.assertEqual([{"word": "hot"}, {"word": "cold"}], results) if __name__ == "__main__": absltest.main() 07070100000021000081A40000000000000000000000016459839500000664000000000000000000000000000000000000004B00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/ipython_env.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Abstract IPythonEnv base class. This module provides a layer of abstraction to address the following problems: 1. Sometimes the code needs to run in an environment where IPython is not available, e.g. inside a unittest. 2. We want to limit dependencies on IPython to code that deals directly with the notebook environment. """ from __future__ import annotations import abc from typing import Any class IPythonEnv(abc.ABC): """Abstract base class that provides a wrapper around IPython methods.""" @abc.abstractmethod def display(self, x: Any) -> None: """Wrapper around IPython.core.display.display().""" @abc.abstractmethod def display_html(self, x: str) -> None: """Wrapper to display HTML. This method is equivalent to calling: display.display(display.HTML(x)) display() and HTML() are combined into a single method because display.HTML() returns an object, which would be complicated to model with this abstract interface. Args: x: An HTML string to be displayed. """ 07070100000022000081A40000000000000000000000016459839500000414000000000000000000000000000000000000005000000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/ipython_env_impl.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """IPythonEnvImpl.""" from __future__ import annotations from typing import Any from google.generativeai.notebook import ipython_env from IPython.core import display as ipython_display class IPythonEnvImpl(ipython_env.IPythonEnv): """Concrete implementation of IPythonEnv.""" def display(self, x: Any) -> None: ipython_display.display(x) def display_html(self, x: str) -> None: ipython_display.display(ipython_display.HTML(x)) 07070100000023000041ED0000000000000000000000026459839500000000000000000000000000000000000000000000004000000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib07070100000024000081A40000000000000000000000016459839500000257000000000000000000000000000000000000004C00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/__init__.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. 07070100000025000081A40000000000000000000000016459839500004266000000000000000000000000000000000000005000000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llm_function.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LLMFunction.""" from __future__ import annotations import abc import dataclasses from typing import AbstractSet, Any, Callable, Iterable, Mapping, Optional, Sequence from google.generativeai.notebook.lib import llmfn_input_utils from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_outputs from google.generativeai.notebook.lib import llmfn_post_process from google.generativeai.notebook.lib import llmfn_post_process_cmds from google.generativeai.notebook.lib import model as model_lib from google.generativeai.notebook.lib import prompt_utils # In the same spirit as post-processing functions (see: llmfn_post_process.py), # we keep the LLM functions more flexible by providing the entire left- and # right-hand side rows to the user-defined comparison function. # # Possible use-cases include adding a scoring function as a post-process # command, then comparing the scores. CompareFn = Callable[ [llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView], Any, ] def _is_equal_fn( lhs: llmfn_output_row.LLMFnOutputRowView, rhs: llmfn_output_row.LLMFnOutputRowView, ) -> bool: """Default function used when comparing outputs.""" return lhs.result_value() == rhs.result_value() def _convert_compare_fn_to_batch_add_fn( fn: Callable[ [ llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, ], Any, ] ) -> llmfn_post_process.LLMCompareFnPostProcessBatchAddFn: """Vectorize a single-row-based comparison function.""" def _fn( lhs_and_rhs_rows: Sequence[ tuple[ llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, ] ] ) -> Sequence[Any]: return [fn(lhs, rhs) for lhs, rhs in lhs_and_rhs_rows] return _fn @dataclasses.dataclass class _PromptInfo: prompt_num: int prompt: str input_num: int prompt_vars: Mapping[str, str] model_input: str def _generate_prompts( prompts: Sequence[str], inputs: llmfn_input_utils.LLMFunctionInputs | None ) -> Iterable[_PromptInfo]: """Generate a tuple of fields needed for processing prompts. Args: prompts: A list of prompts, with optional keyword placeholders. inputs: A list of key/value pairs to substitute into placeholders in `prompts`. Yields: A _PromptInfo instance. """ normalized_inputs: Sequence[Mapping[str, str]] = [] if inputs is not None: normalized_inputs = llmfn_input_utils.to_normalized_inputs(inputs) # Must have at least one entry so that we execute the prompt at least once. if not normalized_inputs: normalized_inputs = [{}] for prompt_num, prompt in enumerate(prompts): for input_num, prompt_vars in enumerate(normalized_inputs): # Perform keyword substitution on the prompt based on `prompt_vars`. model_input = prompt.format(**prompt_vars) yield _PromptInfo( prompt_num=prompt_num, prompt=prompt, input_num=input_num, prompt_vars=prompt_vars, model_input=model_input, ) class LLMFunction( Callable[ [Optional[llmfn_input_utils.LLMFunctionInputs]], llmfn_outputs.LLMFnOutputs, ], metaclass=abc.ABCMeta, ): """Base class for LLMFunctionImpl and LLMCompareFunction.""" def __init__( self, outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None] | None = None, ): """Constructor. Args: outputs_ipython_display_fn: Optional function that will be used to override how the outputs of this LLMFunction will be displayed in a notebook (See further documentation in LLMFnOutputs.__init__().) """ self._post_process_cmds: list[ llmfn_post_process_cmds.LLMFnPostProcessCommand ] = [] self._outputs_ipython_display_fn = outputs_ipython_display_fn @abc.abstractmethod def get_placeholders(self) -> AbstractSet[str]: """Returns the placeholders that should be present in inputs for this function.""" @abc.abstractmethod def _call_impl( self, inputs: llmfn_input_utils.LLMFunctionInputs | None ) -> Sequence[llmfn_outputs.LLMFnOutputEntry]: """Concrete implementation of __call__().""" def __call__( self, inputs: llmfn_input_utils.LLMFunctionInputs | None = None ) -> llmfn_outputs.LLMFnOutputs: """Runs and returns results based on `inputs`.""" outputs = self._call_impl(inputs) return llmfn_outputs.LLMFnOutputs( outputs=outputs, ipython_display_fn=self._outputs_ipython_display_fn ) def add_post_process_reorder_fn( self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReorderFn ) -> LLMFunction: self._post_process_cmds.append( llmfn_post_process_cmds.LLMFnPostProcessReorderCommand(name=name, fn=fn) ) return self def add_post_process_add_fn( self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchAddFn, ) -> LLMFunction: self._post_process_cmds.append( llmfn_post_process_cmds.LLMFnPostProcessAddCommand(name=name, fn=fn) ) return self def add_post_process_replace_fn( self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReplaceFn, ) -> LLMFunction: self._post_process_cmds.append( llmfn_post_process_cmds.LLMFnPostProcessReplaceCommand(name=name, fn=fn) ) return self class LLMFunctionImpl(LLMFunction): """Callable class that executes the contents of a Magics cell. An LLMFunction is constructed from the Magics command line and cell contents specified by the user. It is defined by: - A model instance, - Model arguments - A prompt template (e.g. "the opposite of hot is {word}") with an optional keyword placeholder. The LLMFunction takes as its input a sequence of dictionaries containing values for keyword replacement, e.g. [{"word": "hot"}, {"word": "tall"}]. This will cause the model to be executed with the following prompts: "The opposite of hot is" "The opposite of tall is" The results will be returned in a LLMFnOutputs instance. """ def __init__( self, model: model_lib.AbstractModel, prompts: Sequence[str], model_args: model_lib.ModelArguments | None = None, outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None] | None = None, ): """Constructor. Args: model: The model that the prompts will execute on. prompts: A sequence of prompt templates with optional placeholders. The placeholders will be replaced by the inputs passed into this function. model_args: Optional set of model arguments to configure how the model executes the prompts. outputs_ipython_display_fn: See documentation in LLMFunction.__init__(). """ super().__init__(outputs_ipython_display_fn=outputs_ipython_display_fn) self._model = model self._prompts = prompts self._model_args = ( model_lib.ModelArguments() if model_args is None else model_args ) # Compute placeholders. self._placeholders = frozenset({}) for prompt in self._prompts: self._placeholders = self._placeholders.union( prompt_utils.get_placeholders(prompt) ) def _run_post_processing_cmds( self, results: Sequence[llmfn_output_row.LLMFnOutputRow] ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: """Runs post-processing commands over `results`.""" for cmd in self._post_process_cmds: try: if isinstance(cmd, llmfn_post_process_cmds.LLMFnImplPostProcessCommand): results = cmd.run(results) else: raise llmfn_post_process.PostProcessExecutionError( "Unsupported post-process command type: {}".format(type(cmd)) ) except llmfn_post_process.PostProcessExecutionError: raise except RuntimeError as e: raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}", got {}: {}'.format( cmd.name(), type(e).__name__, e ) ) return results def get_placeholders(self) -> AbstractSet[str]: return self._placeholders def _call_impl( self, inputs: llmfn_input_utils.LLMFunctionInputs | None ) -> Sequence[llmfn_outputs.LLMFnOutputEntry]: results: list[llmfn_outputs.LLMFnOutputEntry] = [] for info in _generate_prompts(prompts=self._prompts, inputs=inputs): model_results = self._model.call_model( model_input=info.model_input, model_args=self._model_args ) output_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for result_num, text_result in enumerate(model_results.text_results): output_rows.append( llmfn_output_row.LLMFnOutputRow( data={ llmfn_outputs.ColumnNames.RESULT_NUM: result_num, llmfn_outputs.ColumnNames.TEXT_RESULT: text_result, }, result_type=str, ) ) results.append( llmfn_outputs.LLMFnOutputEntry( prompt_num=info.prompt_num, input_num=info.input_num, prompt=info.prompt, prompt_vars=info.prompt_vars, model_input=info.model_input, model_results=model_results, output_rows=self._run_post_processing_cmds(output_rows), ) ) return results class LLMCompareFunction(LLMFunction): """LLMFunction for comparisons. LLMCompareFunction runs an input over a pair of LLMFunctions and compares the result. """ def __init__( self, lhs_name_and_fn: tuple[str, LLMFunction], rhs_name_and_fn: tuple[str, LLMFunction], compare_name_and_fns: Sequence[tuple[str, CompareFn]] | None = None, outputs_ipython_display_fn: Callable[[llmfn_outputs.LLMFnOutputs], None] | None = None, ): """Constructor. Args: lhs_name_and_fn: Name and function for the left-hand side of the comparison. rhs_name_and_fn: Name and function for the right-hand side of the comparison. compare_name_and_fns: Optional names and functions for comparing the results of the left- and right-hand sides. outputs_ipython_display_fn: See documentation in LLMFunction.__init__(). """ super().__init__(outputs_ipython_display_fn=outputs_ipython_display_fn) self._lhs_name: str = lhs_name_and_fn[0] self._lhs_fn: LLMFunction = lhs_name_and_fn[1] self._rhs_name: str = rhs_name_and_fn[0] self._rhs_fn: LLMFunction = rhs_name_and_fn[1] self._placeholders = frozenset(self._lhs_fn.get_placeholders()).union( self._rhs_fn.get_placeholders() ) if not compare_name_and_fns: self._result_name = "is_equal" self._result_compare_fn = _is_equal_fn else: # Assume the last entry in `compare_name_and_fns` is the one that # produces value for the result cell. name, fn = compare_name_and_fns[-1] self._result_name = name self._result_compare_fn = fn # Treat the other compare_fns as post-processing operators. for name, cmp_fn in compare_name_and_fns[:-1]: self.add_compare_post_process_add_fn( name=name, fn=_convert_compare_fn_to_batch_add_fn(cmp_fn) ) def _run_post_processing_cmds( self, lhs_output_rows: Sequence[llmfn_output_row.LLMFnOutputRow], rhs_output_rows: Sequence[llmfn_output_row.LLMFnOutputRow], results: Sequence[llmfn_output_row.LLMFnOutputRow], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: """Runs post-processing commands over `results`.""" for cmd in self._post_process_cmds: try: if isinstance(cmd, llmfn_post_process_cmds.LLMFnImplPostProcessCommand): results = cmd.run(results) elif isinstance( cmd, llmfn_post_process_cmds.LLMCompareFnPostProcessCommand ): results = cmd.run( list(zip(lhs_output_rows, rhs_output_rows, results)) ) else: raise RuntimeError( "Unsupported post-process command type: {}".format(type(cmd)) ) except llmfn_post_process.PostProcessExecutionError: raise except RuntimeError as e: raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}", got {}: {}'.format( cmd.name(), type(e).__name__, e ) ) return results def get_placeholders(self) -> AbstractSet[str]: return self._placeholders def _call_impl( self, inputs: llmfn_input_utils.LLMFunctionInputs | None ) -> Sequence[llmfn_outputs.LLMFnOutputEntry]: lhs_results = self._lhs_fn(inputs) rhs_results = self._rhs_fn(inputs) # Combine the results. outputs: list[llmfn_outputs.LLMFnOutputEntry] = [] for lhs_entry, rhs_entry in zip(lhs_results, rhs_results): if lhs_entry.prompt_num != rhs_entry.prompt_num: raise RuntimeError( "Prompt num mismatch: {} vs {}".format( lhs_entry.prompt_num, rhs_entry.prompt_num ) ) if lhs_entry.input_num != rhs_entry.input_num: raise RuntimeError( "Input num mismatch: {} vs {}".format( lhs_entry.input_num, rhs_entry.input_num ) ) if lhs_entry.prompt_vars != rhs_entry.prompt_vars: raise RuntimeError( "Prompt vars mismatch: {} vs {}".format( lhs_entry.prompt_vars, rhs_entry.prompt_vars ) ) # The two functions may have different numbers of results due to # options like candidate_count, so we can only compare up to the # minimum of the two. num_output_rows = min( len(lhs_entry.output_rows), len(rhs_entry.output_rows) ) lhs_output_rows = lhs_entry.output_rows[:num_output_rows] rhs_output_rows = rhs_entry.output_rows[:num_output_rows] output_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for result_num, lhs_and_rhs_output_row in enumerate( zip(lhs_output_rows, rhs_output_rows) ): lhs_output_row, rhs_output_row = lhs_and_rhs_output_row # Combine cells from lhs_output_row and rhs_output_row into a # single row. # Although it is possible for RESULT_NUM (the index of each # text_result if a prompt produces multiple text_results) to be # different between the left and right sides, we ignore their # RESULT_NUM entries and write our own. row_data: dict[str, Any] = { llmfn_outputs.ColumnNames.RESULT_NUM: result_num, self._result_name: self._result_compare_fn( lhs_output_row, rhs_output_row ), } output_row = llmfn_output_row.LLMFnOutputRow( data=row_data, result_type=Any ) # Add the prompt vars. output_row.add( llmfn_outputs.ColumnNames.PROMPT_VARS, lhs_entry.prompt_vars ) # Add the results from the left-hand side and right-hand side. for name, row in [ (self._lhs_name, lhs_output_row), (self._rhs_name, rhs_output_row), ]: for k, v in row.items(): if k != llmfn_outputs.ColumnNames.RESULT_NUM: # We use LLMFnOutputRow.add() because it handles column # name collisions. output_row.add("{}_{}".format(name, k), v) output_rows.append(output_row) outputs.append( llmfn_outputs.LLMFnOutputEntry( prompt_num=lhs_entry.prompt_num, input_num=lhs_entry.input_num, prompt_vars=lhs_entry.prompt_vars, output_rows=self._run_post_processing_cmds( lhs_output_rows=lhs_output_rows, rhs_output_rows=rhs_output_rows, results=output_rows, ), ) ) return outputs def add_compare_post_process_add_fn( self, name: str, fn: llmfn_post_process.LLMCompareFnPostProcessBatchAddFn, ) -> LLMFunction: self._post_process_cmds.append( llmfn_post_process_cmds.LLMCompareFnPostProcessAddCommand( name=name, fn=fn ) ) return self 07070100000026000081A40000000000000000000000016459839500003A6E000000000000000000000000000000000000005500000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llm_function_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for llm_function.""" from __future__ import annotations from typing import Any, Callable, Optional, Mapping, Sequence from absl.testing import absltest from google.generativeai.notebook.lib import llm_function from google.generativeai.notebook.lib import llmfn_inputs_source from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_outputs from google.generativeai.notebook.lib import model as model_lib LLMCompareFunction = llm_function.LLMCompareFunction LLMFunctionImpl = llm_function.LLMFunctionImpl LLMFnOutputs = llmfn_outputs.LLMFnOutputs LLMFnOutputRow = llmfn_output_row.LLMFnOutputRow LLMFnOutputRowView = llmfn_output_row.LLMFnOutputRowView class _MockModel(model_lib.AbstractModel): """Mock model that returns a caller-provided result.""" def __init__(self, mock_results: Sequence[str]): self._mock_results = mock_results def call_model( self, model_input: str, model_args: model_lib.ModelArguments | None = None ) -> model_lib.ModelResults: return model_lib.ModelResults( model_input=model_input, text_results=self._mock_results ) class _MockInputsSource(llmfn_inputs_source.LLMFnInputsSource): def _to_normalized_inputs_impl( self, ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: return [ {"word_one": "apple", "word_two": "banana"}, {"word_one": "australia", "word_two": "brazil"}, ], lambda: None class LLMFunctionBasicTest(absltest.TestCase): """Test basic functionality such as execution and input-handling.""" def _test_is_callable( self, llm_fn: Callable[[Optional[Sequence[tuple[str, str]]]], LLMFnOutputs], ) -> LLMFnOutputs: return llm_fn(None) def test_run(self): llm_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["the opposite of hot is"], ) results = self._test_is_callable(llm_fn) expected_results = { "Prompt Num": [0], "Input Num": [0], "Result Num": [0], "Prompt": ["the opposite of hot is"], "text_result": ["the opposite of hot is"], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) def test_inputs(self): llm_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=[ "A for {word_one}, B for {word_two}, C for", "if A is to {word_one} as B is to {word_two}, then C is", ], ) results = llm_fn( inputs={ "word_one": ["apple", "australia"], "word_two": ["banana", "brazil"], } ) expected_results = { "Prompt Num": [0, 0, 1, 1], "Input Num": [0, 1, 0, 1], "Result Num": [0, 0, 0, 0], "Prompt": [ "A for apple, B for banana, C for", "A for australia, B for brazil, C for", "if A is to apple as B is to banana, then C is", "if A is to australia as B is to brazil, then C is", ], "text_result": [ "A for apple, B for banana, C for", "A for australia, B for brazil, C for", "if A is to apple as B is to banana, then C is", "if A is to australia as B is to brazil, then C is", ], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) def test_inputs_source(self): llm_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=[ "A for {word_one}, B for {word_two}, C for", ], ) results = llm_fn(_MockInputsSource()) expected_results = { "Prompt Num": [0, 0], "Input Num": [0, 1], "Result Num": [0, 0], "Prompt": [ "A for apple, B for banana, C for", "A for australia, B for brazil, C for", ], "text_result": [ "A for apple, B for banana, C for", "A for australia, B for brazil, C for", ], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) def test_one_prompt_many_results(self): llm_fn = LLMFunctionImpl( model=_MockModel(mock_results=["cold", "cold", "cold"]), prompts=["The opposite of hot is"], ) results = llm_fn() expected_results = { "Prompt Num": [0, 0, 0], "Input Num": [0, 0, 0], "Result Num": [0, 1, 2], "Prompt": [ "The opposite of hot is", "The opposite of hot is", "The opposite of hot is", ], "text_result": ["cold", "cold", "cold"], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) class LLMFunctionPostProcessTest(absltest.TestCase): """Test post-processing features.""" def test_add_post_process_reorder_fn(self): llm_fn = LLMFunctionImpl( model=_MockModel( mock_results=["cold", "freezing", "chilly"], ), prompts=["The opposite of {word} is"], ) # Reverse the order of rows. def reverse_fn( rows: Sequence[LLMFnOutputRowView], ) -> Sequence[int]: indices = list(range(0, len(rows))) indices.reverse() return indices results = llm_fn.add_post_process_reorder_fn( name="reverse_fn", fn=reverse_fn )({"word": ["hot"]}) expected_results = { "Prompt Num": [0, 0, 0], "Input Num": [0, 0, 0], "Result Num": [2, 1, 0], "Prompt": [ "The opposite of hot is", "The opposite of hot is", "The opposite of hot is", ], "text_result": ["chilly", "freezing", "cold"], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) def test_add_post_process_add_fn(self): llm_fn = LLMFunctionImpl( model=_MockModel( mock_results=["cold", "freezing", "chilly"], ), prompts=["The opposite of {word} is"], ) def add_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: return [len(row.result_value()) for row in rows] results = llm_fn.add_post_process_add_fn(name="length", fn=add_fn)( {"word": ["hot"]} ) expected_results = { "Prompt Num": [0, 0, 0], "Input Num": [0, 0, 0], "Result Num": [0, 1, 2], "Prompt": [ "The opposite of hot is", "The opposite of hot is", "The opposite of hot is", ], "length": [4, 8, 6], "text_result": ["cold", "freezing", "chilly"], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) def test_add_post_process_replace_fn(self): llm_fn = LLMFunctionImpl( model=_MockModel( mock_results=["cold", "freezing", "chilly"], ), prompts=["The opposite of {word} is"], ) def replace_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[str]: return [row.result_value().upper() for row in rows] results = llm_fn.add_post_process_replace_fn( name="replace_fn", fn=replace_fn )({"word": ["hot"]}) expected_results = { "Prompt Num": [0, 0, 0], "Input Num": [0, 0, 0], "Result Num": [0, 1, 2], "Prompt": [ "The opposite of hot is", "The opposite of hot is", "The opposite of hot is", ], "text_result": ["COLD", "FREEZING", "CHILLY"], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) class LLMCompareFunctionTest(absltest.TestCase): """Test LLMCompareFunction.""" def test_basic_run(self): """Basic comparison test.""" lhs_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["lhs_{word}"], ) rhs_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["rhs_{word}"], ) compare_fn = LLMCompareFunction( lhs_name_and_fn=("lhs", lhs_fn), rhs_name_and_fn=("rhs", rhs_fn) ) results = compare_fn({ "word": ["hello", "world"], }) expected_results = { "Prompt Num": [0, 0], "Input Num": [0, 1], "Result Num": [0, 0], "Prompt vars": [{"word": "hello"}, {"word": "world"}], "lhs_text_result": ["lhs_hello", "lhs_world"], "rhs_text_result": ["rhs_hello", "rhs_world"], "is_equal": [False, False], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) def test_run_with_post_process(self): """Comparison test with post-processing operations.""" def length_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: return [len(str(row.result_value())) for row in rows] lhs_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["lhs_{word}"], ).add_post_process_add_fn(name="length", fn=length_fn) rhs_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["rhs_{word}"], ).add_post_process_add_fn(name="length", fn=length_fn) compare_fn = LLMCompareFunction( lhs_name_and_fn=("lhs", lhs_fn), rhs_name_and_fn=("rhs", rhs_fn) ).add_post_process_add_fn(name="length", fn=length_fn) results = compare_fn({ "word": ["hi", "world"], }) # Post-processing results from the LHS, RHS and compare functions are # all included in the results. expected_results = { "Prompt Num": [0, 0], "Input Num": [0, 1], "Result Num": [0, 0], "Prompt vars": [{"word": "hi"}, {"word": "world"}], "lhs_length": [6, 9], "lhs_text_result": ["lhs_hi", "lhs_world"], "rhs_length": [6, 9], "rhs_text_result": ["rhs_hi", "rhs_world"], "length": [5, 5], "is_equal": [False, False], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) def test_run_with_name_collisions(self): """Test with name collisions.""" def length_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: return [len(str(row.result_value())) for row in rows] lhs_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["lhs_{word}"], ).add_post_process_add_fn(name="length", fn=length_fn) rhs_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["rhs_{word}"], ).add_post_process_add_fn(name="length", fn=length_fn) # Give left- and right functions the same names. compare_fn = LLMCompareFunction( lhs_name_and_fn=("fn", lhs_fn), rhs_name_and_fn=("fn", rhs_fn) ).add_post_process_add_fn(name="length", fn=length_fn) results = compare_fn({ "word": ["hey", "world"], }) # Name collisions are resolved. expected_results = { "Prompt Num": [0, 0], "Input Num": [0, 1], "Result Num": [0, 0], "Prompt vars": [{"word": "hey"}, {"word": "world"}], "fn_length": [7, 9], "fn_text_result": ["lhs_hey", "lhs_world"], "fn_length_1": [7, 9], "fn_text_result_1": ["rhs_hey", "rhs_world"], "length": [5, 5], "is_equal": [False, False], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) def test_custom_compare(self): """Test custom comparison function.""" def length_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: return [len(str(row.result_value())) for row in rows] lhs_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["{word}_{word}"], ).add_post_process_add_fn(name="length", fn=length_fn) rhs_fn = LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["abcd_{word}"], ).add_post_process_add_fn(name="length", fn=length_fn) # We deliberately have our custom fn take Mapping[str, Any] instead # of LLMFnOutputRowView to make sure the typechecker allows this # as well. # Note that this function returns a non-string as well. def _is_length_less_than( lhs: Mapping[str, Any], rhs: Mapping[str, Any] ) -> bool: return lhs["length"] < rhs["length"] def _is_length_greater_than( lhs: Mapping[str, Any], rhs: Mapping[str, Any] ) -> bool: return lhs["length"] > rhs["length"] # Batch-based comparison function for post-processing. def _sum_of_lengths( rows: Sequence[tuple[Mapping[str, Any], Mapping[str, Any]]] ) -> Sequence[int]: return [lhs["length"] + rhs["length"] for lhs, rhs in rows] compare_fn = LLMCompareFunction( lhs_name_and_fn=("lhs", lhs_fn), rhs_name_and_fn=("rhs", rhs_fn), compare_name_and_fns=[ ("is_shorter_than", _is_length_less_than), ("is_longer_than", _is_length_greater_than), ], ).add_compare_post_process_add_fn(name="sum_of_lengths", fn=_sum_of_lengths) results = compare_fn({ "word": ["hey", "world"], }) # Name collisions are resolved. expected_results = { "Prompt Num": [0, 0], "Input Num": [0, 1], "Result Num": [0, 0], "Prompt vars": [{"word": "hey"}, {"word": "world"}], "lhs_length": [7, 11], "lhs_text_result": ["hey_hey", "world_world"], "rhs_length": [8, 10], "rhs_text_result": ["abcd_hey", "abcd_world"], "is_shorter_than": [True, False], "sum_of_lengths": [15, 21], "is_longer_than": [False, True], } self.assertEqual(expected_results, results.as_dict()) self.assertEqual( list(expected_results.keys()), list(results.as_dict().keys()) ) if __name__ == "__main__": absltest.main() 07070100000027000081A40000000000000000000000016459839500000B26000000000000000000000000000000000000005500000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llmfn_input_utils.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for handling input variables.""" from __future__ import annotations from typing import Any, Mapping, Sequence, Union from google.generativeai.notebook.lib import llmfn_inputs_source _NormalizedInputsList = llmfn_inputs_source.NormalizedInputsList _ColumnOrderValuesList = Mapping[str, Sequence[str]] LLMFunctionInputs = Union[ _ColumnOrderValuesList, llmfn_inputs_source.LLMFnInputsSource, ] def _is_column_order_values_list(inputs: Any) -> bool: """See if inputs is of the form: {"key1": ["val1", "val2", ...]}. This is similar to the format produced by: pandas.DataFrame.to_dict(orient="list") Args: inputs: The inputs passed into an LLMFunction. Returns: Whether `inputs` is a column-ordered list of values. """ if not isinstance(inputs, Mapping): return False for x in inputs.values(): if not isinstance(x, Sequence): return False # Strings and bytes are also considered Sequences but we disallow them # here because the values contained in their Sequences are single # characters rather than words. if isinstance(x, str) or isinstance(x, bytes): return False return True # TODO(b/273688393): Perform stricter validation on `inputs`. def _normalize_column_order_values_list( inputs: _ColumnOrderValuesList, ) -> _NormalizedInputsList: """Transforms prompt inputs into a list of dictionaries.""" return_list: list[dict[str, str]] = [] keys = list(inputs.keys()) if keys: first_key = keys[0] for row_num in range(len(inputs[first_key])): row_dict = {} return_list.append(row_dict) for key in keys: row_dict[key] = inputs[key][row_num] return return_list def to_normalized_inputs(inputs: LLMFunctionInputs) -> _NormalizedInputsList: """Handles the different types of `inputs` and returns a normalized form.""" normalized_inputs: list[Mapping[str, str]] = [] if isinstance(inputs, llmfn_inputs_source.LLMFnInputsSource): normalized_inputs.extend(inputs.to_normalized_inputs()) elif _is_column_order_values_list(inputs): normalized_inputs.extend(_normalize_column_order_values_list(inputs)) else: raise ValueError("Unsupported input type {!r}".format(inputs)) return normalized_inputs 07070100000028000081A40000000000000000000000016459839500000908000000000000000000000000000000000000005700000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llmfn_inputs_source.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LLMFnInputsSource.""" from __future__ import annotations import abc from typing import Callable, Mapping, Sequence NormalizedInputsList = Sequence[Mapping[str, str]] class LLMFnInputsSource(abc.ABC): """Abstract class representing a source of inputs for LLMFunction. This class could be extended with concrete implementations that read data from external sources, such as Google Sheets. """ def __init__(self): self._cached_inputs: NormalizedInputsList | None = None self._display_status_fn: Callable[[], None] = lambda: None def to_normalized_inputs( self, suppress_status_msgs: bool = False ) -> NormalizedInputsList: """Returns a sequence of normalized inputs. The return value is a sequence of dictionaries of (placeholder, value) pairs, e.g. [{"word": "hot"}, {"word: "cold"}, ....] These are used for keyword-substitution for prompts in LLMFunctions. Args: suppress_status_msgs: If True, suppress status messages regarding the input being read. Returns: A sequence of normalized inputs. """ if self._cached_inputs is None: self._cached_inputs, self._display_status_fn = ( self._to_normalized_inputs_impl() ) if not suppress_status_msgs: self._display_status_fn() return self._cached_inputs @abc.abstractmethod def _to_normalized_inputs_impl( self, ) -> tuple[NormalizedInputsList, Callable[[], None]]: """Returns a tuple of NormalizedInputsList and a display function. The display function displays some status about the input (e.g. where it is read from). This way the status continues to be displayed even though the results are cached. """ 07070100000029000081A400000000000000000000000164598395000015A7000000000000000000000000000000000000005400000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llmfn_output_row.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LLMFnOutputRow.""" from __future__ import annotations import abc from typing import Any, Iterator, Mapping # The type of value stored in a cell. _CELLVALUETYPE = Any def _get_name_of_type(x: type[Any]) -> str: if hasattr(x, "__name__"): return x.__name__ return str(x) def _validate_is_result_type(value: Any, result_type: type[Any]) -> None: if result_type == Any: return if not isinstance(value, result_type): raise ValueError( 'Value of last entry must be of type "{}", got "{}"'.format( _get_name_of_type(result_type), _get_name_of_type(type(value)), ) ) class LLMFnOutputRowView(Mapping[str, _CELLVALUETYPE], metaclass=abc.ABCMeta): """Immutable view of LLMFnOutputRow.""" # Additional methods (not required by Mapping[str, _CELLVALUETYPE]) @abc.abstractmethod def __contains__(self, k: str) -> bool: """For expressions like: x in this_instance.""" @abc.abstractmethod def __str__(self) -> str: """For expressions like: str(this_instance).""" # Own methods. @abc.abstractmethod def result_type(self) -> type[Any]: """Returns the type enforced for the result cell.""" @abc.abstractmethod def result_value(self) -> Any: """Get the value of the result cell.""" @abc.abstractmethod def result_key(self) -> str: """Get the key of the result cell.""" class LLMFnOutputRow(LLMFnOutputRowView): """Container that represents a single row in a table of outputs. We represent outputs as a table. This class represents a single row in the table like a dictionary, where the key is the column name and the value is the cell value. A single cell is designated the "result". This contains the output of the LLM model after running any post-processing functions specified by the user. In addition to behaving like a dictionary, this class provides additional methods, including: - Getting the value of the "result" cell - Setting the value (and optionally the key) of the "result" cell. - Add a new non-result cell Notes: As an implementation detail, the result-cell is always kept as the rightmost cell. """ def __init__( self, data: Mapping[str, _CELLVALUETYPE], result_type: type[Any] ): """Constructor. Args: data: The initial value of the row. The last entry will be treated as the result. Cannot be empty. The value of the last entry must be `str`. result_type: The type of the result cell. This will be enforced at runtime. """ self._data: dict[str, _CELLVALUETYPE] = dict(data) if not self._data: raise ValueError("Must provide non-empty data") self._result_type = result_type result_value = list(self._data.values())[-1] _validate_is_result_type(result_value, self._result_type) # Methods needed for Mapping[str, _CELLVALUETYPE]: def __iter__(self) -> Iterator[str]: return self._data.__iter__() def __len__(self) -> int: return self._data.__len__() def __getitem__(self, k: str) -> _CELLVALUETYPE: return self._data.__getitem__(k) # Additional methods for LLMFnOutputRowView. def __contains__(self, k: str) -> bool: return self._data.__contains__(k) def __str__(self) -> str: return "LLMFnOutputRow: {}".format(self._data.__str__()) def result_type(self) -> type[Any]: return self._result_type def result_value(self) -> Any: return self._data[self.result_key()] def result_key(self) -> str: # Our invariant is that the result-cell is always the rightmost cell. return list(self._data.keys())[-1] # Mutable methods. def set_result_value(self, value: Any, key: str | None = None) -> None: """Set the value of the result cell. Sets the value (and optionally the key) of the result cell. Args: value: The value to set the result cell today. key: Optionally change the key as well. """ _validate_is_result_type(value, self._result_type) current_key = self.result_key() if key is None or key == current_key: self._data[current_key] = value return del self._data[current_key] self._data[key] = value def add(self, key: str, value: _CELLVALUETYPE) -> None: """Add a non-result cell. Adds a new non-result cell. This does not affect the result cell. Args: key: The key of the new cell to add. value: The value of the new cell to add. """ # Handle collisions with `key`. if key in self._data: idx = 1 candidate_key = key while candidate_key in self._data: candidate_key = "{}_{}".format(key, idx) idx = idx + 1 key = candidate_key # Insert the new key/value into the second rightmost position to keep # the result cell as the rightmost cell. result_key = self.result_key() result_value = self._data.pop(result_key) self._data[key] = value self._data[result_key] = result_value 0707010000002A000081A40000000000000000000000016459839500001363000000000000000000000000000000000000005900000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llmfn_output_row_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for llmfn_outputs.""" from __future__ import annotations from typing import Any, Mapping from absl.testing import absltest from google.generativeai.notebook.lib import llmfn_output_row LLMFnOutputRow = llmfn_output_row.LLMFnOutputRow class LLMFnOutputRowTest(absltest.TestCase): def _test_is_mapping_impl(self, row: Mapping[str, Any]) -> int: """Dummy function that asserts that a LLMFnOutputRow is a Mapping.""" count = 0 for _ in row: count = count + 1 return count def test_is_mapping(self): row = LLMFnOutputRow(data={"result": "none"}, result_type=str) self.assertLen(row, self._test_is_mapping_impl(row)) def _test_is_output_row_view_impl( self, view: llmfn_output_row.LLMFnOutputRowView ) -> None: self.assertEqual("result", view.result_key()) self.assertEqual("none", view.result_value()) def test_is_output_row_view(self): row = LLMFnOutputRow(data={"result": "none"}, result_type=str) self._test_is_output_row_view_impl(row) def test_constructor(self): with self.assertRaisesRegex(ValueError, "Must provide non-empty data"): LLMFnOutputRow(data={}, result_type=str) with self.assertRaisesRegex( ValueError, 'Value of last entry must be of type "str"' ): LLMFnOutputRow(data={"result": 42}, result_type=str) # Non-strings are accepted for non-rightmost cell. _ = LLMFnOutputRow(data={"int": 42, "result": "forty-two"}, result_type=str) def test_add(self): row = LLMFnOutputRow(data={"result": "none"}, result_type=str) row.add("score", 42) row.set_result_value("hello") self.assertEqual({"score": 42, "result": "hello"}, dict(row)) self.assertEqual("result", row.result_key()) self.assertEqual("hello", row.result_value()) def test_add_with_collision(self): row = LLMFnOutputRow(data={"result": "none"}, result_type=str) row.add("score", 42) row.add("score", "forty-two") row.set_result_value("hello") self.assertEqual( {"score": 42, "score_1": "forty-two", "result": "hello"}, dict(row.items()), ) self.assertEqual("result", row.result_key()) self.assertEqual("hello", row.result_value()) def test_add_does_not_affect_result_cell(self): row = LLMFnOutputRow(data={"result": "hello"}, result_type=str) self.assertEqual("hello", row.result_value()) row.add("column_one", 42) row.add("column_two", "forty-two") self.assertEqual("hello", row.result_value()) self.assertEqual( {"column_one": 42, "column_two": "forty-two", "result": "hello"}, dict(row), ) self.assertEqual("result", row.result_key()) self.assertEqual("hello", row.result_value()) def test_set_result_value(self): row = LLMFnOutputRow(data={"result": "none"}, result_type=str) row.set_result_value("hello") self.assertEqual("result", row.result_key()) self.assertEqual("hello", row.result_value()) # Results should remain unaffected when a new column is added. row.add("column_one", 42) self.assertEqual("result", row.result_key()) self.assertEqual("hello", row.result_value()) self.assertEqual( {"column_one": 42, "result": "hello"}, dict(row), ) def test_get_item(self): row = LLMFnOutputRow( data={"one": "first", "two": "second", "three": "third"}, result_type=str, ) self.assertEqual("first", row["one"]) self.assertEqual("second", row["two"]) self.assertEqual("third", row["three"]) def test_result_type(self): # Cannot construct the row if the result cell is of the wrong type. with self.assertRaisesRegex( ValueError, 'Value of last entry must be of type "int", got "str"' ): LLMFnOutputRow( data={"one": "first", "two": "second", "three": "third"}, result_type=int, ) row = LLMFnOutputRow( data={"one": "first", "two": "second", "three": 3}, result_type=int, ) # Cannot set the result value to the wrong type. with self.assertRaisesRegex( ValueError, 'Value of last entry must be of type "int", got "str"' ): row.set_result_value("third") # Can set the result value if it's the correct type. row.set_result_value(42) self.assertEqual(42, row.result_value()) if __name__ == "__main__": absltest.main() 0707010000002B000081A40000000000000000000000016459839500001ED0000000000000000000000000000000000000005100000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llmfn_outputs.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Output of LLMFunction.""" from __future__ import annotations import abc import dataclasses from typing import overload, Any, Callable, Iterable, Iterator, Mapping, Sequence from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import model as model_lib import pandas class ColumnNames: """Names of columns that are used to represent output.""" PROMPT_NUM = "Prompt Num" INPUT_NUM = "Input Num" RESULT_NUM = "Result Num" # In the code we refer to "model_input" as the full keyword-substituted prompt # and "prompt" as the template with placeholders. # When displaying the results however we use "prompt" since "model_input" is # an internal name. MODEL_INPUT = "Prompt" PROMPT_VARS = "Prompt vars" TEXT_RESULT = "text_result" @dataclasses.dataclass class LLMFnOutputEntry: """The output of a single model input from LLMFunction. A model input is a prompt where the keyword placeholders have been substituted (by `prompt_vars`). E.g. If we have: prompt: "the opposite of {word} is" prompt_vars: {"word", "hot"} Then we will have the following model input: model_input: "the opposite of hot is" Note: The model may produce one-or-more results for a given model_input. This is represented by the sequence `output_rows`. """ prompt_num: int input_num: int prompt_vars: Mapping[str, str] output_rows: Sequence[llmfn_output_row.LLMFnOutputRow] prompt: str | None = None model_input: str | None = None model_results: model_lib.ModelResults | None = None def _has_model_input_field(outputs: Iterable[LLMFnOutputEntry]): for entry in outputs: if entry.model_input is not None: return True return False class LLMFnOutputsBase(Sequence[LLMFnOutputEntry]): """Parent class for LLMFnOutputs. This class exists mainly to avoid a circular dependency between LLMFnOutputs and LLMFnOutputsSink. Most users should use LLMFnOutputs directly instead. """ def __init__( self, outputs: Iterable[LLMFnOutputEntry] | None = None, ): """Constructor. Args: outputs: The contents of this LLMFnOutputs instance. """ self._outputs: list[LLMFnOutputEntry] = ( list(outputs) if outputs is not None else [] ) # Needed for Iterable[LLMFnOutputEntry]. def __iter__(self) -> Iterator[LLMFnOutputEntry]: return self._outputs.__iter__() # Needed for Sequence[LLMFnOutputEntry]. def __len__(self) -> int: return self._outputs.__len__() # Needed for Sequence[LLMFnOutputEntry]. @overload def __getitem__(self, x: int) -> LLMFnOutputEntry: ... @overload def __getitem__(self, x: slice) -> Sequence[LLMFnOutputEntry]: ... def __getitem__( self, x: int | slice ) -> LLMFnOutputEntry | Sequence[LLMFnOutputEntry]: return self._outputs.__getitem__(x) # Convenience methods. def __bool__(self) -> bool: return bool(self._outputs) def __str__(self) -> str: return self.as_pandas_dataframe().__str__() # Own methods def as_dict(self) -> Mapping[str, Sequence[Any]]: """Formats returned results as dictionary.""" # `data` is a table in column order, with the columns listed from left to # right. data = { ColumnNames.PROMPT_NUM: [], ColumnNames.INPUT_NUM: [], # RESULT_NUM is special: each LLMFnOutputRow in self._outputs is # expected to have a RESULT_NUM key. ColumnNames.RESULT_NUM: [], } if _has_model_input_field(self._outputs): data[ColumnNames.MODEL_INPUT] = [] if not self._outputs: return data # Add column names of added data. # The last key in LLMFnOutputRow is special as it is considered # the result. To preserve order in the (unlikely) event of inconsistent # keys across rows, we first add all-but-the-last key to `total_keys_set`, # then the last key. # Note: `total_keys_set` is a Python dictionary instead of a Python set # because Python dictionaries preserve the order in which entries are # added, whereas Python sets do not. total_keys_set: dict[str, None] = {k: None for k in data.keys()} for output in self._outputs: for result in output.output_rows: for key in list(result.keys())[:-1]: total_keys_set[key] = None for output in self._outputs: for result in output.output_rows: total_keys_set[list(result.keys())[-1]] = None # `data` represents the table as a dictionary of: # column names -> list of values for key in total_keys_set: data[key] = [] next_num_rows = 1 for output in self._outputs: for result in output.output_rows: data[ColumnNames.PROMPT_NUM].append(output.prompt_num) data[ColumnNames.INPUT_NUM].append(output.input_num) if ColumnNames.MODEL_INPUT in data: data[ColumnNames.MODEL_INPUT].append(output.model_input) for key, value in result.items(): data[key].append(value) # Look for empty cells and pad them with None. for column in data.values(): if len(column) < next_num_rows: column.append(None) next_num_rows += 1 return data def as_pandas_dataframe(self) -> pandas.DataFrame: return pandas.DataFrame(self.as_dict()) class LLMFnOutputsSink(abc.ABC): """Abstract class representing an exporter for the output of LLMFunction. This class could be extended to write to external documents, such as Google Sheets. """ def write_outputs(self, outputs: LLMFnOutputsBase) -> None: """Writes `outputs` to some destination.""" class LLMFnOutputs(LLMFnOutputsBase): """A sequence of LLMFnOutputEntry instances. Notes: - Each LLMFnOutputEntry represents the results of running one model input (see documentation for LLMFnOutputEntry for what "model input" means.) - A single model input may produce more-than-one text results. """ def __init__( self, outputs: Iterable[LLMFnOutputEntry] | None = None, ipython_display_fn: Callable[[LLMFnOutputs], None] | None = None, ): """Constructor. Args: outputs: The contents of this LLMFnOutputs instance. ipython_display_fn: An optional function for pretty-printing this instance when it is the output of a cell in a notebook. If this argument is not None, the _ipython_display_ method will be defined which will in turn invoke this function. """ super().__init__(outputs=outputs) if ipython_display_fn: self._ipython_display_fn = ipython_display_fn # We define the _ipython_display_ method only when `ipython_display_fn` # is set. This lets us fall back to a default implementation defined by # the notebook when `ipython_display_fn` is not set, instead of having to # provide our own default implementation. setattr(self, "_ipython_display_", getattr(self, "_ipython_display_impl")) def _ipython_display_impl(self): """Actual implementation of _ipython_display_. This method should only be used invoked if self._ipython_display_fn is set. """ self._ipython_display_fn(self) def export(self, sink: LLMFnOutputsSink) -> None: """Export contents to `sink`.""" sink.write_outputs(self) 0707010000002C000081A40000000000000000000000016459839500001683000000000000000000000000000000000000005600000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llmfn_outputs_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for llmfn_outputs.""" from __future__ import annotations from typing import Sequence from absl.testing import absltest from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_outputs from google.generativeai.notebook.lib import model as model_lib LLMFnOutputEntry = llmfn_outputs.LLMFnOutputEntry LLMFnOutputs = llmfn_outputs.LLMFnOutputs LLMFnOutputRow = llmfn_output_row.LLMFnOutputRow def _get_empty_model_results() -> model_lib.ModelResults: return model_lib.ModelResults(model_input="", text_results=[]) class LLMFnOutputsTest(absltest.TestCase): def _test_is_sequence(self, outputs: Sequence[LLMFnOutputEntry]): # Make sure `outputs` is iterable. count = 0 for _ in outputs: count = count + 1 # Make sure len(outputs) works. self.assertLen(outputs, count) # Make sure regular integer indices are accepted. self.assertIsInstance(outputs[0], LLMFnOutputEntry) # Make sure slices are accepted. self.assertLen(outputs[:-1], count - 1) self.assertIsInstance(outputs[:-1][0], LLMFnOutputEntry) def test_is_sequence(self): tmp_entry_list = [] for prompt_num in range(0, 3): prompt = ["one", "two", "three"][prompt_num] output_entry = LLMFnOutputEntry( prompt_num=prompt_num, input_num=0, prompt=prompt, prompt_vars={}, model_input=prompt, model_results=_get_empty_model_results(), output_rows=[], ) tmp_entry_list.append(output_entry) outputs = LLMFnOutputs(tmp_entry_list) self._test_is_sequence(outputs) def test_as_dict_basic(self): outputs_list = [] for prompt_num in range(0, 3): prompt = ["one", "two", "three"][prompt_num] text_results = ["red", "green", "blue"][prompt_num] output_entry = LLMFnOutputEntry( prompt_num=prompt_num, input_num=0, prompt=prompt, prompt_vars={}, model_input=prompt, model_results=_get_empty_model_results(), output_rows=[ LLMFnOutputRow( data={ "Result Num": 0, "text_results": "{}_one".format(text_results), }, result_type=str, ), LLMFnOutputRow( data={ "Result Num": 1, "text_results": "{}_two".format(text_results), }, result_type=str, ), ], ) outputs_list.append(output_entry) expected_dict = { "Prompt Num": [0, 0, 1, 1, 2, 2], "Input Num": [0, 0, 0, 0, 0, 0], "Result Num": [0, 1, 0, 1, 0, 1], "Prompt": ["one", "one", "two", "two", "three", "three"], "text_results": [ "red_one", "red_two", "green_one", "green_two", "blue_one", "blue_two", ], } outputs = LLMFnOutputs(outputs_list) self.assertEqual(expected_dict, outputs.as_dict()) # Keys must be in the same order as well. self.assertEqual(list(expected_dict.keys()), list(outputs.as_dict().keys())) def test_as_dict_with_holes(self): outputs_list = [] for prompt_num in range(0, 3): prompt = ["one", "two", "three"][prompt_num] text_results = ["red", "green", "blue"][prompt_num] output_entry = LLMFnOutputEntry( prompt_num=prompt_num, input_num=0, prompt=prompt, prompt_vars={}, model_input=prompt, model_results=_get_empty_model_results(), output_rows=[ LLMFnOutputRow( data={ "Result Num": 0, text_results: True, "text_results": "{}_one".format(text_results), }, result_type=str, ), LLMFnOutputRow( data={ "Result Num": 1, text_results: True, "text_results": "{}_two".format(text_results), }, result_type=str, ), ], ) outputs_list.append(output_entry) expected_dict = { "Prompt Num": [0, 0, 1, 1, 2, 2], "Input Num": [0, 0, 0, 0, 0, 0], "Result Num": [0, 1, 0, 1, 0, 1], "Prompt": ["one", "one", "two", "two", "three", "three"], "red": [True, True, None, None, None, None], "green": [None, None, True, True, None, None], "blue": [None, None, None, None, True, True], "text_results": [ "red_one", "red_two", "green_one", "green_two", "blue_one", "blue_two", ], } outputs = LLMFnOutputs(outputs_list) self.assertEqual(expected_dict, outputs.as_dict()) # Keys must be in the same order as well. self.assertEqual(list(expected_dict.keys()), list(outputs.as_dict().keys())) if __name__ == "__main__": absltest.main() 0707010000002D000081A400000000000000000000000164598395000009A4000000000000000000000000000000000000005600000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llmfn_post_process.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Signatures for post-processing functions and other common definitions.""" from __future__ import annotations from typing import Any, Callable, Sequence, Tuple from google.generativeai.notebook.lib import llmfn_output_row class PostProcessExecutionError(RuntimeError): """An error while executing a post-processing command.""" # A batch-process function takes a batch of rows, and returns a sequence of # indices representing which rows to keep. # This can be used to implement operations such as filtering and sorting. # # Requires: # - Indices must be in the range [0, len(input rows)). LLMFnPostProcessBatchReorderFn = Callable[ [Sequence[llmfn_output_row.LLMFnOutputRowView]], Sequence[int], ] # An add function takes a batch of rows and returns a sequence of values to # be added as new columns. # # Requires: # - Output sequence must be exactly the same length as number of rows. LLMFnPostProcessBatchAddFn = Callable[ [Sequence[llmfn_output_row.LLMFnOutputRowView]], Sequence[Any] ] # A replace function takes a batch of rows and returns a sequence of values # to replace the existing results. # # Requires: # - Output sequence must be exactly the same length as number of rows. # - Return type must match the result_type of LLMFnOutputRow. LLMFnPostProcessBatchReplaceFn = Callable[ [Sequence[llmfn_output_row.LLMFnOutputRowView]], Sequence[Any] ] # An add function takes a batch of pairs of rows and returns a sequence of # values to be added as new columns. # # This is used for LLMCompareFunction. # # Requires: # - Output sequence must be exactly the same length as number of rows. LLMCompareFnPostProcessBatchAddFn = Callable[ [ Sequence[ Tuple[ llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, ] ] ], Sequence[Any], ] 0707010000002E000081A40000000000000000000000016459839500001F7A000000000000000000000000000000000000005B00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llmfn_post_process_cmds.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Internal representation of post-process commands for LLMFunction. This module is internal to LLMFunction and should only be used by llm_function.py. """ from __future__ import annotations import abc from typing import Sequence from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_post_process def _convert_view_to_output_row( row: llmfn_output_row.LLMFnOutputRowView, ) -> llmfn_output_row.LLMFnOutputRow: """Convenience method to conert a LLMFnOutputRowView to LLMFnOutputRow. If `row` is already a LLMFnOutputRow, return as-is for efficiency. This could potentially break encapsulation as it could let code to modify a LLMFnOutputRowView that was intended to be immutable, so it should be used with care. Args: row: An instance of LLMFnOutputRowView. Returns: An instance of LLMFnOutputRow. May be the same instance as `row` if `row` is already an instance of LLMFnOutputRow. """ if isinstance(row, llmfn_output_row.LLMFnOutputRow): return row return llmfn_output_row.LLMFnOutputRow( data=row, result_type=row.result_type() ) class LLMFnPostProcessCommand(abc.ABC): """Abstract class representing post-processing commands.""" @abc.abstractmethod def name(self) -> str: """Returns the name of this post-processing command.""" class LLMFnImplPostProcessCommand(LLMFnPostProcessCommand): """Post-processing commands for LLMFunctionImpl.""" @abc.abstractmethod def run( self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView] ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: """Processes a batch of results and returns a new batch. Args: rows: The rows in a batch. Note that `rows` are not guaranteed to be remain unmodified. Returns: A new set of rows that should replace the batch. """ class LLMFnPostProcessReorderCommand(LLMFnImplPostProcessCommand): """A batch command processes a set of results at once. Note that a "batch" represents a set of results coming from a single prompt, as the model may produce more-than-one result for a prompt. """ def __init__( self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReorderFn ): self._name = name self._fn = fn def name(self) -> str: return self._name def run( self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: new_row_indices = self._fn(rows) if len(set(new_row_indices)) != len(new_row_indices): raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned indices should be unique'.format( self._name ) ) new_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for idx in new_row_indices: if idx < 0: raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned indices must be greater than or' ' equal to zero, got {}'.format(self._name, idx) ) if idx >= len(rows): raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned indices must be less than length of' ' rows (={}), got {}'.format(self._name, len(rows), idx) ) new_rows.append(_convert_view_to_output_row(rows[idx])) return new_rows class LLMFnPostProcessAddCommand(LLMFnImplPostProcessCommand): """A command that adds each row with a new column. This does not change the value of the results cell. """ def __init__( self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchAddFn ): self._name = name self._fn = fn def name(self) -> str: return self._name def run( self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: new_values = self._fn(rows) if len(new_values) != len(rows): raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned length ({}) != number of input rows' ' ({})'.format(self._name, len(new_values), len(rows)) ) new_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for new_value, row in zip(new_values, rows): new_row = _convert_view_to_output_row(row) new_row.add(key=self._name, value=new_value) new_rows.append(new_row) return new_rows class LLMFnPostProcessReplaceCommand(LLMFnImplPostProcessCommand): """A command that modifies the results in each row.""" def __init__( self, name: str, fn: llmfn_post_process.LLMFnPostProcessBatchReplaceFn ): self._name = name self._fn = fn def name(self) -> str: return self._name def run( self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: new_values = self._fn(rows) if len(new_values) != len(rows): raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned length ({}) != number of input rows' ' ({})'.format(self._name, len(new_values), len(rows)) ) new_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for new_value, row in zip(new_values, rows): new_row = _convert_view_to_output_row(row) new_row.set_result_value(value=new_value) new_rows.append(new_row) return new_rows class LLMCompareFnPostProcessCommand(LLMFnPostProcessCommand): """Post-processing commands for LLMCompareFunction.""" @abc.abstractmethod def run( self, rows: Sequence[ tuple[ llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, ] ], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: """Processes a batch of left- and right-hand side results. Args: rows: The rows in a batch. Each row is a three-tuple containing: - The left-hand side results, - The right-hand side results, and - The current combined results Returns: A new set of rows that should replace the combined results. """ class LLMCompareFnPostProcessAddCommand(LLMCompareFnPostProcessCommand): """A command that adds each row with a new column. This does not change the value of the results cell. """ def __init__( self, name: str, fn: llmfn_post_process.LLMCompareFnPostProcessBatchAddFn ): self._name = name self._fn = fn def name(self) -> str: return self._name def run( self, rows: Sequence[ tuple[ llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, llmfn_output_row.LLMFnOutputRowView, ] ], ) -> Sequence[llmfn_output_row.LLMFnOutputRow]: new_values = self._fn([(lhs, rhs) for lhs, rhs, _ in rows]) if len(new_values) != len(rows): raise llmfn_post_process.PostProcessExecutionError( 'Error executing "{}": returned length ({}) != number of input rows' ' ({})'.format(self._name, len(new_values), len(rows)) ) new_rows: list[llmfn_output_row.LLMFnOutputRow] = [] for new_value, row in zip( new_values, [combined for _, _, combined in rows] ): new_row = _convert_view_to_output_row(row) new_row.add(key=self._name, value=new_value) new_rows.append(new_row) return new_rows 0707010000002F000081A40000000000000000000000016459839500001A40000000000000000000000000000000000000006000000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/llmfn_post_process_cmds_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Sequence from absl.testing import absltest from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_post_process from google.generativeai.notebook.lib import llmfn_post_process_cmds LLMFnOutputRow = llmfn_output_row.LLMFnOutputRow LLMFnOutputRowView = llmfn_output_row.LLMFnOutputRowView PostProcessExecutionError = llmfn_post_process.PostProcessExecutionError LLMFnPostProcessReorderCommand = ( llmfn_post_process_cmds.LLMFnPostProcessReorderCommand ) LLMFnPostProcessAddCommand = llmfn_post_process_cmds.LLMFnPostProcessAddCommand LLMFnPostProcessReplaceCommand = ( llmfn_post_process_cmds.LLMFnPostProcessReplaceCommand ) LLMCompareFnPostProcessAddCommand = ( llmfn_post_process_cmds.LLMCompareFnPostProcessAddCommand ) class LLMFnPostProcessCmdTest(absltest.TestCase): def test_post_process_reorder_cmd_bad_index_duplicate_indices(self): def bad_index_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: del rows return [0, 0] cmd = LLMFnPostProcessReorderCommand(name="test", fn=bad_index_fn) expected_msg = 'Error executing "test": returned indices should be unique' with self.assertRaisesRegex(PostProcessExecutionError, expected_msg): cmd.run([LLMFnOutputRow(data={"text_result": "hello"}, result_type=str)]) def test_post_process_reorder_cmd_bad_index_less_than_zero(self): def bad_index_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: del rows return [-1] cmd = LLMFnPostProcessReorderCommand(name="test", fn=bad_index_fn) expected_msg = ( 'Error executing "test": returned indices must be greater than or equal' " to zero, got -1" ) with self.assertRaisesRegex(PostProcessExecutionError, expected_msg): cmd.run([]) def test_post_process_reorder_cmd_bad_index_greater_than_equal_to_len(self): def bad_index_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: del rows return [1] cmd = LLMFnPostProcessReorderCommand(name="test", fn=bad_index_fn) expected_msg = ( 'Error executing "test": returned indices must be less than length of' " rows \\(=1\\), got 1" ) with self.assertRaisesRegex(PostProcessExecutionError, expected_msg): cmd.run([LLMFnOutputRow(data={"text_result": "hello"}, result_type=str)]) def test_post_process_reorder_cmd_reverse(self): def reverse_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: del rows return [1, 0] cmd = LLMFnPostProcessReorderCommand(name="test", fn=reverse_fn) results = cmd.run([ LLMFnOutputRow(data={"text_result": "one"}, result_type=str), LLMFnOutputRow(data={"text_result": "two"}, result_type=str), ]) self.assertLen(results, 2) self.assertEqual({"text_result": "two"}, dict(results[0])) self.assertEqual({"text_result": "one"}, dict(results[1])) def test_post_process_reorder_cmd_filter(self): def filter_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: del rows return [1] cmd = LLMFnPostProcessReorderCommand(name="test", fn=filter_fn) results = cmd.run([ LLMFnOutputRow(data={"text_result": "one"}, result_type=str), LLMFnOutputRow(data={"text_result": "two"}, result_type=str), ]) self.assertLen(results, 1) self.assertEqual({"text_result": "two"}, dict(results[0])) def test_post_process_reorder_cmd_filter_to_empty(self): def filter_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: del rows return [] cmd = LLMFnPostProcessReorderCommand(name="test", fn=filter_fn) results = cmd.run([ LLMFnOutputRow(data={"text_result": "one"}, result_type=str), LLMFnOutputRow(data={"text_result": "two"}, result_type=str), ]) self.assertEmpty(results) def test_post_process_add_cmd(self): def add_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[int]: return [len(row.result_value()) for row in rows] cmd = LLMFnPostProcessAddCommand(name="test", fn=add_fn) results = cmd.run([ LLMFnOutputRow(data={"text_result": "apple"}, result_type=str), LLMFnOutputRow(data={"text_result": "banana"}, result_type=str), ]) self.assertLen(results, 2) self.assertEqual({"test": 5, "text_result": "apple"}, dict(results[0])) self.assertEqual({"test": 6, "text_result": "banana"}, dict(results[1])) def test_post_process_replace_cmd(self): def replace_fn(rows: Sequence[LLMFnOutputRowView]) -> Sequence[str]: return [row.result_value().upper() for row in rows] cmd = LLMFnPostProcessReplaceCommand(name="test", fn=replace_fn) results = cmd.run([ LLMFnOutputRow(data={"text_result": "apple"}, result_type=str), LLMFnOutputRow(data={"text_result": "banana"}, result_type=str), ]) self.assertLen(results, 2) self.assertEqual({"text_result": "APPLE"}, dict(results[0])) self.assertEqual({"text_result": "BANANA"}, dict(results[1])) class LLMCompareFnPostProcessTest(absltest.TestCase): def test_cmp_post_process_add_cmd(self): def add_fn( rows: Sequence[tuple[LLMFnOutputRowView, LLMFnOutputRowView]] ) -> Sequence[int]: return [x.result_value() + y.result_value() for x, y in rows] cmd = LLMCompareFnPostProcessAddCommand(name="sum", fn=add_fn) results = cmd.run([ ( LLMFnOutputRow(data={"int_result": 1}, result_type=int), LLMFnOutputRow(data={"int_result": 2}, result_type=int), LLMFnOutputRow(data={"text_result": "ok"}, result_type=str), ), ( LLMFnOutputRow(data={"int_result": 3}, result_type=int), LLMFnOutputRow(data={"int_result": 4}, result_type=int), LLMFnOutputRow(data={"int_result": 5}, result_type=int), ), ]) self.assertLen(results, 2) self.assertEqual({"sum": 3, "text_result": "ok"}, dict(results[0])) self.assertEqual({"sum": 7, "int_result": 5}, dict(results[1])) if __name__ == "__main__": absltest.main() 07070100000030000081A400000000000000000000000164598395000007B1000000000000000000000000000000000000004900000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/model.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Abstract interface for models.""" from __future__ import annotations import abc import dataclasses from typing import Sequence @dataclasses.dataclass(frozen=True) class ModelArguments: """Common arguments for models. Attributes: model: The model string to use. If None a default model will be selected. temperature: The temperature. Must be greater-than-or-equal-to zero. candidate_count: Number of candidates to return. """ model: str | None = None temperature: float | None = None candidate_count: int | None = None @dataclasses.dataclass class ModelResults: """Results from calling AbstractModel.call_model().""" model_input: str text_results: Sequence[str] class AbstractModel(abc.ABC): @abc.abstractmethod def call_model( self, model_input: str, model_args: ModelArguments | None = None ) -> ModelResults: """Executes the model.""" class EchoModel(AbstractModel): """Model that returns the original input. This is primarily used for testing. """ def call_model( self, model_input: str, model_args: ModelArguments | None = None ) -> ModelResults: candidate_count = model_args.candidate_count if model_args else None if candidate_count is None: candidate_count = 1 return ModelResults( model_input=model_input, text_results=[model_input] * candidate_count ) 07070100000031000081A400000000000000000000000164598395000004D0000000000000000000000000000000000000005000000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/prompt_utils.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for processing prompts.""" from __future__ import annotations import string from typing import AbstractSet def get_placeholders(prompt: str) -> AbstractSet[str]: """Returns the placeholders for `prompt`. E.g. Given "A for {word_one} B for {word_two}", returns {"word_one", "word_two"}. Args: prompt: A prompt template with optional placeholders. Returns: A sequence of placeholders in `prompt`. """ placeholders: list[str] = [] for _, field_name, _, _ in string.Formatter().parse(prompt): if field_name is not None: placeholders.append(field_name) return frozenset(placeholders) 07070100000032000081A400000000000000000000000164598395000005E2000000000000000000000000000000000000005500000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/prompt_utils_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from absl.testing import absltest from google.generativeai.notebook.lib import prompt_utils class PromptUtilsTest(absltest.TestCase): def test_get_placeholders_empty(self): placeholders = prompt_utils.get_placeholders("") self.assertEmpty(placeholders) placeholders = prompt_utils.get_placeholders( "There are no placeholders here" ) self.assertEmpty(placeholders) def test_get_placeholders(self): placeholders = prompt_utils.get_placeholders("today {hello} world") self.assertEqual(frozenset({"hello"}), placeholders) placeholders = prompt_utils.get_placeholders("{hello} {world}") self.assertEqual(frozenset({"hello", "world"}), placeholders) placeholders = prompt_utils.get_placeholders("{hello} {hello}") self.assertEqual(frozenset({"hello"}), placeholders) if __name__ == "__main__": absltest.main() 07070100000033000081A4000000000000000000000001645983950000059F000000000000000000000000000000000000004D00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/unique_fn.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Function for de-duping results.""" from __future__ import annotations from typing import Sequence from google.generativeai.notebook.lib import llmfn_output_row def unique_fn( rows: Sequence[llmfn_output_row.LLMFnOutputRowView], ) -> Sequence[int]: """Returns a list of indices with duplicates removed. E.g. if rows has results ["hello", "hello", "world"], the return value would be [0, 2], indicating that the results at index 1 is a duplicate and should be removed. Args: rows: The input rows Returns: A sequence of indices indicating which entries have unique results. """ indices: list[int] = [] seen_entries = set() for idx, row in enumerate(rows): value = row.result_value() if value in seen_entries: continue seen_entries.add(value) indices.append(idx) return indices 07070100000034000081A400000000000000000000000164598395000006F2000000000000000000000000000000000000005200000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/lib/unique_fn_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from absl.testing import absltest from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import unique_fn LLMFnOutputRow = llmfn_output_row.LLMFnOutputRow class UniqueFntest(absltest.TestCase): def test_all_unique(self): rows = [ LLMFnOutputRow(data={"text_result": "red"}, result_type=str), LLMFnOutputRow(data={"text_result": "green"}, result_type=str), LLMFnOutputRow(data={"text_result": "blue"}, result_type=str), ] self.assertEqual([0, 1, 2], unique_fn.unique_fn(rows)) def test_some_dupes(self): rows = [ LLMFnOutputRow(data={"text_result": "red"}, result_type=str), LLMFnOutputRow(data={"text_result": "red"}, result_type=str), LLMFnOutputRow(data={"text_result": "green"}, result_type=str), LLMFnOutputRow(data={"text_result": "red"}, result_type=str), LLMFnOutputRow(data={"text_result": "green"}, result_type=str), LLMFnOutputRow(data={"text_result": "blue"}, result_type=str), ] self.assertEqual([0, 2, 5], unique_fn.unique_fn(rows)) if __name__ == "__main__": absltest.main() 07070100000035000081A4000000000000000000000001645983950000103C000000000000000000000000000000000000004600000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/magics.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Colab Magics class. Installs %%palm magics. """ from __future__ import annotations import abc from google.auth import credentials from google.generativeai import client as palm from google.generativeai.notebook import gspread_client from google.generativeai.notebook import ipython_env from google.generativeai.notebook import ipython_env_impl from google.generativeai.notebook import magics_engine from google.generativeai.notebook import post_process_utils from google.generativeai.notebook import sheets_utils import IPython from IPython.core import magic # Set the UA to distinguish the magic from the client. Do this at import-time # so that a user can still call `palm.configure()`, and both their settings # and this are honored. palm.USER_AGENT = "genai-py-magic" SheetsInputs = sheets_utils.SheetsInputs SheetsOutputs = sheets_utils.SheetsOutputs # Decorator functions for post-processing. post_process_add_fn = post_process_utils.post_process_add_fn post_process_replace_fn = post_process_utils.post_process_replace_fn # Globals. _ipython_env: ipython_env.IPythonEnv | None = None def _get_ipython_env() -> ipython_env.IPythonEnv: """Lazily constructs and returns a global IPythonEnv instance.""" global _ipython_env if _ipython_env is None: _ipython_env = ipython_env_impl.IPythonEnvImpl() return _ipython_env def authorize(creds: credentials.Credentials) -> None: """Sets up credentials. This is used for interacting Google APIs, such as Google Sheets. Args: creds: The credentials that will be used (e.g. to read from Google Sheets.) """ gspread_client.authorize(creds=creds, env=_get_ipython_env()) class AbstractMagics(abc.ABC): """Defines interface to Magics class.""" @abc.abstractmethod def palm(self, cell_line: str | None, cell_body: str | None): """Perform various LLM-related operations. Args: cell_line: String to pass to the MagicsEngine. cell_body: Contents of the cell body. """ raise NotImplementedError() class MagicsImpl(AbstractMagics): """Actual class implementing the magics functionality. We use a separate class to ensure a single, global instance of the magics class. """ def __init__(self): self._engine = magics_engine.MagicsEngine(env=_get_ipython_env()) def palm(self, cell_line: str | None, cell_body: str | None): """Perform various LLM-related operations. Args: cell_line: String to pass to the MagicsEngine. cell_body: Contents of the cell body. Returns: Results from running MagicsEngine. """ cell_line = cell_line or "" cell_body = cell_body or "" return self._engine.execute_cell(cell_line, cell_body) @magic.magics_class class Magics(magic.Magics): """Class to register the magic with Colab. Objects of this class delegate all calls to a single, global instance. """ # Global instance _instance = None @classmethod def get_instance(cls) -> AbstractMagics: """Retrieve global instance of the Magics object.""" if cls._instance is None: cls._instance = MagicsImpl() return cls._instance @magic.line_cell_magic def palm(self, cell_line: str | None, cell_body: str | None): """Perform various LLM-related operations. Args: cell_line: String to pass to the MagicsEngine. cell_body: Contents of the cell body. Returns: Results from running MagicsEngine. """ return Magics.get_instance().palm(cell_line=cell_line, cell_body=cell_body) IPython.get_ipython().register_magics(Magics) 07070100000036000081A400000000000000000000000164598395000014AE000000000000000000000000000000000000004D00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/magics_engine.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MagicsEngine class.""" from __future__ import annotations from typing import AbstractSet, Sequence from google.generativeai.notebook import argument_parser from google.generativeai.notebook import cmd_line_parser from google.generativeai.notebook import command from google.generativeai.notebook import compare_cmd from google.generativeai.notebook import compile_cmd from google.generativeai.notebook import eval_cmd from google.generativeai.notebook import ipython_env from google.generativeai.notebook import model_registry from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import post_process_utils from google.generativeai.notebook import run_cmd from google.generativeai.notebook.lib import prompt_utils class MagicsEngine: """Implementation of functionality used by Magics. This class provides the implementation for Magics, decoupled from the details of integrating with Colab Magics such as registration. """ def __init__( self, registry: model_registry.ModelRegistry | None = None, env: ipython_env.IPythonEnv | None = None, ): self._ipython_env = env models = registry or model_registry.ModelRegistry() self._cmd_handlers: dict[parsed_args_lib.CommandName, command.Command] = { parsed_args_lib.CommandName.RUN_CMD: run_cmd.RunCommand( models=models, env=env ), parsed_args_lib.CommandName.COMPILE_CMD: compile_cmd.CompileCommand( models=models, env=env ), parsed_args_lib.CommandName.COMPARE_CMD: compare_cmd.CompareCommand( env=env ), parsed_args_lib.CommandName.EVAL_CMD: eval_cmd.EvalCommand( models=models, env=env ), } def parse_line( self, line: str, placeholders: AbstractSet[str], ) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]: return cmd_line_parser.CmdLineParser().parse_line(line, placeholders) def _get_handler( self, line: str, placeholders: AbstractSet[str] ) -> tuple[ command.Command, parsed_args_lib.ParsedArgs, Sequence[post_process_utils.ParsedPostProcessExpr], ]: """Given the command line, parse and return all components. Args: line: The LLM Magics command line. placeholders: Placeholders from prompts in the cell contents. Returns: A three-tuple containing: - The command (e.g. "run") - Parsed arguments for the command, - Parsed post-processing expressions """ parsed_args, post_processing_tokens = self.parse_line(line, placeholders) cmd_name = parsed_args.cmd handler = self._cmd_handlers[cmd_name] post_processing_fns = handler.parse_post_processing_tokens( post_processing_tokens ) return handler, parsed_args, post_processing_fns def execute_cell(self, line: str, cell_content: str): """Executes the supplied magic line and cell payload.""" cell = _clean_cell(cell_content) placeholders = prompt_utils.get_placeholders(cell) try: handler, parsed_args, post_processing_fns = self._get_handler( line, placeholders ) return handler.execute(parsed_args, cell, post_processing_fns) except argument_parser.ParserNormalExit as e: if self._ipython_env is not None: e.set_ipython_env(self._ipython_env) # ParserNormalExit implements the _ipython_display_ method so it can # be returned as the output of this cell for display. return e except argument_parser.ParserError as e: e.display(self._ipython_env) # Raise an exception to indicate that execution for this cell has # failed. # The exception is re-raised as SystemExit because Colab automatically # suppresses traceback for SystemExit but not other exceptions. Because # ParserErrors are usually due to user error (e.g. a missing required # flag or an invalid flag value), we want to hide the traceback to # avoid detracting the user from the error message, and we want to # reserve exceptions-with-traceback for actual bugs and unexpected # errors. error_msg = ( "Got parser error: {}".format(e.msgs()[-1]) if e.msgs() else "" ) raise SystemExit(error_msg) from e def _clean_cell(cell_content: str) -> str: # Colab includes a trailing newline in cell_content. Remove only the last # line break from cell contents (i.e. not rstrip), so that multi-line and # intentional line breaks are preserved, but single-line prompts don't have # a trailing line break. cell = cell_content if cell.endswith("\n"): cell = cell[:-1] return cell 07070100000037000081A40000000000000000000000016459839500006100000000000000000000000000000000000000005200000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/magics_engine_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import sys from typing import Any, Callable, Mapping, Sequence from unittest import mock from absl import logging from absl.testing import absltest from google.generativeai.notebook import gspread_client from google.generativeai.notebook import ipython_env from google.generativeai.notebook import magics_engine from google.generativeai.notebook import model_registry from google.generativeai.notebook import post_process_utils from google.generativeai.notebook import sheets_id from google.generativeai.notebook import sheets_utils from google.generativeai.notebook.lib import llm_function from google.generativeai.notebook.lib import llmfn_inputs_source from google.generativeai.notebook.lib import llmfn_outputs from google.generativeai.notebook.lib import model import pandas def _fake_llm_function(x: Any) -> Any: assert False, "Should not be called" return x # Used to store output of "compile" tests. _compiled_function = _fake_llm_function _compiled_lhs_function = _fake_llm_function _compiled_rhs_function = _fake_llm_function # Decorators for testing post-processing operations. def add_length(result: str) -> int: return len(result) @post_process_utils.post_process_add_fn def add_length_decorated(result: str) -> int: return len(result) @post_process_utils.post_process_replace_fn def repeat(result: str) -> str: return result + result @post_process_utils.post_process_replace_fn def to_upper(result: str) -> str: return result.upper() # Comparison functions for "compare" command. def get_sum_of_lengths(lhs: str, rhs: str) -> int: return len(lhs) + len(rhs) def concat(lhs: str, rhs: str) -> str: return lhs + " " + rhs def my_is_equal_fn(lhs: str, rhs: str) -> bool: return lhs == rhs class EchoModelRegistry(model_registry.ModelRegistry): """Fake model registry for testing.""" def __init__(self, alt_model=None): self.model = alt_model or model.EchoModel() self.get_model_name: model_registry.ModelName | None = None def get_model( self, model_name: model_registry.ModelName ) -> model.AbstractModel: self.get_model_name = model_name return self.model class FakeIPythonEnv(ipython_env.IPythonEnv): """Fake IPythonEnv for testing.""" def __init__(self): self.display_args: Any = None def clear(self) -> None: self.display_args = None def display(self, x: Any) -> None: self.display_args = x logging.info("IPythonEnv.display called with:\n%r", x) def display_html(self, x: Any) -> None: logging.info("IPythonEnv.display_html called with:\n%r", x) class FakeInputsSource(llmfn_inputs_source.LLMFnInputsSource): def _to_normalized_inputs_impl( self, ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: return [{"word": "quack3"}, {"word": "quack4"}], lambda: None class FakeOutputsSink(llmfn_outputs.LLMFnOutputsSink): def __init__(self): self.outputs: llmfn_outputs.LLMFnOutputsBase | None = None def write_outputs(self, outputs: llmfn_outputs.LLMFnOutputsBase) -> None: self.outputs = outputs class MockGSpreadClient(gspread_client.GSpreadClient): def __init__(self): self.get_all_records_name: str | None = None self.get_all_records_worksheet_id: int | None = None self.write_records_name: str | None = None self.write_records_rows: Sequence[Sequence[Any]] | None = None def validate(self, sid: sheets_id.SheetsIdentifier): if sid.name() is None: raise gspread_client.SpreadsheetNotFoundError( "Sheets not found: {}".format(sid) ) pass def get_all_records( self, sid: sheets_id.SheetsIdentifier, worksheet_id: int, ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: self.get_all_records_name = sid.name() self.get_all_records_worksheet_id = worksheet_id return [{"word": "quack5"}, {"word": "quack6"}], lambda: None def write_records( self, sid: sheets_id.SheetsIdentifier, rows: Sequence[Sequence[Any]], ) -> None: self.write_records_name = sid.name() self.write_records_rows = rows # Fake input variables to test the --inputs flag. _INPUT_VAR_ONE = {"word": ["quack1", "quack2"]} _INPUT_VAR_TWO = FakeInputsSource() _SHEETS_INPUT_VAR = None # Variable to test the --ground_truth flag. _GROUND_TRUTH_VAR = ["QUACK QUACK1", "NOT QUACK QUACK2"] # Variables to test the --outputs flag. _output_var: llmfn_outputs.LLMFnOutputs | None = None _output_sink_var = FakeOutputsSink() def _reset_globals() -> None: # Reset all our gloabls. global _compiled_function global _compiled_lhs_function global _compiled_rhs_function global _output_var global _output_sink_var global _SHEETS_INPUT_VAR _compiled_function = None _compiled_lhs_function = None _compiled_rhs_function = None _output_var = None _output_sink_var = FakeOutputsSink() # This should be done after the MockGSpreadClient has been set up. _SHEETS_INPUT_VAR = sheets_utils.SheetsInputs( sid=sheets_id.SheetsIdentifier(name="fake_sheets"), worksheet_id=42, ) class EndToEndTests(absltest.TestCase): def setUp(self): super().setUp() self._mock_client = MockGSpreadClient() gspread_client.testonly_set_client(self._mock_client) _reset_globals() def _assert_is_expected_pandas_dataframe( self, results: pandas.DataFrame, expected_results: Mapping[str, Any] ) -> None: self.assertIsInstance(results, pandas.DataFrame) self.assertEqual(expected_results, results.to_dict(orient="list")) self.assertEqual( list(expected_results.keys()), list(results.to_dict(orient="list").keys()), ) def _assert_output_var_is_expected_results( self, var: Any, expected_results: Mapping[str, Any], fake_env: FakeIPythonEnv, ) -> None: self.assertIsInstance(var, llmfn_outputs.LLMFnOutputs) # Make sure output vars are also populated. self.assertEqual(expected_results, var.as_dict()) self.assertEqual( list(expected_results.keys()), list(var.as_dict().keys()), ) # Make sure the output object is displayable in notebooks. self.assertTrue(hasattr(var, "_ipython_display_")) fake_env.clear() # The typechecker thinks LLMFnOutputs does not have _ipython_display_ # because the method is conditionally added. var._ipython_display_() # type: ignore self.assertIsInstance(fake_env.display_args, pandas.DataFrame) self.assertEqual( expected_results, fake_env.display_args.to_dict(orient="list") ) @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class HelpEndToEndTests(EndToEndTests): def test_help(self): mock_registry = EchoModelRegistry() magic_line = "--help" engine = magics_engine.MagicsEngine(registry=mock_registry) # Should not raise an exception. results = engine.execute_cell(magic_line, "ignored") self.assertRegex(str(results), "A system for interacting with LLMs") def test_run_help(self): mock_registry = EchoModelRegistry() magic_line = "run --help" engine = magics_engine.MagicsEngine(registry=mock_registry) # Should not raise an exception. results = engine.execute_cell(magic_line, "ignored") self.assertRegex(str(results), "usage: palm run") def test_error(self): mock_registry = EchoModelRegistry() magic_line = "run --this_is_an_invalid_flag" engine = magics_engine.MagicsEngine(registry=mock_registry) with self.assertRaisesRegex( SystemExit, "unrecognized arguments: --this_is_an_invalid_flag" ): engine.execute_cell(magic_line, "ignored") # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class RunCmdEndToEndTests(EndToEndTests): def test_run_cmd(self): """Smoke test for executing the run command.""" mock_registry = EchoModelRegistry() magic_line = "run --model_type=echo" engine = magics_engine.MagicsEngine(registry=mock_registry) results = engine.execute_cell(magic_line, "quack quack") expected_results = { "Prompt Num": [0], "Input Num": [0], "Result Num": [0], "Prompt": ["quack quack"], "text_result": ["quack quack"], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) results = engine.execute_cell(magic_line, "line 1\nline 2\n") expected_results = { "Prompt Num": [0], "Input Num": [0], "Result Num": [0], "Prompt": ["line 1\nline 2"], "text_result": ["line 1\nline 2"], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) # Confirm trailing line breaks preserved (bar 1) results = engine.execute_cell(magic_line, "line 1\nline 2\n\n") expected_results = { "Prompt Num": [0], "Input Num": [0], "Result Num": [0], "Prompt": ["line 1\nline 2\n"], "text_result": ["line 1\nline 2\n"], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) # --model_type should be parsed and passed to the ModelRegistry instance. self.assertEqual( model_registry.ModelName.ECHO_MODEL, mock_registry.get_model_name ) def test_model_args_passed(self): mock_model = mock.create_autospec(model.EchoModel) reg = EchoModelRegistry(mock_model) engine = magics_engine.MagicsEngine(registry=reg) _ = engine.execute_cell( ( "run --model_type=echo --model=the_best_model --temperature=0.25" " --candidate_count=3" ), "quack", ) expected_model_args = model.ModelArguments( model="the_best_model", temperature=0.25, candidate_count=3 ) actual_model_args = mock_model.call_model.call_args.kwargs["model_args"] self.assertEqual(actual_model_args, expected_model_args) def test_candidate_count(self): mock_registry = EchoModelRegistry() engine = magics_engine.MagicsEngine(registry=mock_registry) results = engine.execute_cell( "run --model_type=echo --candidate=3", "quack", ) expected_results = { "Prompt Num": [0, 0, 0], "Input Num": [0, 0, 0], "Result Num": [0, 1, 2], "Prompt": ["quack", "quack", "quack"], "text_result": ["quack", "quack", "quack"], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) def test_unique(self): mock_registry = EchoModelRegistry() engine = magics_engine.MagicsEngine(registry=mock_registry) results = engine.execute_cell( "run --model_type=echo --candidate=3 --unique", "quack", ) expected_results = { "Prompt Num": [0], "Input Num": [0], "Result Num": [0], "Prompt": ["quack"], "text_result": ["quack"], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) def test_inputs_passed(self): magic_line = ( "run --model_type=echo --inputs _INPUT_VAR_ONE _INPUT_VAR_TWO" " _SHEETS_INPUT_VAR" ) engine = magics_engine.MagicsEngine(registry=EchoModelRegistry()) results = engine.execute_cell(magic_line, "quack {word}") self.assertIsInstance(results, pandas.DataFrame) expected_results = { "Prompt Num": [0, 0, 0, 0, 0, 0], "Input Num": [0, 1, 2, 3, 4, 5], "Result Num": [0, 0, 0, 0, 0, 0], "Prompt": [ "quack quack1", "quack quack2", "quack quack3", "quack quack4", "quack quack5", "quack quack6", ], "text_result": [ "quack quack1", "quack quack2", "quack quack3", "quack quack4", "quack quack5", "quack quack6", ], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) self.assertEqual("fake_sheets", self._mock_client.get_all_records_name) self.assertEqual(42, self._mock_client.get_all_records_worksheet_id) def test_sheets_input_names_passed(self): """Test using --sheets_input_names.""" magic_line = "run --model_type=echo --sheets_input_names my_fake_sheets" engine = magics_engine.MagicsEngine(registry=EchoModelRegistry()) results = engine.execute_cell(magic_line, "quack {word}") self.assertIsInstance(results, pandas.DataFrame) expected_results = { "Prompt Num": [0, 0], "Input Num": [0, 1], "Result Num": [0, 0], "Prompt": [ "quack quack5", "quack quack6", ], "text_result": [ "quack quack5", "quack quack6", ], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) self.assertEqual("my_fake_sheets", self._mock_client.get_all_records_name) # The default worksheet_id should be used. self.assertEqual(0, self._mock_client.get_all_records_worksheet_id) def test_validate_inputs_against_placeholders(self): engine = magics_engine.MagicsEngine(registry=EchoModelRegistry()) with self.assertRaisesRegex( SystemExit, ( 'argument --inputs/-i: Error with value "_INPUT_VAR_ONE", got' ' ValueError: Placeholder "not_word" not found in input' ), ): engine.execute_cell( "run --model_type=echo --inputs _INPUT_VAR_ONE", "quack {not_word}" ) with self.assertRaisesRegex( SystemExit, ( 'argument --inputs/-i: Error with value "_INPUT_VAR_TWO", got' ' ValueError: Placeholder "not_word" not found in input' ), ): engine.execute_cell( "run --model_type=echo --inputs _INPUT_VAR_TWO", "quack {not_word}" ) with self.assertRaisesRegex( SystemExit, ( 'argument --inputs/-i: Error with value "_SHEETS_INPUT_VAR", got' ' ValueError: Placeholder "not_word" not found in input' ), ): engine.execute_cell( "run --model_type=echo --inputs _SHEETS_INPUT_VAR", "quack {not_word}" ) def test_validate_sheets_inputs_against_placeholders(self): engine = magics_engine.MagicsEngine(registry=EchoModelRegistry()) with self.assertRaisesRegex( SystemExit, ( "argument --sheets_input_names/-si: Error with value" ' "my_fake_sheets", got ValueError: Placeholder "not_word" not' " found in input" ), ): engine.execute_cell( "run --model_type=echo --sheets_input_names my_fake_sheets", "quack {not_word}", ) def test_post_process(self): magic_line = ( "run --model_type=echo | add_length | repeat | add_length_decorated" ) engine = magics_engine.MagicsEngine(registry=EchoModelRegistry()) results = engine.execute_cell(magic_line, "quack") self.assertIsInstance(results, pandas.DataFrame) expected_results = { "Prompt Num": [0], "Input Num": [0], "Result Num": [0], "Prompt": ["quack"], "add_length": [5], "add_length_decorated": [10], "text_result": ["quackquack"], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) def test_outputs(self): # Include post-processing commands to make sure their results are exported # as well. magic_line = ( "run --model_type=echo --outputs _output_var | add_length | repeat" ) fake_env = FakeIPythonEnv() engine = magics_engine.MagicsEngine( registry=EchoModelRegistry(), env=fake_env ) results = engine.execute_cell(magic_line, "quack") self.assertIsInstance(_output_var, llmfn_outputs.LLMFnOutputs) expected_results = { "Prompt Num": [0], "Input Num": [0], "Result Num": [0], "Prompt": ["quack"], "add_length": [5], "text_result": ["quackquack"], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) self._assert_output_var_is_expected_results( var=_output_var, expected_results=expected_results, fake_env=fake_env ) def test_outputs_sink(self): # Include post-processing commands to make sure their results are exported # as well. magic_line = ( "run --model_type=echo --outputs _output_sink_var | add_length | repeat" ) engine = magics_engine.MagicsEngine(registry=EchoModelRegistry()) results = engine.execute_cell(magic_line, "quack") self.assertIsNotNone(_output_sink_var.outputs) self.assertIsInstance(_output_sink_var.outputs, llmfn_outputs.LLMFnOutputs) expected_results = { "Prompt Num": [0], "Input Num": [0], "Result Num": [0], "Prompt": ["quack"], "add_length": [5], "text_result": ["quackquack"], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) self.assertEqual( expected_results, _output_sink_var.outputs.as_pandas_dataframe().to_dict(orient="list"), ) def test_sheets_outputs_names(self): # Include post-processing commands to make sure their results are exported # as well. magic_line = ( "run --model_type=echo --sheets_output_names my_fake_output_sheets |" " add_length | repeat" ) engine = magics_engine.MagicsEngine(registry=EchoModelRegistry()) _ = engine.execute_cell(magic_line, "quack") self.assertEqual( "my_fake_output_sheets", self._mock_client.write_records_name ) expected_rows = [ [ "Prompt Num", "Input Num", "Result Num", "Prompt", "add_length", "text_result", ], [0, 0, 0, "quack", 5, "quackquack"], ] self.assertEqual(expected_rows, self._mock_client.write_records_rows) # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class CompileCmdEndToEndTests(EndToEndTests): def test_compile_cmd(self): fake_env = FakeIPythonEnv() engine = magics_engine.MagicsEngine( registry=EchoModelRegistry(), env=fake_env ) _ = engine.execute_cell( "compile _compiled_function --model_type=echo", "quack {word}" ) # The "compile" command produces a saved function. # Execute the saved function and check that it produces the expected output. self.assertIsInstance(_compiled_function, llm_function.LLMFunction) outputs = _compiled_function({"word": ["LOUD QUACK"]}) expected_outputs = { "Prompt Num": [0], "Input Num": [0], "Result Num": [0], "Prompt": ["quack LOUD QUACK"], "text_result": ["quack LOUD QUACK"], } self._assert_output_var_is_expected_results( var=outputs, expected_results=expected_outputs, fake_env=fake_env ) # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class CompareCmdEndToEndTests(EndToEndTests): def test_compare_cmd_with_default_compare_fn(self): fake_env = FakeIPythonEnv() engine = magics_engine.MagicsEngine( registry=EchoModelRegistry(), env=fake_env ) # Create a pair of LLMFunctions to compare. _ = engine.execute_cell( "compile _compiled_lhs_function --model_type=echo", "left quack {word}" ) _ = engine.execute_cell( "compile _compiled_rhs_function --model_type=echo", "right quack {word}" ) # Run comparison. results = engine.execute_cell( ( "compare _compiled_lhs_function _compiled_rhs_function --inputs" " _INPUT_VAR_ONE --outputs _output_var" ), "ignored", ) # Check results. expected_results = { "Prompt Num": [0, 0], "Input Num": [0, 1], "Result Num": [0, 0], "Prompt vars": [{"word": "quack1"}, {"word": "quack2"}], "_compiled_lhs_function_text_result": [ "left quack quack1", "left quack quack2", ], "_compiled_rhs_function_text_result": [ "right quack quack1", "right quack quack2", ], "is_equal": [False, False], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) self._assert_output_var_is_expected_results( var=_output_var, expected_results=expected_results, fake_env=fake_env ) def test_compare_cmd_with_custom_compare_fn(self): fake_env = FakeIPythonEnv() engine = magics_engine.MagicsEngine( registry=EchoModelRegistry(), env=fake_env ) # Create a pair of LLMFunctions to compare. _ = engine.execute_cell( "compile _compiled_lhs_function --model_type=echo", "left quack {word}" ) _ = engine.execute_cell( "compile _compiled_rhs_function --model_type=echo", "right quack {word}" ) # Run comparison. results = engine.execute_cell( ( "compare _compiled_lhs_function _compiled_rhs_function --inputs" " _INPUT_VAR_ONE --outputs _output_var --compare_fn concat" " get_sum_of_lengths" ), "ignored", ) # Check results. expected_results = { "Prompt Num": [0, 0], "Input Num": [0, 1], "Result Num": [0, 0], "Prompt vars": [{"word": "quack1"}, {"word": "quack2"}], "_compiled_lhs_function_text_result": [ "left quack quack1", "left quack quack2", ], "_compiled_rhs_function_text_result": [ "right quack quack1", "right quack quack2", ], "concat": [ "left quack quack1 right quack quack1", "left quack quack2 right quack quack2", ], "get_sum_of_lengths": [35, 35], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) self._assert_output_var_is_expected_results( var=_output_var, expected_results=expected_results, fake_env=fake_env ) # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class EvalCmdEndToEndTests(EndToEndTests): def test_eval_cmd(self): fake_env = FakeIPythonEnv() engine = magics_engine.MagicsEngine( registry=EchoModelRegistry(), env=fake_env ) # Run evaluation. # # Some of the features tested are: # 1. We evoke a model flag (to make sure model flags are parsed.) # 2. We add a post-processing function to the prompt results # 3. We add a few custom comparison functions. # 4. We write to an output variable. results = engine.execute_cell( ( "eval --model_type=echo --ground_truth _GROUND_TRUTH_VAR --inputs" " _INPUT_VAR_ONE --outputs _output_var --compare_fn" " get_sum_of_lengths my_is_equal_fn | to_upper" ), "quack {word}", ) expected_results = { "Prompt Num": [0, 0], "Input Num": [0, 1], "Result Num": [0, 0], "Prompt vars": [{"word": "quack1"}, {"word": "quack2"}], "actual_text_result": ["QUACK QUACK1", "QUACK QUACK2"], "ground_truth_text_result": ["QUACK QUACK1", "NOT QUACK QUACK2"], "get_sum_of_lengths": [24, 28], "my_is_equal_fn": [True, False], } self._assert_is_expected_pandas_dataframe( results=results, expected_results=expected_results ) self._assert_output_var_is_expected_results( var=_output_var, expected_results=expected_results, fake_env=fake_env ) if __name__ == "__main__": absltest.main() 07070100000038000081A40000000000000000000000016459839500000743000000000000000000000000000000000000004E00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/model_registry.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Maintains set of LLM models that can be instantiated by name.""" from __future__ import annotations import enum from typing import Callable from google.generativeai.notebook import text_model from google.generativeai.notebook.lib import model as model_lib class ModelName(enum.Enum): ECHO_MODEL = "echo" TEXT_MODEL = "text" class ModelRegistry: """Registry that instantiates and caches models.""" DEFAULT_MODEL = ModelName.TEXT_MODEL def __init__(self): self._model_cache: dict[ModelName, model_lib.AbstractModel] = {} self._model_constructors: dict[ ModelName, Callable[[], model_lib.AbstractModel] ] = { ModelName.ECHO_MODEL: model_lib.EchoModel, ModelName.TEXT_MODEL: text_model.TextModel, } def get_model(self, model_name: ModelName) -> model_lib.AbstractModel: """Given `model_name`, return the corresponding Model instance. Model instances are cached and reused for the same `model_name`. Args: model_name: The name of the model. Returns: The corresponding model instance for `model_name`. """ if model_name not in self._model_cache: self._model_cache[model_name] = self._model_constructors[model_name]() return self._model_cache[model_name] 07070100000039000081A400000000000000000000000164598395000004E1000000000000000000000000000000000000005300000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/model_registry_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for model_registry.""" from __future__ import annotations from absl.testing import absltest from google.generativeai.notebook import model_registry class ModelRegistryTest(absltest.TestCase): def test_get_model_echo_model(self): registry = model_registry.ModelRegistry() model = registry.get_model(model_registry.ModelName.ECHO_MODEL) results = model.call_model(model_input="this_is_a_test") self.assertEqual("this_is_a_test", results.model_input) # Echo model returns the model_input as text results. self.assertEqual(["this_is_a_test"], results.text_results) if __name__ == "__main__": absltest.main() 0707010000003A000081A400000000000000000000000164598395000007E7000000000000000000000000000000000000004C00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/output_utils.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for exporting outputs from LLMFunctions.""" from __future__ import annotations import copy from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import py_utils from google.generativeai.notebook.lib import llmfn_outputs class _PyVarOutputsSink(llmfn_outputs.LLMFnOutputsSink): """Sink that writes results to a Python variable.""" def __init__(self, var_name: str): self._var_name = var_name def write_outputs(self, outputs: llmfn_outputs.LLMFnOutputsBase) -> None: # Clone our results so that they are all independent. py_utils.set_py_var(self._var_name, copy.deepcopy(outputs)) def get_outputs_sink_from_py_var( var_name: str, ) -> llmfn_outputs.LLMFnOutputsSink: # The output variable `var_name` will be created if it does not already # exist. if py_utils.has_py_var(var_name): data = py_utils.get_py_var(var_name) if isinstance(data, llmfn_outputs.LLMFnOutputsSink): return data return _PyVarOutputsSink(var_name) def write_to_outputs( results: llmfn_outputs.LLMFnOutputs, parsed_args: parsed_args_lib.ParsedArgs, ) -> None: """Writes `results` to the sinks provided. Args: results: The results to export. parsed_args: Arguments parsed from the command line. """ for sink in parsed_args.outputs: results.export(sink) for sink in parsed_args.sheets_output_names: results.export(sink) 0707010000003B000081A40000000000000000000000016459839500000BE8000000000000000000000000000000000000004F00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/parsed_args_lib.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Results from parsing the commandline. This module separates the results from commandline parsing from the parser itself so that classes that operate on the results (e.g. the subclasses of Command) do not have to depend on the commandline parser as well. """ from __future__ import annotations import dataclasses import enum from typing import Any, Callable, Sequence from google.generativeai.notebook import model_registry from google.generativeai.notebook.lib import llm_function from google.generativeai.notebook.lib import llmfn_inputs_source from google.generativeai.notebook.lib import llmfn_outputs from google.generativeai.notebook.lib import model as model_lib # Post processing tokens are represented as a sequence of sequence of tokens, # because the pipe operator could be used more than once. PostProcessingTokens = Sequence[Sequence[str]] # The type of function taken by the "compare_fn" flag. # It takes the text_results of the left- and right-hand side functions as # inputs and returns a comparison result. TextResultCompareFn = Callable[[str, str], Any] class CommandName(enum.Enum): RUN_CMD = "run" COMPILE_CMD = "compile" COMPARE_CMD = "compare" EVAL_CMD = "eval" @dataclasses.dataclass(frozen=True) class ParsedArgs: """The results of parsing the command line.""" cmd: CommandName # For run, compile and eval commands. model_args: model_lib.ModelArguments model_type: model_registry.ModelName | None = None unique: bool = False # For run, compare and eval commands. inputs: Sequence[llmfn_inputs_source.LLMFnInputsSource] = dataclasses.field( default_factory=list ) sheets_input_names: Sequence[llmfn_inputs_source.LLMFnInputsSource] = ( dataclasses.field(default_factory=list) ) outputs: Sequence[llmfn_outputs.LLMFnOutputsSink] = dataclasses.field( default_factory=list ) sheets_output_names: Sequence[llmfn_outputs.LLMFnOutputsSink] = ( dataclasses.field(default_factory=list) ) # For compile command. compile_save_name: str | None = None # For compare command. lhs_name_and_fn: tuple[str, llm_function.LLMFunction] | None = None rhs_name_and_fn: tuple[str, llm_function.LLMFunction] | None = None # For compare and eval commands. compare_fn: Sequence[tuple[str, TextResultCompareFn]] = dataclasses.field( default_factory=list ) # For eval command. ground_truth: Sequence[str] = dataclasses.field(default_factory=list) 0707010000003C000081A400000000000000000000000164598395000015ED000000000000000000000000000000000000005200000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/post_process_utils.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for working with post-processing tokens.""" from __future__ import annotations import abc from typing import Any, Callable, Sequence from google.generativeai.notebook import py_utils from google.generativeai.notebook.lib import llm_function from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import llmfn_post_process class PostProcessParseError(RuntimeError): """An error parsing the post-processing tokens.""" class ParsedPostProcessExpr(abc.ABC): """A post-processing expression parsed from the command line.""" @abc.abstractmethod def name(self) -> str: """Returns the name of this expression.""" @abc.abstractmethod def add_to_llm_function( self, llm_fn: llm_function.LLMFunction ) -> llm_function.LLMFunction: """Adds this parsed expression to `llm_fn` as a post-processing command.""" class _ParsedPostProcessAddExpr( ParsedPostProcessExpr, llmfn_post_process.LLMFnPostProcessBatchAddFn ): """An expression that returns the value of a new column to add to a row.""" def __init__(self, name: str, fn: Callable[[str], Any]): """Constructor. Args: name: The name of the expression. The name of the new column will be derived from this. fn: A function that takes the result of a row and returns a new value to add as a new column in the row. """ self._name = name self._fn = fn def name(self) -> str: return self._name def __call__( self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView] ) -> Sequence[Any]: return [self._fn(row.result_value()) for row in rows] def add_to_llm_function( self, llm_fn: llm_function.LLMFunction ) -> llm_function.LLMFunction: return llm_fn.add_post_process_add_fn(name=self._name, fn=self) class _ParsedPostProcessReplaceExpr( ParsedPostProcessExpr, llmfn_post_process.LLMFnPostProcessBatchReplaceFn ): """An expression that returns the new result value for a row.""" def __init__(self, name: str, fn: Callable[[str], str]): """Constructor. Args: name: The name of the expression. fn: A function that takes the result of a row and returns the new result. """ self._name = name self._fn = fn def name(self) -> str: return self._name def __call__( self, rows: Sequence[llmfn_output_row.LLMFnOutputRowView] ) -> Sequence[str]: return [self._fn(row.result_value()) for row in rows] def add_to_llm_function( self, llm_fn: llm_function.LLMFunction ) -> llm_function.LLMFunction: return llm_fn.add_post_process_replace_fn(name=self._name, fn=self) # Decorator functions. def post_process_add_fn(fn: Callable[[str], Any]): return _ParsedPostProcessAddExpr(name=fn.__name__, fn=fn) def post_process_replace_fn(fn: Callable[[str], str]): return _ParsedPostProcessReplaceExpr(name=fn.__name__, fn=fn) def validate_one_post_processing_expression( tokens: Sequence[str], ) -> None: if not tokens: raise PostProcessParseError("Cannot have empty post-processing expression") if len(tokens) > 1: raise PostProcessParseError( "Post-processing expression should be a single token" ) def _resolve_one_post_processing_expression( tokens: Sequence[str], ) -> tuple[str, Any]: """Returns name and the resolved expression.""" validate_one_post_processing_expression(tokens) token_parts = tokens[0].split(".") current_module = py_utils.get_main_module() for part_num, part in enumerate(token_parts): current_module_vars = vars(current_module) if part not in current_module_vars: raise PostProcessParseError( 'Unable to resolve "{}"'.format(".".join(token_parts[: part_num + 1])) ) current_module = current_module_vars[part] return (" ".join(tokens), current_module) def resolve_post_processing_tokens( tokens: Sequence[Sequence[str]], ) -> Sequence[ParsedPostProcessExpr]: """Resolves post-processing tokens into ParsedPostProcessExprs. E.g. Given [["add_length"], ["to_upper"]] as input, this function will return a sequence of ParsedPostProcessExprs that will execute add_length() and to_upper() on each entry of the LLM output as post-processing operations. Raises: PostProcessParseError: An error parsing or resolving the tokens. Args: tokens: A sequence of post-processing tokens after splitting. Returns: A sequence of ParsedPostProcessExprs. """ results: list[ParsedPostProcessExpr] = [] for expression in tokens: expr_name, expr_value = _resolve_one_post_processing_expression(expression) if isinstance(expr_value, ParsedPostProcessExpr): results.append(expr_value) elif isinstance(expr_value, Callable): # By default, assume that an undecorated function is an "add" function. results.append(_ParsedPostProcessAddExpr(name=expr_name, fn=expr_value)) else: raise PostProcessParseError("{} is not callable".format(expr_name)) return results 0707010000003D000081A400000000000000000000000164598395000021AD000000000000000000000000000000000000005700000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/post_process_utils_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittests for post_process_utils.""" from __future__ import annotations import sys from unittest import mock from absl.testing import absltest from google.generativeai.notebook import post_process_utils from google.generativeai.notebook import post_process_utils_test_helper as helper from google.generativeai.notebook.lib import llm_function from google.generativeai.notebook.lib import llmfn_output_row from google.generativeai.notebook.lib import model as model_lib NOT_A_FUNCTION = "this is a string not a function" LLMFnOutputRow = llmfn_output_row.LLMFnOutputRow LLMFnOutputRowView = llmfn_output_row.LLMFnOutputRowView PostProcessParseError = post_process_utils.PostProcessParseError def add_length(x: str) -> int: return len(x) @post_process_utils.post_process_add_fn def add_length_decorated(x: str) -> int: return len(x) @post_process_utils.post_process_replace_fn def to_upper(x: str) -> str: return x.upper() # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class PostProcessUtilsResolveTest(absltest.TestCase): def test_cannot_resolve_empty_expression(self): with self.assertRaisesRegex(PostProcessParseError, "Cannot have empty"): post_process_utils._resolve_one_post_processing_expression([]) def test_cannot_resolve_multiword_expression(self): with self.assertRaisesRegex( PostProcessParseError, "should be a single token" ): post_process_utils._resolve_one_post_processing_expression( ["hello", "world"] ) def test_cannot_resolve_invalid_module(self): with self.assertRaisesRegex( PostProcessParseError, 'Unable to resolve "invalid_module"' ): post_process_utils._resolve_one_post_processing_expression( ["invalid_module.add_length"] ) def test_cannot_resolve_invalid_function(self): with self.assertRaisesRegex( PostProcessParseError, 'Unable to resolve "helper.invalid_function"' ): post_process_utils._resolve_one_post_processing_expression( ["helper.invalid_function"] ) def test_resolve_undecorated_function(self): name, expr = post_process_utils._resolve_one_post_processing_expression( ["add_length"] ) self.assertEqual("add_length", name) self.assertEqual(add_length, expr) self.assertEqual(11, expr("hello_world")) def test_resolve_decorated_add_function(self): name, expr = post_process_utils._resolve_one_post_processing_expression( ["add_length_decorated"] ) self.assertEqual("add_length_decorated", name) self.assertEqual(add_length_decorated, expr) self.assertIsInstance(expr, post_process_utils._ParsedPostProcessAddExpr) self.assertEqual( [11], expr( [ LLMFnOutputRow( data={"text_result": "hello_world"}, result_type=str ) ] ), ) def test_resolve_decorated_replace_function(self): # Test to_upper(). name, expr = post_process_utils._resolve_one_post_processing_expression( ["to_upper"] ) self.assertEqual("to_upper", name) self.assertEqual(to_upper, expr) self.assertIsInstance( expr, post_process_utils._ParsedPostProcessReplaceExpr ) self.assertEqual( ["HELLO_WORLD"], expr( [ LLMFnOutputRow( data={"text_result": "hello_world"}, result_type=str ) ] ), ) def test_resolve_module_undecorated_function(self): name, expr = post_process_utils._resolve_one_post_processing_expression( ["helper.add_length"] ) self.assertEqual("helper.add_length", name) self.assertEqual(helper.add_length, expr) self.assertEqual(11, expr("hello_world")) def test_resolve_module_decorated_add_function(self): name, expr = post_process_utils._resolve_one_post_processing_expression( ["helper.add_length_decorated"] ) self.assertEqual("helper.add_length_decorated", name) self.assertEqual(helper.add_length_decorated, expr) self.assertIsInstance(expr, post_process_utils._ParsedPostProcessAddExpr) self.assertEqual( [11], expr( [ LLMFnOutputRow( data={"text_result": "hello_world"}, result_type=str ) ] ), ) def test_resolve_module_decorated_replace_function(self): name, expr = post_process_utils._resolve_one_post_processing_expression( ["helper.to_upper"] ) self.assertEqual("helper.to_upper", name) self.assertEqual(helper.to_upper, expr) self.assertIsInstance( expr, post_process_utils._ParsedPostProcessReplaceExpr ) self.assertEqual( ["HELLO_WORLD"], expr( [ LLMFnOutputRow( data={"text_result": "hello_world"}, result_type=str ) ] ), ) # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class PostProcessUtilsTest(absltest.TestCase): def test_must_be_callable(self): with self.assertRaisesRegex( PostProcessParseError, "NOT_A_FUNCTION is not callable" ): post_process_utils.resolve_post_processing_tokens([["NOT_A_FUNCTION"]]) def test_parsed_post_process_add_fn(self): """Test that from a post-processing token to an updated LLMFunction.""" parsed_exprs = post_process_utils.resolve_post_processing_tokens([ ["add_length"], ]) self.assertLen(parsed_exprs, 1) self.assertIsInstance( parsed_exprs[0], post_process_utils._ParsedPostProcessAddExpr ) llm_fn = llm_function.LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["hello"] ) parsed_exprs[0].add_to_llm_function(llm_fn) results = llm_fn() self.assertEqual( { "Input Num": [0], "Prompt Num": [0], "Prompt": ["hello"], "Result Num": [0], "add_length": [5], "text_result": ["hello"], }, results.as_dict(), ) def test_parsed_post_process_replace_fn(self): parsed_exprs = post_process_utils.resolve_post_processing_tokens([ ["to_upper"], ]) self.assertLen(parsed_exprs, 1) self.assertIsInstance( parsed_exprs[0], post_process_utils._ParsedPostProcessReplaceExpr ) llm_fn = llm_function.LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["hello"] ) parsed_exprs[0].add_to_llm_function(llm_fn) results = llm_fn() self.assertEqual( { "Input Num": [0], "Prompt Num": [0], "Prompt": ["hello"], "Result Num": [0], "text_result": ["HELLO"], }, results.as_dict(), ) def test_resolve_post_processing_tokens(self): parsed_exprs = post_process_utils.resolve_post_processing_tokens([ ["add_length"], ["to_upper"], ["add_length_decorated"], ["helper.add_length"], ["helper.add_length_decorated"], ["helper.to_upper"], ]) for fn in parsed_exprs: self.assertIsInstance(fn, post_process_utils.ParsedPostProcessExpr) llm_fn = llm_function.LLMFunctionImpl( model=model_lib.EchoModel(), prompts=["hello"] ) for expr in parsed_exprs: expr.add_to_llm_function(llm_fn) results = llm_fn() self.assertEqual( { "Input Num": [0], "Prompt Num": [0], "Prompt": ["hello"], "Result Num": [0], "add_length": [5], "add_length_decorated": [5], "add_length_decorated_1": [5], "helper.add_length": [5], "text_result": ["HELLO"], }, results.as_dict(), ) if __name__ == "__main__": absltest.main() 0707010000003E000081A400000000000000000000000164598395000003D9000000000000000000000000000000000000005E00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/post_process_utils_test_helper.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Helper module for post_process_utils_test.""" from __future__ import annotations from google.generativeai.notebook import post_process_utils def add_length(x: str) -> int: return len(x) @post_process_utils.post_process_add_fn def add_length_decorated(x: str) -> int: return len(x) @post_process_utils.post_process_replace_fn def to_upper(x: str) -> str: return x.upper() 0707010000003F000081A400000000000000000000000164598395000007D3000000000000000000000000000000000000004800000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/py_utils.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Convenience functions for writing to and reading from Python variables.""" from __future__ import annotations import builtins import keyword import sys from typing import Any def validate_var_name(var_name: str) -> None: """Validates that the variable name is a valid identifier.""" if not var_name.isidentifier(): raise ValueError('Invalid Python variable name, got "{}"'.format(var_name)) if keyword.iskeyword(var_name): raise ValueError('Cannot use Python keywords, got "{}"'.format(var_name)) def get_main_module(): return sys.modules['__main__'] def get_py_var(var_name: str) -> Any: """Retrieves the value of `var_name` from the global environment.""" validate_var_name(var_name) g_vars = vars(get_main_module()) if var_name in g_vars: return g_vars[var_name] elif var_name in vars(builtins): return vars(builtins)[var_name] raise NameError('"{}" not found'.format(var_name)) def has_py_var(var_name: str) -> bool: """Returns true if `var_name` is defined in the global environment.""" try: validate_var_name(var_name) _ = get_py_var(var_name) except ValueError: return False except NameError: return False return True def set_py_var(var_name: str, val: Any) -> None: """Sets the value of `var_name` in the global environment.""" validate_var_name(var_name) g_vars = vars(get_main_module()) g_vars[var_name] = val 07070100000040000081A400000000000000000000000164598395000006B5000000000000000000000000000000000000004D00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/py_utils_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for py_utils.""" from __future__ import annotations import sys from unittest import mock from absl.testing import absltest from google.generativeai.notebook import py_utils _INPUT_VAR = "hello world" _OUTPUT_VAR = None # `unittest discover` does not run via __main__, so patch this context in. @mock.patch.dict(sys.modules, {"__main__": sys.modules[__name__]}) class PyUtilsTest(absltest.TestCase): def test_get_py_var(self): # get_py_var() with an invalid var should raise an error. with self.assertRaisesRegex(NameError, "IncorrectVar"): py_utils.get_py_var("IncorrectVar") results = py_utils.get_py_var("_INPUT_VAR") self.assertEqual("hello world", results) def test_set_py_var(self): py_utils.set_py_var("_OUTPUT_VAR", "world hello") self.assertEqual("world hello", _OUTPUT_VAR) # Calling with a new variable name creates a new variable. py_utils.set_py_var("_NEW_VAR", "world hello world") # pylint: disable-next=undefined-variable self.assertEqual("world hello world", _NEW_VAR) # type: ignore if __name__ == "__main__": absltest.main() 07070100000041000081A40000000000000000000000016459839500000A4C000000000000000000000000000000000000004700000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/run_cmd.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The run command.""" from __future__ import annotations from typing import Sequence from google.generativeai.notebook import command from google.generativeai.notebook import command_utils from google.generativeai.notebook import input_utils from google.generativeai.notebook import ipython_env from google.generativeai.notebook import model_registry from google.generativeai.notebook import output_utils from google.generativeai.notebook import parsed_args_lib from google.generativeai.notebook import post_process_utils import pandas class RunCommand(command.Command): """Implementation of the "run" command.""" def __init__( self, models: model_registry.ModelRegistry, env: ipython_env.IPythonEnv | None = None, ): """Constructor. Args: models: ModelRegistry instance. env: The IPythonEnv environment. """ super().__init__() self._models = models self._ipython_env = env def execute( self, parsed_args: parsed_args_lib.ParsedArgs, cell_content: str, post_processing_fns: Sequence[post_process_utils.ParsedPostProcessExpr], ) -> pandas.DataFrame: # We expect CmdLineParser to have already read the inputs once to validate # that the placeholders in the prompt are present in the inputs, so we can # suppress the status messages here. inputs = input_utils.join_inputs_sources( parsed_args, suppress_status_msgs=True ) llm_fn = command_utils.create_llm_function( models=self._models, env=self._ipython_env, parsed_args=parsed_args, cell_content=cell_content, post_processing_fns=post_processing_fns, ) results = llm_fn(inputs=inputs) output_utils.write_to_outputs(results=results, parsed_args=parsed_args) return results.as_pandas_dataframe() def parse_post_processing_tokens( self, tokens: Sequence[Sequence[str]] ) -> Sequence[post_process_utils.ParsedPostProcessExpr]: return post_process_utils.resolve_post_processing_tokens(tokens) 07070100000042000081A40000000000000000000000016459839500000B55000000000000000000000000000000000000004900000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/sheets_id.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module for classes related to identifying a Sheets document.""" from __future__ import annotations import re from google.generativeai.notebook import sheets_sanitize_url def _sanitize_key(key: str) -> str: if not re.fullmatch("[a-zA-Z0-9_-]+", key): raise ValueError('"{}" is not a valid Sheets key'.format(key)) return key class SheetsURL: """Class that enforces safety by ensuring that URLs are sanitized.""" def __init__(self, url: str): self._url: str = sheets_sanitize_url.sanitize_sheets_url(url) def __str__(self) -> str: return self._url class SheetsKey: """Class that enforces safety by ensuring that keys are sanitized.""" def __init__(self, key: str): self._key: str = _sanitize_key(key) def __str__(self) -> str: return self._key class SheetsIdentifier: """Encapsulates a means to identify a Sheets document. The gspread library provides three ways to look up a Sheets document: by name, by url and by key. An instance of this class represents exactly one of the methods. """ def __init__( self, name: str | None = None, key: SheetsKey | None = None, url: SheetsURL | None = None, ): """Constructor. Exactly one of the arguments should be provided. Args: name: The name of the Sheets document. More-than-one Sheets documents can have the same name, so this is the least precise method of identifying the document. key: The key of the Sheets document url: The url to the Sheets document Raises: ValueError: If the caller does not specify exactly one of name, url or key. """ self._name = name self._key = key self._url = url # There should be exactly one. num_inputs = ( int(bool(self._name)) + int(bool(self._key)) + int(bool(self._url)) ) if num_inputs != 1: raise ValueError("Must set exactly one of name, key or url") def name(self) -> str | None: return self._name def key(self) -> SheetsKey | None: return self._key def url(self) -> SheetsURL | None: return self._url def __str__(self): if self._name: return "name={}".format(self._name) elif self._key: return "key={}".format(self._key) else: return "url={}".format(self._url) 07070100000043000081A4000000000000000000000001645983950000082B000000000000000000000000000000000000004E00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/sheets_id_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for py_utils.""" from __future__ import annotations from absl.testing import absltest from google.generativeai.notebook import sheets_id class SheetsIdentifierTest(absltest.TestCase): def test_constructor(self): sid = sheets_id.SheetsIdentifier(name="hello") self.assertEqual("name=hello", str(sid)) sid = sheets_id.SheetsIdentifier(key=sheets_id.SheetsKey("hello")) self.assertEqual("key=hello", str(sid)) sid = sheets_id.SheetsIdentifier( url=sheets_id.SheetsURL("https://docs.google.com/") ) self.assertEqual("url=https://docs.google.com/", str(sid)) def test_constructor_error(self): with self.assertRaisesRegex( ValueError, "Must set exactly one of name, key or url" ): sheets_id.SheetsIdentifier() # Empty "name" is also considered an invalid name. with self.assertRaisesRegex( ValueError, "Must set exactly one of name, key or url" ): sheets_id.SheetsIdentifier(name="") with self.assertRaisesRegex( ValueError, "Must set exactly one of name, key or url" ): sheets_id.SheetsIdentifier(name="hello", key=sheets_id.SheetsKey("hello")) with self.assertRaisesRegex( ValueError, "Must set exactly one of name, key or url" ): sheets_id.SheetsIdentifier( name="hello", key=sheets_id.SheetsKey("hello"), url=sheets_id.SheetsURL("https://docs.google.com/"), ) if __name__ == "__main__": absltest.main() 07070100000044000081A40000000000000000000000016459839500000B75000000000000000000000000000000000000005300000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/sheets_sanitize_url.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for working with URLs.""" from __future__ import annotations import re from urllib import parse def _validate_url_part(part: str) -> None: if not re.fullmatch('[a-zA-Z0-9_-]*', part): raise ValueError( '"{}" is outside the restricted character set'.format(part) ) def _validate_url_query_or_fragment(part: str) -> None: for key, values in parse.parse_qs(part).items(): _validate_url_part(key) for value in values: _validate_url_part(value) def sanitize_sheets_url(url: str) -> str: """Sanitize a Sheets URL. Run some saftey checks to check whether `url` is a Sheets URL. This is not a general-purpose URL sanitizer. Rather, it makes use of the fact that we know the URL has to be for Sheets so we can make a few assumptions about (e.g. the domain). Args: url: The url to sanitize. Returns: The sanitized url. Raises: ValueError: If `url` does not match the expected restrictions for a Sheets URL. """ parse_result = parse.urlparse(url) if parse_result.scheme != 'https': raise ValueError( 'Scheme for Sheets url must be "https", got "{}"'.format( parse_result.scheme ) ) if parse_result.netloc not in ('docs.google.com', 'sheets.googleapis.com'): raise ValueError( 'Domain for Sheets url must be "docs.google.com", got "{}"'.format( parse_result.netloc ) ) # Path component. try: for fragment in parse_result.path.split('/'): _validate_url_part(fragment) except ValueError as exc: raise ValueError( 'Invalid path for Sheets url, got "{}"'.format(parse_result.path) ) from exc # Params component. if parse_result.params: raise ValueError( 'Params component must be empty, got "{}"'.format(parse_result.params) ) # Query component. try: _validate_url_query_or_fragment(parse_result.query) except ValueError as exc: raise ValueError( 'Invalid query for Sheets url, got "{}"'.format(parse_result.query) ) from exc # Fragment component. try: _validate_url_query_or_fragment(parse_result.fragment) except ValueError as exc: raise ValueError( 'Invalid fragment for Sheets url, got "{}"'.format( parse_result.fragment ) ) from exc return url 07070100000045000081A40000000000000000000000016459839500000EBB000000000000000000000000000000000000005800000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/sheets_sanitize_url_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unittest for sheets_sanitize_url.""" from __future__ import annotations from absl.testing import absltest from google.generativeai.notebook import sheets_sanitize_url sanitize_sheets_url = sheets_sanitize_url.sanitize_sheets_url class SheetsSanitizeURLTest(absltest.TestCase): def test_scheme_must_be_https(self): """The URL must be https://.""" with self.assertRaisesRegex( ValueError, 'Scheme for Sheets url must be "https", got "http"' ): sanitize_sheets_url("http://docs.google.com") # HTTPS goes through. url = sanitize_sheets_url("https://docs.google.com") self.assertEqual("https://docs.google.com", str(url)) def test_domain_must_be_docs_google_com(self): """Domain must be docs.google.com.""" with self.assertRaisesRegex( ValueError, ( 'Domain for Sheets url must be "docs.google.com", got' ' "sheets.google.com"' ), ): sanitize_sheets_url("https://sheets.google.com") # docs.google.com goes through. url = sanitize_sheets_url("https://docs.google.com") self.assertEqual("https://docs.google.com", str(url)) def test_params_must_be_docs_google_com(self): """Params component must be empty.""" with self.assertRaisesRegex( ValueError, 'Params component must be empty, got "hello"' ): sanitize_sheets_url("https://docs.google.com/;hello") # URL without params goes through. url = sanitize_sheets_url("https://docs.google.com") self.assertEqual("https://docs.google.com", str(url)) def test_path_must_be_limited_character_set(self): """Path can only contain a limited character set.""" with self.assertRaisesRegex( ValueError, 'Invalid path for Sheets url, got "/abc/def/sheets.php' ): sanitize_sheets_url("https://docs.google.com/abc/def/sheets.php") # Valid path goes through. url = sanitize_sheets_url("https://docs.google.com/abc/DEF/123/-_-") self.assertEqual("https://docs.google.com/abc/DEF/123/-_-", str(url)) def test_query_must_be_limited_character_set(self): """Query can only contain a limited character set.""" with self.assertRaisesRegex( ValueError, 'Invalid query for Sheets url, got "a=b&key=sheets.php"' ): sanitize_sheets_url("https://docs.google.com/?a=b&key=sheets.php") # Valid query goes through. url = sanitize_sheets_url( "https://docs.google.com/?k1=abc&k2=DEF&k3=123&k4=-_-" ) self.assertEqual( "https://docs.google.com/?k1=abc&k2=DEF&k3=123&k4=-_-", str(url) ) def test_fragment_must_be_limited_character_set(self): """Fragment can only contain a limited character set.""" with self.assertRaisesRegex( ValueError, 'Invalid fragment for Sheets url, got "a=b&key=sheets.php"' ): sanitize_sheets_url("https://docs.google.com/#a=b&key=sheets.php") # Valid fragment goes through. url = sanitize_sheets_url( "https://docs.google.com/#k1=abc&k2=DEF&k3=123&k4=-_-" ) self.assertEqual( "https://docs.google.com/#k1=abc&k2=DEF&k3=123&k4=-_-", str(url) ) if __name__ == "__main__": absltest.main() 07070100000046000081A40000000000000000000000016459839500000EE0000000000000000000000000000000000000004C00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/sheets_utils.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """SheetsInputs.""" from __future__ import annotations from typing import Any, Callable, Mapping, Sequence from urllib import parse from google.generativeai.notebook import gspread_client from google.generativeai.notebook import sheets_id from google.generativeai.notebook.lib import llmfn_inputs_source from google.generativeai.notebook.lib import llmfn_outputs def _try_sheet_id_as_url(value: str) -> sheets_id.SheetsIdentifier | None: """Try to open a Sheets document with `value` as a URL.""" try: parse_result = parse.urlparse(value) except ValueError: # If there's a URL parsing error, then it's not a URL. return None if parse_result.scheme: # If it looks like a URL, try to open the document as a URL but don't fall # back to trying as key or name since it's very unlikely that a key or name # looks like a URL. sid = sheets_id.SheetsIdentifier(url=sheets_id.SheetsURL(value)) gspread_client.get_client().validate(sid) return sid return None def _try_sheet_id_as_key(value: str) -> sheets_id.SheetsIdentifier | None: """Try to open a Sheets document with `value` as a key.""" try: sid = sheets_id.SheetsIdentifier(key=sheets_id.SheetsKey(value)) except ValueError: # `value` is not a well-formed Sheets key. return None try: gspread_client.get_client().validate(sid) except gspread_client.SpreadsheetNotFoundError: return None return sid def _try_sheet_id_as_name(value: str) -> sheets_id.SheetsIdentifier | None: """Try to open a Sheets document with `value` as a name.""" sid = sheets_id.SheetsIdentifier(name=value) try: gspread_client.get_client().validate(sid) except gspread_client.SpreadsheetNotFoundError: return None return sid def get_sheets_id_from_str(value: str) -> sheets_id.SheetsIdentifier: if sid := _try_sheet_id_as_url(value): return sid if sid := _try_sheet_id_as_key(value): return sid if sid := _try_sheet_id_as_name(value): return sid raise RuntimeError( 'No Sheets found with "{}" as URL, key or name'.format(value) ) class SheetsInputs(llmfn_inputs_source.LLMFnInputsSource): """Inputs to an LLMFunction from Google Sheets.""" def __init__(self, sid: sheets_id.SheetsIdentifier, worksheet_id: int = 0): super().__init__() self._sid = sid self._worksheet_id = worksheet_id def _to_normalized_inputs_impl( self, ) -> tuple[Sequence[Mapping[str, str]], Callable[[], None]]: return gspread_client.get_client().get_all_records( sid=self._sid, worksheet_id=self._worksheet_id ) class SheetsOutputs(llmfn_outputs.LLMFnOutputsSink): """Writes outputs from an LLMFunction to Google Sheets.""" def __init__(self, sid: sheets_id.SheetsIdentifier): self._sid = sid def write_outputs(self, outputs: llmfn_outputs.LLMFnOutputsBase) -> None: # Transpose `outputs` into a list of rows. outputs_dict = outputs.as_dict() outputs_rows: list[Sequence[Any]] = [] outputs_rows.append(list(outputs_dict.keys())) outputs_rows.extend([list(x) for x in zip(*outputs_dict.values())]) gspread_client.get_client().write_records( sid=self._sid, rows=outputs_rows, ) 07070100000047000081A40000000000000000000000016459839500000842000000000000000000000000000000000000004A00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/text_model.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model that uses the Text service.""" from __future__ import annotations from google.api_core import retry from google.generativeai import text from google.generativeai.notebook.lib import model as model_lib class TextModel(model_lib.AbstractModel): """Concrete model that uses the Text service.""" def _generate_text( self, prompt: str, model: str | None = None, temperature: float | None = None, candidate_count: int | None = None, **kwargs, ) -> text.Completion: if model is not None: kwargs["model"] = model if temperature is not None: kwargs["temperature"] = temperature if candidate_count is not None: kwargs["candidate_count"] = candidate_count return text.generate_text(prompt=prompt, **kwargs) def call_model( self, model_input: str, model_args: model_lib.ModelArguments | None = None ) -> model_lib.ModelResults: if model_args is None: model_args = model_lib.ModelArguments() # Wrap the generation function here, rather than decorate, so that it # applies to any overridden calls too. retryable_fn = retry.Retry(retry.if_transient_error)(self._generate_text) response = retryable_fn( prompt=model_input, model=model_args.model, temperature=model_args.temperature, candidate_count=model_args.candidate_count, ) return model_lib.ModelResults( model_input=model_input, text_results=[x["output"] for x in response.candidates], ) 07070100000048000081A40000000000000000000000016459839500000BE3000000000000000000000000000000000000004F00000000generative-ai-python-0.1.0~rc2/google/generativeai/notebook/text_model_test.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from unittest import mock from absl.testing import absltest from google.api_core import exceptions from google.generativeai import text from google.generativeai.notebook import text_model from google.generativeai.notebook.lib import model as model_lib def _fake_generator( prompt: str, model: str | None = None, temperature: float | None = None, candidate_count: int | None = None, ) -> text.Completion: return text.Completion( prompt=prompt, model=model, temperature=temperature, candidate_count=candidate_count, # Smuggle the parameters as text output, so we can make assertions. candidates=[ {'output': f'{prompt}_1'}, {'output': model}, {'output': temperature}, {'output': candidate_count}, ], ) class TestModel(text_model.TextModel): """A TextModel, but with _generate_text stubbed out.""" def _generate_text( self, prompt: str, model: str | None = None, temperature: float | None = None, candidate_count: int | None = None, **kwargs, ) -> text.Completion: return _fake_generator( prompt=prompt, model=model, temperature=temperature, candidate_count=candidate_count, ) class TextModelTestCase(absltest.TestCase): def test_generate_text(self): model = TestModel() result = model.call_model('prompt goes in') self.assertEqual(result.text_results[0], 'prompt goes in_1') self.assertIsNone(result.text_results[1]) self.assertIsNone(result.text_results[2]) self.assertIsNone(result.text_results[3]) args = model_lib.ModelArguments( model='model_name', temperature=0.42, candidate_count=5 ) result = model.call_model('prompt goes in', args) self.assertEqual(result.text_results[0], 'prompt goes in_1') self.assertEqual(result.text_results[1], 'model_name') self.assertEqual(result.text_results[2], 0.42) self.assertEqual(result.text_results[3], 5) def test_retry(self): model = TestModel() with mock.patch.object(model, '_generate_text') as erroneous_generator: erroneous_generator.side_effect = [ exceptions.ResourceExhausted('Over quota'), mock.DEFAULT, ] _ = model.call_model('phew it worked') self.assertEqual(erroneous_generator.call_count, 2) if __name__ == '__main__': absltest.main() 07070100000049000081A40000000000000000000000016459839500002011000000000000000000000000000000000000003B00000000generative-ai-python-0.1.0~rc2/google/generativeai/text.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import dataclasses from typing import List, Iterable, Iterator, Optional, Union import google.ai.generativelanguage as glm from google.generativeai.client import get_default_text_client from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai.types import safety_types def _make_text_prompt(prompt: Union[str, dict[str, str]]) -> glm.TextPrompt: if isinstance(prompt, str): return glm.TextPrompt(text=prompt) elif isinstance(prompt, dict): return glm.TextPrompt(prompt) else: TypeError("Expected string or dictionary for text prompt.") def _make_generate_text_request( *, model: model_types.ModelNameOptions = "models/chat-lamda-001", prompt: Optional[str] = None, temperature: Optional[float] = None, candidate_count: Optional[int] = None, max_output_tokens: Optional[int] = None, top_p: Optional[int] = None, top_k: Optional[int] = None, safety_settings: Optional[List[safety_types.SafetySettingDict]] = None, stop_sequences: Union[str, Iterable[str]] = None, ) -> glm.GenerateTextRequest: model = model_types.make_model_name(model) prompt = _make_text_prompt(prompt=prompt) if isinstance(stop_sequences, str): stop_sequences = [stop_sequences] if stop_sequences: stop_sequences = list(stop_sequences) return glm.GenerateTextRequest( model=model, prompt=prompt, temperature=temperature, candidate_count=candidate_count, max_output_tokens=max_output_tokens, top_p=top_p, top_k=top_k, safety_settings=safety_settings, stop_sequences=stop_sequences, ) def generate_text( *, model: Optional[model_types.ModelNameOptions] = "models/text-bison-001", prompt: str, temperature: Optional[float] = None, candidate_count: Optional[int] = None, max_output_tokens: Optional[int] = None, top_p: Optional[float] = None, top_k: Optional[float] = None, safety_settings: Optional[Iterable[safety.SafetySettingDict]] = None, stop_sequences: Union[str, Iterable[str]] = None, client: Optional[glm.TextServiceClient] = None, ) -> text_types.Completion: """Calls the API and returns a `types.Completion` containing the response. Args: model: Which model to call, as a string or a `types.Model`. prompt: Free-form input text given to the model. Given a prompt, the model will generate text that completes the input text. temperature: Controls the randomness of the output. Must be positive. Typical values are in the range: `[0.0,1.0]`. Higher values produce a more random and varied response. A temperature of zero will be deterministic. candidate_count: The **maximum** number of generated response messages to return. This value must be between `[1, 8]`, inclusive. If unset, this will default to `1`. Note: Only unique candidates are returned. Higher temperatures are more likely to produce unique candidates. Setting `temperature=0.0` will always return 1 candidate regardless of the `candidate_count`. max_output_tokens: Maximum number of tokens to include in a candidate. Must be greater than zero. If unset, will default to 64. top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. `top_k` sets the maximum number of tokens to sample from on each step. top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. `top_p` configures the nucleus sampling. It sets the maximum cumulative probability of tokens to sample from. For example, if the sorted probabilities are `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample as `[0.625, 0.25, 0.125, 0, 0, 0]. safety_settings: A list of unique `types.SafetySetting` instances for blocking unsafe content. These will be enforced on the `prompt` and `candidates`. There should not be more than one setting for each `types.SafetyCategory` type. The API will block any prompts and responses that fail to meet the thresholds set by these settings. This list overrides the default settings for each `SafetyCategory` specified in the safety_settings. If there is no `types.SafetySetting` for a given `SafetyCategory` provided in the list, the API will use the default safety setting for that category. stop_sequences: A set of up to 5 character sequences that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. Returns: A `types.Completion` containing the model's text completion response. """ request = _make_generate_text_request( model=model, prompt=prompt, temperature=temperature, candidate_count=candidate_count, max_output_tokens=max_output_tokens, top_p=top_p, top_k=top_k, safety_settings=safety_settings, stop_sequences=stop_sequences, ) return _generate_response(client=client, request=request) @dataclasses.dataclass(init=False) class Completion(text_types.Completion): def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) self.result = None if self.candidates: self.result = self.candidates[0]["output"] def _generate_response( request: glm.GenerateTextRequest, client: glm.TextServiceClient = None ) -> Completion: if client is None: client = get_default_text_client() response = client.generate_text(request) response = type(response).to_dict(response) response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums( response["safety_feedback"] ) response["candidates"] = safety_types.convert_candidate_enums( response["candidates"] ) return Completion(_client=client, **response) def generate_embeddings(model: str, text: str, client: glm.TextServiceClient = None): """Calls the API to create an embedding for the text passed in. Args: model: Which model to call, as a string or a `types.Model`. text: Free-form input text given to the model. Given a string, the model will generate an embedding based on the input text. client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. Returns: Dictionary containing the embedding (list of float values) for the input text. """ if model is None: model = "models/chat-lamda-001" else: model = model_types.make_model_name(model) if client is None: client = get_default_text_client() embedding_request = glm.EmbedTextRequest(model=model, text=text) embedding_response = client.embed_text(embedding_request) embedding_dict = type(embedding_response).to_dict(embedding_response) embedding_dict["embedding"] = embedding_dict["embedding"]["value"] return embedding_dict 0707010000004A000041ED0000000000000000000000026459839500000000000000000000000000000000000000000000003900000000generative-ai-python-0.1.0~rc2/google/generativeai/types0707010000004B000081A400000000000000000000000164598395000003FA000000000000000000000000000000000000004500000000generative-ai-python-0.1.0~rc2/google/generativeai/types/__init__.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A collection of type definitions used throughout the library.""" from google.generativeai.types.discuss_types import * from google.generativeai.types.model_types import * from google.generativeai.types.text_types import * from google.generativeai.types.citation_types import * from google.generativeai.types.safety_types import * del discuss_types del model_types del text_types del citation_types del safety_types 0707010000004C000081A400000000000000000000000164598395000004C3000000000000000000000000000000000000004B00000000generative-ai-python-0.1.0~rc2/google/generativeai/types/citation_types.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, List from google.ai import generativelanguage as glm from google.generativeai import docstring_utils from typing import TypedDict __all__ = [ "CitationMetadataDict", "CitationSourceDict", ] class CitationSourceDict(TypedDict): start_index: Optional[int] end_index: Optional[int] uri: Optional[str] license: Optional[str] __doc__ = docstring_utils.strip_oneof(glm.CitationSource.__doc__) class CitationMetadataDict(TypedDict): citation_sources = Optional[List[CitationSourceDict]] __doc__ = docstring_utils.strip_oneof(glm.CitationMetadata.__doc__) 0707010000004D000081A40000000000000000000000016459839500001902000000000000000000000000000000000000004A00000000generative-ai-python-0.1.0~rc2/google/generativeai/types/discuss_types.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Type definitions for the discuss service.""" import abc import dataclasses from typing import Any, Dict, TypedDict, Union, Iterable, Optional, Tuple, List import google.ai.generativelanguage as glm from google.generativeai.types import safety_types from google.generativeai.types import citation_types __all__ = [ "MessageDict", "MessageOptions", "MessagesOptions", "ExampleDict", "ExampleOptions", "ExamplesOptions", "MessagePromptDict", "MessagePromptOptions", "ResponseDict", "ChatResponse", "AuthorError", ] class MessageDict(TypedDict): """A dict representation of a `glm.Message`.""" author: str content: str citation_metadata: Optional[citation_types.CitationMetadataDict] MessageOptions = Union[str, MessageDict, glm.Message] MESSAGE_OPTIONS = (str, dict, glm.Message) MessagesOptions = Union[ MessageOptions, Iterable[MessageOptions], ] MESSAGES_OPTIONS = (MESSAGE_OPTIONS, Iterable) class ExampleDict(TypedDict): """A dict representation of a `glm.Example`.""" input: MessageOptions output: MessageOptions ExampleOptions = Union[ Tuple[MessageOptions, MessageOptions], Iterable[MessageOptions], ExampleDict, glm.Example, ] EXAMPLE_OPTIONS = (glm.Example, dict, Iterable) ExamplesOptions = Union[ExampleOptions, Iterable[ExampleOptions]] class MessagePromptDict(TypedDict, total=False): """A dict representation of a `glm.MessagePrompt`.""" context: str examples: ExamplesOptions messages: MessagesOptions MessagePromptOptions = Union[ str, glm.Message, Iterable[Union[str, glm.Message]], MessagePromptDict, glm.MessagePrompt, ] MESSAGE_PROMPT_KEYS = {"context", "examples", "messages"} class ResponseDict(TypedDict): """A dict representation of a `glm.GenerateMessageResponse`.""" messages: List[MessageDict] candidates: List[MessageDict] @dataclasses.dataclass(init=False) class ChatResponse(abc.ABC): """A chat response from the model. * Use `response.last` (settable) for easy access to the text of the last response. (`messages[0]['content']`) * Use `response.messages` to access the message history (including `.last`). * Use `response.candidates` to access all the responses generated by the model. Other attributes are just saved from the arguments to `genai.chat`, so you can easily continue a conversation: ``` import google.generativeai as genai genai.configure(api_key=os.environ['API_KEY']) response = genai.chat(messages=["Hello."]) print(response.last) # 'Hello! What can I help you with?' response.reply("Can you tell me a joke?") ``` See `genai.chat` for more details. Attributes: candidates: A list of candidate responses from the model. The top candidate is appended to the `messages` field. This list will contain a *maximum* of `candidate_count` candidates. It may contain fewer (duplicates are dropped), it will contain at least one. Note: The `temperature` field affects the variability of the responses. Low temperatures will return few candidates. Setting `temperature=0` is deterministic, so it will only ever return one candidate. filters: This indicates which `types.SafetyCategory`(s) blocked a candidate from this response, the lowest `types.HarmProbability` that triggered a block, and the `types.HarmThreshold` setting for that category. This indicates the smallest change to the `types.SafetySettings` that would be necessary to unblock at least 1 response. The blocking is configured by the `types.SafetySettings` in the request (or the default `types.SafetySettings` of the API). messages: Contains all the `messages` that were passed when the model was called, plus the top `candidate` message. model: The model name. context: Text that should be provided to the model first, to ground the response. examples: Examples of what the model should generate. messages: A snapshot of the conversation history sorted chronologically. temperature: Controls the randomness of the output. Must be positive. candidate_count: The **maximum** number of generated response messages to return. top_k: The maximum number of tokens to consider when sampling. top_p: The maximum cumulative probability of tokens to consider when sampling. """ model: str context: str examples: List[ExampleDict] messages: List[Optional[MessageDict]] temperature: Optional[float] candidate_count: Optional[int] candidates: List[MessageDict] top_p: Optional[float] = None top_k: Optional[float] = None filters: List[safety_types.ContentFilterDict] @property @abc.abstractmethod def last(self) -> Optional[str]: """A settable property that provides simple access to the last response string A shortcut for `response.messages[0]['content']`. """ pass def to_dict(self) -> Dict[str, Any]: result = { "model": self.model, "context": self.context, "examples": self.examples, "messages": self.messages, "temperature": self.temperature, "candidate_count": self.candidate_count, "top_p": self.top_p, "top_k": self.top_k, "candidates": self.candidates, } return result @abc.abstractmethod def reply(self, message: MessageOptions) -> "ChatResponse": "Add a message to the conversation, and get the model's response." pass class AuthorError(Exception): pass 0707010000004E000081A40000000000000000000000016459839500000B4E000000000000000000000000000000000000004800000000generative-ai-python-0.1.0~rc2/google/generativeai/types/model_types.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Type definitions for the models service.""" import abc import dataclasses from typing import Iterator, List, Optional, Union __all__ = [ "Model", "ModelNameOptions", "ModelsIterable", ] @dataclasses.dataclass class Model: """A dataclass representation of a `glm.Model`. Attributes: name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming convention of: "{base_model_id}-{version}". For example: `models/chat-bison-001`. base_model_id: The base name of the model. For example: `chat-bison`. version: The major version number of the model. For example: `001`. display_name: The human-readable name of the model. E.g. `"Chat Bison"`. The name can be up to 128 characters long and can consist of any UTF-8 characters. description: A short description of the model. input_token_limit: Maximum number of input tokens allowed for this model. output_token_limit: Maximum number of output tokens available for this model. supported_generation_methods: lists which methods are supported by the model. The method names are defined as Pascal case strings, such as `generateMessage` which correspond to API methods. """ name: str base_model_id: str version: str display_name: str description: str input_token_limit: int output_token_limit: int supported_generation_methods: List[str] temperature: Optional[float] = None top_p: Optional[float] = None top_k: Optional[int] = None ModelNameOptions = Union[str, Model] def make_model_name(name: ModelNameOptions): if isinstance(name, Model): name = name.name return name class ModelsIterable(abc.ABC): """Iterate over this to yield `types.Model` objects.""" @abc.abstractmethod def __iter__(self) -> Iterator[Model]: pass @dataclasses.dataclass class TokenCount: """A dataclass representation of a `glm.TokenCountResponse`. Attributes: token_count: The number of tokens returned by the model's tokenizer for the `input_text`. token_count_limit: """ token_count: int token_count_limit: int def over_limit(self): return self.token_count > self.token_count_limit 0707010000004F000081A40000000000000000000000016459839500000D10000000000000000000000000000000000000004900000000generative-ai-python-0.1.0~rc2/google/generativeai/types/safety_types.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import enum from google.ai import generativelanguage as glm from google.generativeai import docstring_utils from typing import Iterable, List, TypedDict __all__ = [ "HarmCategory", "HarmProbability", "HarmBlockThreshold", "BlockedReason", "ContentFilterDict", "SafetyRatingDict", "SafetySettingDict", "SafetyFeedbackDict", ] # These are basic python enums, it's okay to expose them HarmCategory = glm.HarmCategory HarmProbability = glm.SafetyRating.HarmProbability HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold BlockedReason = glm.ContentFilter.BlockedReason class ContentFilterDict(TypedDict): reason: BlockedReason message: str __doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__) def convert_filters_to_enums(filters: Iterable[dict]) -> List[ContentFilterDict]: result = [] for f in filters: f = f.copy() f["reason"] = BlockedReason(f["reason"]) result.append(f) return result class SafetyRatingDict(TypedDict): category: HarmCategory probability: HarmProbability __doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__) def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: return { "category": HarmCategory(rating["category"]), "probability": HarmProbability(rating["probability"]), } def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: result = [] for r in ratings: result.append(convert_rating_to_enum(r)) return result class SafetySettingDict(TypedDict): category: HarmCategory threshold: HarmBlockThreshold __doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__) def convert_setting_to_enum(setting: dict) -> SafetySettingDict: return { "category": HarmCategory(setting["category"]), "threshold": HarmBlockThreshold(setting["threshold"]), } class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict setting: SafetySettingDict __doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__) def convert_safety_feedback_to_enums( safety_feedback: Iterable[dict], ) -> List[SafetyFeedbackDict]: result = [] for sf in safety_feedback: result.append( { "rating": convert_rating_to_enum(sf["rating"]), "setting": convert_setting_to_enum(sf["setting"]), } ) return result def convert_candidate_enums(candidates): result = [] for candidate in candidates: candidate = candidate.copy() candidate["safety_ratings"] = convert_ratings_to_enum( candidate["safety_ratings"] ) result.append(candidate) return result 07070100000050000081A4000000000000000000000001645983950000080C000000000000000000000000000000000000004700000000generative-ai-python-0.1.0~rc2/google/generativeai/types/text_types.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import dataclasses from typing import Any, Dict, Optional, List, Iterator, TypedDict from google.generativeai.types import safety_types from google.generativeai.types import citation_types __all__ = ["Completion"] class TextCompletion(TypedDict, total=False): output: str safety_ratings: Optional[List[safety_types.SafetyRatingDict]] citation_metadata: Optional[citation_types.CitationMetadataDict] @dataclasses.dataclass(init=False) class Completion(abc.ABC): """The result returned by `generativeai.generate_text`. Use `GenerateTextResponse.candidates` to access all the completions generated by the model. Attributes: candidates: A list of candidate text completions generated by the model. result: The output of the first candidate, filters: Indicates the reasons why content may have been blocked Either Unspecified, Safety, or Other. See `types.ContentFilter`. safety_feedback: Indicates which safety settings blocked content in this result. """ candidates: List[TextCompletion] result: Optional[str] filters: Optional[list[safety_types.ContentFilterDict]] safety_feedback: Optional[list[safety_types.SafetyFeedbackDict]] def to_dict(self) -> Dict[str, Any]: result = { "candidates": self.candidates, "filters": self.filters, "safety_feedback": self.safety_feedback, } return result 07070100000051000081A40000000000000000000000016459839500000270000000000000000000000000000000000000003E00000000generative-ai-python-0.1.0~rc2/google/generativeai/version.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. __version__ = "0.1.0rc2" 07070100000052000081A4000000000000000000000001645983950000002A000000000000000000000000000000000000002E00000000generative-ai-python-0.1.0~rc2/pyproject.toml[tool.pytype] inputs = ['google', 'tests']07070100000053000081A40000000000000000000000016459839500000AC4000000000000000000000000000000000000002800000000generative-ai-python-0.1.0~rc2/setup.py# -*- coding: utf-8 -*- # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import io import os import setuptools # type: ignore package_root = os.path.abspath(os.path.dirname(__file__)) name = "google-generativeai" description = "Google Generative AI High level API client library and tools." version = {} with open(os.path.join(package_root, "google/generativeai/version.py")) as fp: exec(fp.read(), version) version = version["__version__"] if version[0] == "0": release_status = "Development Status :: 4 - Beta" else: release_status = "Development Status :: 5 - Production/Stable" dependencies = ["google-ai-generativelanguage==0.2.0"] extras_require = { "dev": [ "absl-py", "asynctest", "black", "nose2", "pandas", "pytype", "pyyaml", ], } url = "https://github.com/google/generative-ai-python" readme_filename = os.path.join(package_root, "README.md") with io.open(readme_filename, encoding="utf-8") as readme_file: readme = readme_file.read() packages = [ package for package in setuptools.PEP420PackageFinder.find() if package.startswith("google") ] namespaces = ["google"] setuptools.setup( name=name, version=version, description=description, long_description=readme, author="Google LLC", author_email="googleapis-packages@google.com", license="Apache 2.0", url=url, classifiers=[ release_status, "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", # Colab "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Operating System :: OS Independent", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], platforms="Posix; MacOS X; Windows", packages=packages, python_requires=">=3.8", namespace_packages=namespaces, install_requires=dependencies, extras_require=extras_require, include_package_data=True, zip_safe=False, ) 07070100000054000041ED0000000000000000000000026459839500000000000000000000000000000000000000000000002500000000generative-ai-python-0.1.0~rc2/tests07070100000055000081A40000000000000000000000016459839500000000000000000000000000000000000000000000003100000000generative-ai-python-0.1.0~rc2/tests/__init__.py07070100000056000081A40000000000000000000000016459839500000D74000000000000000000000000000000000000003400000000generative-ai-python-0.1.0~rc2/tests/test_client.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from unittest import mock from absl.testing import absltest from absl.testing import parameterized from google.api_core import client_options from google.generativeai import client class ClientTests(parameterized.TestCase): def setUp(self): super().setUp() client.default_client_config = {} def test_api_key_passed_directly(self): client.configure(api_key="AIzA_direct") client_opts = client.default_client_config["client_options"] self.assertEqual(client_opts.api_key, "AIzA_direct") def test_api_key_passed_via_client_options(self): client_opts = client_options.ClientOptions(api_key="AIzA_client_opts") client.configure(client_options=client_opts) client_opts = client.default_client_config["client_options"] self.assertEqual(client_opts.api_key, "AIzA_client_opts") @mock.patch.dict(os.environ, {"GOOGLE_API_KEY": "AIzA_env"}) def test_api_key_from_environment(self): # Default to API key loaded from environment. client.configure() client_opts = client.default_client_config["client_options"] self.assertEqual(client_opts.api_key, "AIzA_env") # But not when a key is provided explicitly. client.configure(api_key="AIzA_client") client_opts = client.default_client_config["client_options"] self.assertEqual(client_opts.api_key, "AIzA_client") def test_api_key_cannot_be_set_twice(self): client_opts = client_options.ClientOptions(api_key="AIzA_client_opts") with self.assertRaisesRegex(ValueError, "You can't set both"): client.configure(api_key="AIzA_client", client_options=client_opts) def test_api_key_and_client_options(self): # Client options should merge with an API key, as long as they are both # do not have the key set. client_opts = client_options.ClientOptions(api_endpoint="web.site") client.configure(api_key="AIzA_client", client_options=client_opts) actual_client_opts = client.default_client_config["client_options"] self.assertEqual(actual_client_opts.api_key, "AIzA_client") self.assertEqual(actual_client_opts.api_endpoint, "web.site") @parameterized.parameters( client.get_default_discuss_client, client.get_default_text_client, client.get_default_discuss_async_client, client.get_default_model_client, ) @mock.patch.dict(os.environ, {"GOOGLE_API_KEY": "AIzA_env"}) def test_configureless_client_with_key(self, factory_fn): _ = factory_fn() # And ensure that it has set the default options. actual_client_opts = client.default_client_config["client_options"] self.assertEqual(actual_client_opts.api_key, "AIzA_env") if __name__ == "__main__": absltest.main() 07070100000057000081A40000000000000000000000016459839500002EDF000000000000000000000000000000000000003500000000generative-ai-python-0.1.0~rc2/tests/test_discuss.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import unittest.mock import google.ai.generativelanguage as glm from google.generativeai import discuss from google.generativeai import client import google.generativeai as genai from google.generativeai.types import safety_types from absl.testing import absltest from absl.testing import parameterized # TODO: replace returns with 'assert' statements class UnitTests(parameterized.TestCase): def setUp(self): self.client = unittest.mock.MagicMock() client.default_discuss_client = self.client self.observed_request = None self.mock_response = glm.GenerateMessageResponse( candidates=[ glm.Message(content="a", author="1"), glm.Message(content="b", author="1"), glm.Message(content="c", author="1"), ], ) def fake_generate_message( request: glm.GenerateMessageRequest, ) -> glm.GenerateMessageResponse: self.observed_request = request response = copy.copy(self.mock_response) response.messages = request.prompt.messages return response self.client.generate_message = fake_generate_message @parameterized.named_parameters( ["string", "Hello", ""], ["dict", {"content": "Hello"}, ""], ["dict_author", {"content": "Hello", "author": "me"}, "me"], ["proto", glm.Message(content="Hello"), ""], ["proto_author", glm.Message(content="Hello", author="me"), "me"], ) def test_make_message(self, message, author): x = discuss._make_message(message) self.assertIsInstance(x, glm.Message) self.assertEqual("Hello", x.content) self.assertEqual(author, x.author) @parameterized.named_parameters( ["string", "Hello", ["Hello"]], ["dict", {"content": "Hello"}, ["Hello"]], ["proto", glm.Message(content="Hello"), ["Hello"]], [ "list", ["hello0", {"content": "hello1"}, glm.Message(content="hello2")], ["hello0", "hello1", "hello2"], ], ) def test_make_messages(self, messages, expected_contents): messages = discuss._make_messages(messages) for expected, message in zip(expected_contents, messages): self.assertEqual(expected, message.content) @parameterized.named_parameters( ["tuple", ("hello", {"content": "goodbye"})], ["iterable", iter(["hello", "goodbye"])], ["dict", {"input": "hello", "output": "goodbye"}], [ "proto", glm.Example( input=glm.Message(content="hello"), output=glm.Message(content="goodbye"), ), ], ) def test_make_example(self, example): x = discuss._make_example(example) self.assertIsInstance(x, glm.Example) self.assertEqual("hello", x.input.content) self.assertEqual("goodbye", x.output.content) return @parameterized.named_parameters( [ "messages", [ "Hi", {"content": "Hello!"}, "what's your name?", glm.Message(content="Dave, what's yours"), ], ], [ "examples", [ ("Hi", "Hello!"), { "input": "what's your name?", "output": {"content": "Dave, what's yours"}, }, ], ], ) def test_make_examples(self, examples): examples = discuss._make_examples(examples) self.assertLen(examples, 2) self.assertEqual(examples[0].input.content, "Hi") self.assertEqual(examples[0].output.content, "Hello!") self.assertEqual(examples[1].input.content, "what's your name?") self.assertEqual(examples[1].output.content, "Dave, what's yours") return def test_make_examples_from_example(self): ex_dict = {"input": "hello", "output": "meow!"} example = discuss._make_example(ex_dict) examples1 = discuss._make_examples(ex_dict) examples2 = discuss._make_examples(discuss._make_example(ex_dict)) self.assertEqual(example, examples1[0]) self.assertEqual(example, examples2[0]) @parameterized.named_parameters( ["str", "hello"], ["message", glm.Message(content="hello")], ["messages", ["hello"]], ["dict", {"messages": "hello"}], ["dict2", {"messages": ["hello"]}], ["proto", glm.MessagePrompt(messages=[glm.Message(content="hello")])], ) def test_make_message_prompt_from_messages(self, prompt): x = discuss._make_message_prompt(prompt) self.assertIsInstance(x, glm.MessagePrompt) self.assertEqual(x.messages[0].content, "hello") return @parameterized.named_parameters( [ "dict", [ { "context": "you are a cat", "examples": ["are you hungry?", "meow!"], "messages": "hello", } ], {}, ], [ "kwargs", [], { "context": "you are a cat", "examples": ["are you hungry?", "meow!"], "messages": "hello", }, ], [ "proto", [ glm.MessagePrompt( context="you are a cat", examples=[ glm.Example( input=glm.Message(content="are you hungry?"), output=glm.Message(content="meow!"), ) ], messages=[glm.Message(content="hello")], ) ], {}, ], ) def test_make_message_prompt_from_prompt(self, args, kwargs): x = discuss._make_message_prompt(*args, **kwargs) self.assertIsInstance(x, glm.MessagePrompt) self.assertEqual(x.context, "you are a cat") self.assertEqual(x.examples[0].input.content, "are you hungry?") self.assertEqual(x.examples[0].output.content, "meow!") self.assertEqual(x.messages[0].content, "hello") def test_make_generate_message_request_nested( self, ): request0 = discuss._make_generate_message_request( **{ "model": "Dave", "context": "you are a cat", "examples": ["hello", "meow", "are you hungry?", "meow!"], "messages": "Please catch that mouse.", "temperature": 0.2, "candidate_count": 7, } ) request1 = discuss._make_generate_message_request( **{ "model": "Dave", "prompt": { "context": "you are a cat", "examples": ["hello", "meow", "are you hungry?", "meow!"], "messages": "Please catch that mouse.", }, "temperature": 0.2, "candidate_count": 7, } ) self.assertIsInstance(request0, glm.GenerateMessageRequest) self.assertIsInstance(request1, glm.GenerateMessageRequest) self.assertEqual(request0, request1) @parameterized.parameters( {"prompt": {}, "context": "You are a cat."}, {"prompt": {"context": "You are a cat."}, "examples": ["hello", "meow"]}, {"prompt": {"examples": ["hello", "meow"]}, "messages": "hello"}, ) def test_make_generate_message_request_flat_prompt_conflict( self, context=None, examples=None, messages=None, prompt=None, ): with self.assertRaises(ValueError): x = discuss._make_generate_message_request( model="test", context=context, examples=examples, messages=messages, prompt=prompt, ) @parameterized.parameters( {"kwargs": {"context": "You are a cat."}}, {"kwargs": {"messages": "hello"}}, {"kwargs": {"examples": [["a", "b"], ["c", "d"]]}}, {"kwargs": {"messages": ["hello"], "examples": [["a", "b"], ["c", "d"]]}}, ) def test_reply(self, kwargs): response = genai.chat(**kwargs) first_messages = response.messages self.assertEqual("a", response.last) self.assertEqual( [ {"author": "1", "content": "a"}, {"author": "1", "content": "b"}, {"author": "1", "content": "c"}, ], response.candidates, ) response = response.reply("again") def test_receive_and_reply_with_filters(self): self.mock_response = mock_response = glm.GenerateMessageResponse( candidates=[glm.Message(content="a", author="1")], filters=[ glm.ContentFilter( reason=safety_types.BlockedReason.SAFETY, message="unsafe" ), glm.ContentFilter(reason=safety_types.BlockedReason.OTHER), ], ) response = discuss.chat(messages="do filters work?") filters = response.filters self.assertLen(filters, 2) self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) self.assertEqual(filters[0]["reason"], safety_types.BlockedReason.SAFETY) self.assertEqual(filters[0]["message"], "unsafe") self.mock_response = glm.GenerateMessageResponse( candidates=[glm.Message(content="a", author="1")], filters=[ glm.ContentFilter( reason=safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED ) ], ) response = response.reply("Does reply work?") filters = response.filters self.assertLen(filters, 1) self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) self.assertEqual( filters[0]["reason"], safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED ) def test_chat_citations(self): self.mock_response = mock_response = glm.GenerateMessageResponse( candidates=[{'content':"Hello google!", 'author':"1", "citation_metadata": { "citation_sources": [ { "start_index": 6, "end_index": 12, "uri": "https://google.com", } ] }, }], ) response = discuss.chat(messages="Do citations work?") self.assertEqual( response.candidates[0]["citation_metadata"]["citation_sources"][0][ "start_index" ], 6, ) response = response.reply("What about a second time?") self.assertEqual( response.candidates[0]["citation_metadata"]["citation_sources"][0][ "start_index" ], 6, ) self.assertLen(response.messages, 4) if __name__ == "__main__": absltest.main() 07070100000058000081A40000000000000000000000016459839500000C86000000000000000000000000000000000000003B00000000generative-ai-python-0.1.0~rc2/tests/test_discuss_async.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import unittest if sys.version_info < (3, 11): import asynctest from asynctest import mock as async_mock import google.ai.generativelanguage as glm from google.generativeai import discuss from absl.testing import absltest from absl.testing import parameterized bases = (parameterized.TestCase,) if sys.version_info < (3, 11): bases = bases + (asynctest.TestCase,) unittest.skipIf( sys.version_info >= (3, 11), "asynctest is not suported on python 3.11+" ) class AsyncTests(*bases): if sys.version_info < (3, 11): async def test_chat_async(self): client = async_mock.MagicMock() observed_request = None async def fake_generate_message( request: glm.GenerateMessageRequest, ) -> glm.GenerateMessageResponse: nonlocal observed_request observed_request = request return glm.GenerateMessageResponse( candidates=[ glm.Message( author="1", content="Why did the chicken cross the road?" ) ] ) client.generate_message = fake_generate_message observed_response = await discuss.chat_async( model="models/bard", context="Example Prompt", examples=[["Example from human", "Example response from AI"]], messages=["Tell me a joke"], temperature=0.75, candidate_count=1, client=client, ) self.assertEqual( observed_request, glm.GenerateMessageRequest( model="models/bard", prompt=glm.MessagePrompt( context="Example Prompt", examples=[ glm.Example( input=glm.Message(content="Example from human"), output=glm.Message(content="Example response from AI"), ) ], messages=[glm.Message(author="0", content="Tell me a joke")], ), temperature=0.75, candidate_count=1, ), ) self.assertEqual( observed_response.candidates, [{"author": "1", "content": "Why did the chicken cross the road?"}], ) if __name__ == "__main__": absltest.main() 07070100000059000081A400000000000000000000000164598395000026EC000000000000000000000000000000000000003200000000generative-ai-python-0.1.0~rc2/tests/test_text.py# -*- coding: utf-8 -*- # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import unittest import unittest.mock as mock import google.ai.generativelanguage as glm from google.generativeai import text as text_service from google.generativeai import client from google.generativeai.types import safety_types from absl.testing import absltest from absl.testing import parameterized class UnitTests(parameterized.TestCase): def setUp(self): self.client = unittest.mock.MagicMock() client.default_text_client = self.client self.observed_request = None self.mock_response = glm.GenerateTextResponse( candidates=[ glm.TextCompletion(output=" road?"), glm.TextCompletion(output=" bridge?"), glm.TextCompletion(output=" river?"), ] ) def fake_generate_completion( request: glm.GenerateTextRequest, ) -> glm.GenerateTextResponse: self.observed_request = request return self.mock_response self.client.generate_text = fake_generate_completion def fake_embed_text( request: glm.EmbedTextRequest, ) -> glm.EmbedTextResponse: self.observed_request = request return glm.EmbedTextResponse(embedding=glm.Embedding(value=[1, 2, 3])) self.client.embed_text = fake_embed_text @parameterized.named_parameters( [ dict(testcase_name="string", prompt="Hello how are"), ] ) def test_make_prompt(self, prompt): x = text_service._make_text_prompt(prompt) self.assertIsInstance(x, glm.TextPrompt) self.assertEqual("Hello how are", x.text) @parameterized.named_parameters( [ dict(testcase_name="string", prompt="What are you"), ] ) def test_make_generate_text_request(self, prompt): x = text_service._make_generate_text_request(prompt=prompt) self.assertEqual("models/chat-lamda-001", x.model) self.assertIsInstance(x, glm.GenerateTextRequest) @parameterized.named_parameters( [ dict( testcase_name="basic_model", model="models/chat-lamda-001", text="What are you", ) ] ) def test_generate_embeddings(self, model, text): emb = text_service.generate_embeddings(model=model, text=text) self.assertIsInstance(emb, dict) self.assertEqual( self.observed_request, glm.EmbedTextRequest(model=model, text=text) ) @parameterized.named_parameters( [ dict(testcase_name="basic", prompt="Why did the chicken cross the"), dict( testcase_name="temperature", prompt="Why did the chicken cross the", temperature=0.75, ), dict( testcase_name="stop_list", prompt="Why did the chicken cross the", stop_sequences=["a", "b", "c"], ), dict( testcase_name="count", prompt="Why did the chicken cross the", candidate_count=2, ), ] ) def test_generate_response(self, *, prompt, **kwargs): complete = text_service.generate_text(prompt=prompt, **kwargs) self.assertEqual( self.observed_request, glm.GenerateTextRequest( model="models/text-bison-001", prompt=glm.TextPrompt(text=prompt), **kwargs ), ) self.assertIsInstance(complete.result, str) self.assertEqual( complete.candidates, [ {"output": " road?", "safety_ratings": []}, {"output": " bridge?", "safety_ratings": []}, {"output": " river?", "safety_ratings": []}, ], ) def test_stop_string(self): complete = text_service.generate_text(prompt="Hello", stop_sequences="stop") self.assertEqual( self.observed_request, glm.GenerateTextRequest( model="models/text-bison-001", prompt=glm.TextPrompt(text="Hello"), stop_sequences=["stop"], ), ) # Just make sure it made it into the request object. self.assertEqual(self.observed_request.stop_sequences, ["stop"]) def test_safety_settings(self): result = text_service.generate_text( prompt="Say something wicked.", safety_settings=[ { "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, }, { "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, "threshold": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, }, ], ) self.assertEqual( self.observed_request.safety_settings[0].category, safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) def test_filters(self): self.mock_response = glm.GenerateTextResponse( candidates=[{"output": "hello"}], filters=[ {"reason": safety_types.BlockedReason.SAFETY, "message": "not safe"} ], ) response = text_service.generate_text(prompt="do filters work?") self.assertIsInstance(response.filters[0]["reason"], safety_types.BlockedReason) self.assertEqual( response.filters[0]["reason"], safety_types.BlockedReason.SAFETY ) def test_safety_feedback(self): self.mock_response = glm.GenerateTextResponse( candidates=[{"output": "hello"}], safety_feedback=[ { "rating": { "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, "probability": safety_types.HarmProbability.HIGH, }, "setting": { "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, }, } ], ) response = text_service.generate_text(prompt="does safety feedback work?") self.assertIsInstance( response.safety_feedback[0]["rating"]["probability"], safety_types.HarmProbability, ) self.assertEqual( response.safety_feedback[0]["rating"]["probability"], safety_types.HarmProbability.HIGH, ) self.assertIsInstance( response.safety_feedback[0]["setting"]["category"], safety_types.HarmCategory, ) self.assertEqual( response.safety_feedback[0]["setting"]["category"], safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) def test_candidate_safety_feedback(self): self.mock_response = glm.GenerateTextResponse( candidates=[ { "output": "hello", "safety_ratings": [ { "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, "probability": safety_types.HarmProbability.HIGH, }, { "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, "probability": safety_types.HarmProbability.LOW, }, ], } ] ) result = text_service.generate_text(prompt="Write a story from the ER.") self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["category"], safety_types.HarmCategory, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["category"], safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["probability"], safety_types.HarmProbability, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["probability"], safety_types.HarmProbability.HIGH, ) def test_candidate_citations(self): self.mock_response = glm.GenerateTextResponse( candidates=[ { "output": "Hello Google!", "citation_metadata": { "citation_sources": [ { "start_index": 6, "end_index": 12, "uri": "https://google.com", } ] }, } ] ) result = text_service.generate_text(prompt="Hi my name is Google") self.assertEqual( result.candidates[0]["citation_metadata"]["citation_sources"][0][ "start_index" ], 6, ) if __name__ == "__main__": absltest.main() 07070100000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000B00000000TRAILER!!!796 blocks
Locations
Projects
Search
Status Monitor
Help
OpenBuildService.org
Documentation
API Documentation
Code of Conduct
Contact
Support
@OBShq
Terms
openSUSE Build Service is sponsored by
The Open Build Service is an
openSUSE project
.
Sign Up
Log In
Places
Places
All Projects
Status Monitor