{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Databricks Notebook with MLFlow AutoLogging for LiteLLM Proxy calls\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "5e2812ed-8000-4793-b090-49a31464d810",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "%pip install -U -qqqq databricks-agents mlflow langchain==0.3.1 langchain-core==0.3.6 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "52530b37-1860-4bba-a6c1-723de83bc58f",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "%pip install \"langchain-openai<=0.3.1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "43c6f4b1-e2d5-431c-b1a2-b97df7707d59",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Before logging this chain using the driver notebook, you must comment out this line.\n",
    "dbutils.library.restartPython() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "88eb8dd7-16b1-480b-aa70-cd429ef87159",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "import mlflow\n",
    "from operator import itemgetter\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.prompts import PromptTemplate\n",
    "from langchain_core.runnables import RunnableLambda\n",
    "from langchain_databricks import ChatDatabricks\n",
    "from langchain_openai import ChatOpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "f0fdca8f-6f6f-407c-ad4a-0d5a2778728e",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "mlflow.langchain.autolog()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "2ef67315-e468-4d60-a318-98c2cac75bc4",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# These helper functions parse the `messages` array.\n",
    "\n",
    "# Return the string contents of the most recent message from the user\n",
    "def extract_user_query_string(chat_messages_array):\n",
    "    return chat_messages_array[-1][\"content\"]\n",
    "\n",
    "\n",
    "# Return the chat history, which is is everything before the last question\n",
    "def extract_chat_history(chat_messages_array):\n",
    "    return chat_messages_array[:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "17708467-1976-48bd-94a0-8c7895cfae3b",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "model = ChatOpenAI(\n",
    "    openai_api_base=\"LITELLM_PROXY_BASE_URL\", # e.g.: http://0.0.0.0:4000\n",
    "    model = \"gpt-3.5-turbo\", # LITELLM 'model_name'\n",
    "    temperature=0.1, \n",
    "    api_key=\"LITELLM_PROXY_API_KEY\" # e.g.: \"sk-1234\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "a5f2c2af-82f7-470d-b559-47b67fb00cda",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "############\n",
    "# Prompt Template for generation\n",
    "############\n",
    "prompt = PromptTemplate(\n",
    "    template=\"You are a hello world bot.  Respond with a reply to the user's question that is fun and interesting to the user.  User's question: {question}\",\n",
    "    input_variables=[\"question\"],\n",
    ")\n",
    "\n",
    "############\n",
    "# FM for generation\n",
    "# ChatDatabricks accepts any /llm/v1/chat model serving endpoint\n",
    "############\n",
    "model = ChatDatabricks(\n",
    "    endpoint=\"databricks-dbrx-instruct\",\n",
    "    extra_params={\"temperature\": 0.01, \"max_tokens\": 500},\n",
    ")\n",
    "\n",
    "\n",
    "############\n",
    "# Simple chain\n",
    "############\n",
    "# The framework requires the chain to return a string value.\n",
    "chain = (\n",
    "    {\n",
    "        \"question\": itemgetter(\"messages\")\n",
    "        | RunnableLambda(extract_user_query_string),\n",
    "        \"chat_history\": itemgetter(\"messages\") | RunnableLambda(extract_chat_history),\n",
    "    }\n",
    "    | prompt\n",
    "    | model\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "366edd90-62a1-4d6f-8a65-0211fb24ca02",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Hello there! I\\'m here to help with your questions. Regarding your query about \"rag,\" it\\'s not something typically associated with a \"hello world\" bot, but I\\'m happy to explain!\\n\\nRAG, or Remote Angular GUI, is a tool that allows you to create and manage Angular applications remotely. It\\'s a way to develop and test Angular components and applications without needing to set up a local development environment. This can be particularly useful for teams working on distributed systems or for developers who prefer to work in a cloud-based environment.\\n\\nI hope this explanation of RAG has been helpful and interesting! If you have any other questions or need further clarification, feel free to ask.'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "application/databricks.mlflow.trace": "\"tr-ea2226413395413ba2cf52cffc523502\"",
      "text/plain": [
       "Trace(request_id=tr-ea2226413395413ba2cf52cffc523502)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# This is the same input your chain's REST API will accept.\n",
    "question = {\n",
    "    \"messages\": [\n",
    "               {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": \"what is rag?\",\n",
    "        },\n",
    "    ]\n",
    "}\n",
    "\n",
    "chain.invoke(question)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "5d68e37d-0980-4a02-bf8d-885c3853f6c1",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "mlflow.models.set_model(model=model)"
   ]
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "environmentMetadata": null,
   "language": "python",
   "notebookMetadata": {
    "pythonIndentUnit": 4
   },
   "notebookName": "Untitled Notebook 2024-10-16 19:35:16",
   "widgets": {}
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
