was successfully added to your cart.

    Retrieval-Augmented Generation (RAG) with Spark NLP: A Practical Guide

    RAG is the process of optimizing the output of a large language model, so it references an authoritative knowledge base outside its training data before generating a response.

    It has 3 steps:

    1. Ingestion – source data is ingested, cleaned, converted into embeddings, and stored in a vector database.
    2. Retrieval – the system finds the most relevant documents based on the user’s query.
    3. Generation – the retrieved information is fed into a language model to create a more accurate and context-aware response.

    In this tutorial, we show how to build a simple Retrieval-Augmented Generation (RAG) pipeline using Spark NLP. We load a dataset into Spark and use the AutoGGUFModel provided by Spark NLP to generate answers based on retrieved information.

    Before we get started, download any text file to use as your dataset and setup a pinecone instance. For this tutorial, we have stored the text from the Harry Potter Wikipedia page into a text file.

    Part 1 — Ingestion

    1. Read your source text into a DataFrame.
    2. Clean the text by tokenizing and normalizing.
    3. Generate embeddings for your text.
    4. Connect to pinecone server.
    5. Create a collection to store your embeddings in pinecone.

    1.1 Insert embeddings into pinecone collection.

    from sparknlp.base import DocumentAssembler, Finisher
    from sparknlp.annotator import SentenceDetector, AutoGGUFModel, Tokenizer, Normalizer
    from sparknlp.annotator.embeddings import AutoGGUFEmbeddings
    from pyspark.ml import Pipeline
    
    text_data = (
      spark
      .read
      .text("harry-potter.txt")
      .withColumnRenamed("value", "text") #Add your own text file
    )
    
    text_data.show()
    
    # OUTPUT
    +--------------------+
    |                text|
    +--------------------+
    |Harry Potter is a...|
    +--------------------+

    1.2 Clean the text by tokenizing and normalizing and generate embeddings for your text.

    We build a Spark NLP Pipeline with the following stages:

    • DocumentAssembler: Entry annotator for our pipelines; it creates the data structure for the Annotation Framework.
    • SentenceDetector: Annotator to pragmatically separate complete sentences inside each document.
    • Tokenizer: Annotator used to convert text into tokens.
    • Normalizer: Annotator that cleans out tokens. Removes all dirty characters from text following a regex pattern and transforms words based on a provided dictionary.
    • Finisher: Converts the cleaned text back into a string sentence.
    • AutoGGUFEmbeddings: Used to generate embeddings for each sentence. Any embeddings of your choice can be used in this step.
    from pyspark.sql.functions import monotonically_increasing_id
    from pyspark.sql.types import StringType
    
    document_assembler = (
        DocumentAssembler()
            .setInputCol("text")
            .setOutputCol("document")
    )
    
    sentence_detector = (
        SentenceDetector()
            .setInputCols(["document"])
            .setOutputCol("sentence")
            .setExplodeSentences(True)
    )
    
    tokenizer = (
        Tokenizer()
        .setInputCols("sentence")
        .setOutputCol("token")
    )
    
    normalizer = (
        Normalizer()
        .setInputCols("token")
        .setOutputCol("normalized")
    )
    
    finisher = (
        Finisher()
        .setInputCols("normalized")
        .setOutputCols("normalized_sentence")
        .setOutputAsArray(False)
        .setAnnotationSplitSymbol(" ")
    )
    
    normalized_document = (
        DocumentAssembler()
            .setInputCol("normalized_sentence")
            .setOutputCol("sentence")
    )
    
    embeddings = (
        AutoGGUFEmbeddings
        .pretrained("Nomic_Embed_Text_v1.5.Q8_0.gguf")
        .setInputCols(["sentence"])
        .setOutputCol("embeddings") \
        .setBatchSize(4) \
        .setNGpuLayers(99) \
        .setNCtx(8191)\
    )
    
    pipeline = Pipeline(
        stages=[
            document_assembler,
            sentence_detector,
            tokenizer,
            normalizer,
            finisher,
            normalized_document,
            embeddings,
        ]
    )
    result = pipeline.fit(text_data).transform(text_data)
    
    sentence_and_embeddings_df = result.selectExpr(
        "explode(sentence.result) as sentence",
        "explode(embeddings.embeddings) as vector"
    )
    
    df_pinecone = sentence_and_embeddings_df
                    .withColumn("id", monotonically_increasing_id().cast(StringType())) \
                    .withColumnRenamed("vector", "values") \
                    .withColumn("metadata", sentence_and_embeddings_df.sentence) \
                    .drop("sentence")  # Store sentence as metadata, drop original column
    
    df_pinecone.show()
    
    #OUTPUT
    +--------------------+---+--------------------+
    |              values| id|            metadata|
    +--------------------+---+--------------------+
    |[0.009264344, 0.0...|  0|Harry Potter is a...|
    |[-0.05384197, 0.0...|  1|The novels chroni...|
    |[-0.056440134, -0...|  2|The main story ar...|
    |[0.056693483, 0.0...|  3|The series was or...|
    |[-0.041599147, 0....|  4|A series of many ...|
    |[-0.030768316, 0....|  5|Major themes in t...|
    |[6.1618997E-4, 0....|  6|Since the release...|
    |[-0.047986094, -0...|  7|They have attract...|
    |[0.027252585, 0.0...|  8|As of February th...|
    |[0.04030519, 0.04...|  9|The last four boo...|
    |[0.026217783, 0.0...| 10|It holds the Guin...|
    |[0.007119997, -0....| 11|         Warner Bros|
    |[0.0041893553, 0....| 12|Pictures adapted ...|
    |[0.034328587, 0.0...| 13|In the total valu...|
    |[0.017639026, 0.0...| 14|Harry Potter and ...|
    |[-0.026738804, 0....| 15|A television seri...|
    |[-0.02948736, 0.0...| 16|Themed attraction...|
    |[-0.0024689648, 0...| 17|In the first book...|
    |[-0.041317504, 0....| 18|At the age of Har...|
    |[-0.059096724, 0....| 19|He meets a halfgi...|
    +--------------------+---+--------------------+
    only showing top 20 rows

    1.3 Connect to Pinecone Server and create a collection to store your embeddings in Pinecone.

    from pyspark.sql.functions import struct, array, lit
    from pinecone import Pinecone
    from pinecone import ServerlessSpec
    
    # Set these environment variables
    URL = 
    API_KEY = 
    INDEX_NAME = 
    EMBEDDING_DIM = 768
    
    pc = Pinecone(api_key=API_KEY)
    
    pc.create_index(
        name=INDEX_NAME,
        dimension=EMBEDDING_DIM,
        metric="cosine",
        spec=ServerlessSpec(
            cloud="aws",
            region="us-east-1"
        )
    )

    1.4 Insert the embeddings into Pinecone.

    # Function to insert a batch of vectors into Pinecone
    def insert_batch(rows):
        pc2 = Pinecone(api_key=API_KEY)
        index = pc2.Index(INDEX_NAME)
        vectors = []
        for row in rows:
            vector = {
                "id": row.id,
                "values": row.values,
                "metadata": {"text": row.metadata}
            }
            if hasattr(row, "namespace") and row.namespace is not None:
                vector["namespace"] = row.namespace
            vectors.append(vector)
        
        if vectors:
            index.upsert(vectors=vectors)
    
    # Convert DataFrame to RDD and process partitions in parallel
    df_pinecone.rdd.foreachPartition(insert_batch)

    Part 2 — Retrieval

    1. Write your queries into a DataFrame.
    2. Generate the embedding for the queries.
    3. Query the pinecone vector database using the embeddings to find the relevant context.

    2.1 Write your queries into a DataFrame.

    We create a DataFrame which contains our queries. And we run it through the same pipeline which was used to generate the embeddings in the ingestion step. As a result, we get a new DataFrame with the sentence and the embedding of the query. We store the embeddings into a list.

    from pyspark.sql.types import StringType, StructType, StructField
    from pyspark.sql import Row
    
    # Create your own queries
    queries = [
        Row(text="Who are Harry Potter's parents?"),
        Row(text="Which House did Harry belong to?"),
        Row(text="Who were Harry's friends?"),
        Row(text="What are the major themes in the Harry Potter series?")
    ]
    
    schema = StructType([StructField("text", StringType(), True)])
    query_df = spark.createDataFrame(queries, schema)
    transformed_query = pipeline.fit(query_df).transform(query_df)
    
    transformed_query = transformed_query.selectExpr(
        "explode(sentence.result) as sentence",  # Extract sentences
        "explode(embeddings.embeddings) as vector"  # Extract embeddings
    )
    
    vector_list = (
        transformed_query
        .select("vector")
        .rdd
        .flatMap(lambda x: x)
        .collect()
    )
    
    print(f"Total queries: {len(vector_list)}")
    
    #OUTPUT 
    4

    2.2 Query the Pinecone vector database using the embeddings to find the relevant context.

    For each query in the list we use the query function from pinecone which fetches text from the ingested data that is relevant to our query. You can change the top_k parameter to set the number of relevant vectors that should be fetched from the vector database.

    from collections import defaultdict
    
    query_and_context = defaultdict(list)
    vector_database = pc.Index(INDEX_NAME)
    
    for index, rag_query in enumerate(vector_list):
        response = vector_database.query(
            vector=rag_query,
            top_k=3,
            include_metadata=True
        )
    
        sentences = []
        matches = response["matches"]
        for match in matches:
            context = match["metadata"]["text"]
            sentences.append(context)
    
        for idx, sentence in enumerate(sentences, start=1):
            query_and_context[queries[index].text].append(sentence)
    
    print(query_and_context)
    
    #OUTPUT
    defaultdict(<class 'list'>, {"Who are Harry Potter's parents?": ['Harry learns that his parents Lily and James Potter also had magical powers and were murdered by the dark wizard Lord Voldemort when Harry was a baby', 'wizards of Muggle parentage are the primary targets', 'He gains the friendship of Ron Weasley a member of a large but poor wizarding family and Hermione Granger a witch of nonmagical or Muggle parentage'], 'Which House did Harry belong to?': ['The event made Harry famous among the community of wizards and witchesHarry becomes a student at Hogwarts and is sorted into Gryffindor House', 'Harry learns that his parents Lily and James Potter also had magical powers and were murdered by the dark wizard Lord Voldemort when Harry was a baby', 'In the first book Harry Potter and the Philosophers Stone Harry Potter and the Sorcerers Stone in the US Harry lives in a cupboard under the stairs in the house of the Dursleys his aunt uncle and cousin who all treat him poorly'], "Who were Harry's friends?": ['The novels chronicle the lives of a young wizard Harry Potter and his friends Hermione Granger and Ron Weasley all of whom are students at Hogwarts School of Witchcraft and Wizardry', 'He gains the friendship of Ron Weasley a member of a large but poor wizarding family and Hermione Granger a witch of nonmagical or Muggle parentage', 'Lupin enters the shack and explains that Sirius was James Potters best friend'], 'What are the major themes in the Harry Potter series?': ['A series of many genres including fantasy drama comingofage fiction and the British school story which includes elements of mystery thriller adventure horror and romance the world of Harry Potter explores numerous themes and includes many cultural meanings and references', 'Major themes in the series include prejudice corruption madness love and death', 'Harry Potter is a series of seven fantasy novels written by British author J K Rowling']})

    Part 3 — Generation

    1. Setup the prompt assembler using the template to query the model.
    2. Fill the prompt template with the query and the relevant context fetched from the retrieval step.
    3. Pass the prompt to the LLM to receive an answer to your query.

    3.1 Setup the prompt assembler using the template to query the model.

    Here, we define the default prompt template and use the PromptAssembler to set this template as the chat template.

    from sparknlp.base import *
    
    
    template = (
        "{{- bos_token }} {%- if custom_tools is defined %} {%- set tools = custom_tools %} {%- "
        "endif %} {%- if not tools_in_user_message is defined %} {%- set tools_in_user_message = true %} {%- "
        'endif %} {%- if not date_string is defined %} {%- set date_string = "26 Jul 2024" %} {%- endif %} '
        "{%- if not tools is defined %} {%- set tools = none %} {%- endif %} {#- This block extracts the "
        "system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %}"
        " {%- set system_message = messages[0]['content']|trim %} {%- set messages = messages[1:] %} {%- else"
        ' %} {%- set system_message = "" %} {%- endif %} {#- System message + builtin tools #} {{- '
        '"<|start_header_id|>system<|end_header_id|>\\n\\n" }} {%- if builtin_tools is defined or tools is '
        'not none %} {{- "Environment: ipython\\n" }} {%- endif %} {%- if builtin_tools is defined %} {{- '
        '"Tools: " + builtin_tools | reject(\'equalto\', \'code_interpreter\') | join(", ") + "\\n\\n"}} '
        '{%- endif %} {{- "Cutting Knowledge Date: December 2023\\n" }} {{- "Today Date: " + date_string '
        '+ "\\n\\n" }} {%- if tools is not none and not tools_in_user_message %} {{- "You have access to '
        'the following functions. To call a function, please respond with JSON for a function call." }} {{- '
        '\'Respond in the format {"name": function name, "parameters": dictionary of argument name and its'
        ' value}.\' }} {{- "Do not use variables.\\n\\n" }} {%- for t in tools %} {{- t | tojson(indent=4) '
        '}} {{- "\\n\\n" }} {%- endfor %} {%- endif %} {{- system_message }} {{- "<|eot_id|>" }} {#- '
        "Custom tools are passed in a user message with some extra guidance #} {%- if tools_in_user_message "
        "and not tools is none %} {#- Extract the first user message so we can plug it in here #} {%- if "
        "messages | length != 0 %} {%- set first_user_message = messages[0]['content']|trim %} {%- set "
        'messages = messages[1:] %} {%- else %} {{- raise_exception("Cannot put tools in the first user '
        "message when there's no first user message!\") }} {%- endif %} {{- "
        "'<|start_header_id|>user<|end_header_id|>\\n\\n' -}} {{- \"Given the following functions, please "
        'respond with a JSON for a function call " }} {{- "with its proper arguments that best answers the '
        'given prompt.\\n\\n" }} {{- \'Respond in the format {"name": function name, "parameters": '
        'dictionary of argument name and its value}.\' }} {{- "Do not use variables.\\n\\n" }} {%- for t in '
        'tools %} {{- t | tojson(indent=4) }} {{- "\\n\\n" }} {%- endfor %} {{- first_user_message + '
        "\"<|eot_id|>\"}} {%- endif %} {%- for message in messages %} {%- if not (message.role == 'ipython' "
        "or message.role == 'tool' or 'tool_calls' in message) %} {{- '<|start_header_id|>' + message['role']"
        " + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }} {%- elif 'tool_calls' in "
        'message %} {%- if not message.tool_calls|length == 1 %} {{- raise_exception("This model only '
        'supports single tool-calls at once!") }} {%- endif %} {%- set tool_call = message.tool_calls[0]'
        ".function %} {%- if builtin_tools is defined and tool_call.name in builtin_tools %} {{- "
        "'<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- \"<|python_tag|>\" + tool_call.name + "
        '".call(" }} {%- for arg_name, arg_val in tool_call.arguments | items %} {{- arg_name + \'="\' + '
        'arg_val + \'"\' }} {%- if not loop.last %} {{- ", " }} {%- endif %} {%- endfor %} {{- ")" }} {%- '
        "else %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- '{\"name\": \"' + "
        'tool_call.name + \'", \' }} {{- \'"parameters": \' }} {{- tool_call.arguments | tojson }} {{- "}" '
        "}} {%- endif %} {%- if builtin_tools is defined %} {#- This means we're in ipython mode #} {{- "
        '"<|eom_id|>" }} {%- else %} {{- "<|eot_id|>" }} {%- endif %} {%- elif message.role == "tool" '
        'or message.role == "ipython" %} {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }} {%- '
        "if message.content is mapping or message.content is iterable %} {{- message.content | tojson }} {%- "
        'else %} {{- message.content }} {%- endif %} {{- "<|eot_id|>" }} {%- endif %} {%- endfor %} {%- if '
        "add_generation_prompt %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }} {%- endif %} "
    )
    
    promptAssembler = (
        PromptAssembler()
        .setInputCol("messages")
        .setOutputCol("prompt")
        .setChatTemplate(template)
    )

    3.2 Fill the prompt template with the query and the relevant context fetched from the retrieval step.

    Now we populate our prompt template with the query and the context as shown below. Then we convert our prompts into a DataFrame to pass it to the prompt assembler. Below we can see the output of the prompt template with the query and the context that is fetched from the vector database.

    prompts = []
    for query, context in query_and_context.items():
        messages = [
            ("system", "You are a question answering system. You will be given a query and some context, you need to answer the query based on the context provided. Use your own knowledge if relevant context is not provided. Give your answer as a full sentence with minimum text."),
            ("assistant", "Hello there! What is your query today?"),
            ("user", f"Query: {query} Context: {''.join(context)}"),
        ]
        prompts.append([messages])
    
    promptDF = spark.createDataFrame(prompts, ["messages"])
    promptDF.show()
    
    #OUTPUT
    +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    |messages                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        |
    +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    |[{system, You are a question answering system. You will be given a query and some context, you need to answer the query based on the context provided. Use your own knowledge if relevant context is not provided. Give your answer as a full sentence with minimum text.}, {assistant, Hello there! What is your query today?}, {user, Query: Who are Harry Potter's parents? Context: Harry learns that his parents Lily and James Potter also had magical powers and were murdered by the dark wizard Lord Voldemort when Harry was a babywizards of Muggle parentage are the primary targetsHe gains the friendship of Ron Weasley a member of a large but poor wizarding family and Hermione Granger a witch of nonmagical or Muggle parentage}]                                                                                                                                                                           |
    |[{system, You are a question answering system. You will be given a query and some context, you need to answer the query based on the context provided. Use your own knowledge if relevant context is not provided. Give your answer as a full sentence with minimum text.}, {assistant, Hello there! What is your query today?}, {user, Query: Which House did Harry belong to? Context: The event made Harry famous among the community of wizards and witchesHarry becomes a student at Hogwarts and is sorted into Gryffindor HouseHarry learns that his parents Lily and James Potter also had magical powers and were murdered by the dark wizard Lord Voldemort when Harry was a babyIn the first book Harry Potter and the Philosophers Stone Harry Potter and the Sorcerers Stone in the US Harry lives in a cupboard under the stairs in the house of the Dursleys his aunt uncle and cousin who all treat him poorly}]|
    |[{system, You are a question answering system. You will be given a query and some context, you need to answer the query based on the context provided. Use your own knowledge if relevant context is not provided. Give your answer as a full sentence with minimum text.}, {assistant, Hello there! What is your query today?}, {user, Query: Who were Harry's friends? Context: The novels chronicle the lives of a young wizard Harry Potter and his friends Hermione Granger and Ron Weasley all of whom are students at Hogwarts School of Witchcraft and WizardryHe gains the friendship of Ron Weasley a member of a large but poor wizarding family and Hermione Granger a witch of nonmagical or Muggle parentageLupin enters the shack and explains that Sirius was James Potters best friend}]                                                                                                                       |
    |[{system, You are a question answering system. You will be given a query and some context, you need to answer the query based on the context provided. Use your own knowledge if relevant context is not provided. Give your answer as a full sentence with minimum text.}, {assistant, Hello there! What is your query today?}, {user, Query: What are the major themes in the Harry Potter series? Context: A series of many genres including fantasy drama comingofage fiction and the British school story which includes elements of mystery thriller adventure horror and romance the world of Harry Potter explores numerous themes and includes many cultural meanings and referencesMajor themes in the series include prejudice corruption madness love and deathHarry Potter is a series of seven fantasy novels written by British author J K Rowling}]                                                             |
    +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

    Load the AutoGGUFModel which loads phi3.5_mini_4k_instruct_q4_gguf model by default. Any LLM of your choice can be used in this step.

    from sparknlp.annotator import AutoGGUFModel
    
    autoGGUFModel = (
        AutoGGUFModel
        .pretrained()
        .setInputCols("prompt")
        .setOutputCol("completions")
        .setBatchSize(4)
        .setNGpuLayers(99)
        .setUseChatTemplate(True)  
    )

    3.3 Pass the prompt to the LLM to receive an answer to your query.

    Now we build a Spark NLP pipeline with the following stages:

    PromptAssembler: Annotator to fill prompt templates with relevant text.

    AutoGGUFModel: LLM which completes our prompts.

    generationPipeline = Pipeline(stages=[promptAssembler, autoGGUFModel])
    output = generationPipeline.fit(promptDF).transform(promptDF)

    Let’s check our final results from the AutoGGUFModel.

    final_output = output.selectExpr(
        "explode(completions.result) as output"
    )
    
    final_output.show()
    
    #OUTPUT
    +--------------------------------------------------------------------------------------------------------+
    |output                                                                                                  |
    +--------------------------------------------------------------------------------------------------------+
    |Harry Potter's parents are Lily and James Potter.                                                       |
    |Harry belonged to Gryffindor House at Hogwarts School of Witchcraft and Wizardry.                       |
    |Harry's friends were Ron Weasley and Hermione Granger.\n\n                                              |
    |The major themes in the Harry Potter series include prejudice, corruption, madness, love, and death.\n\n|
    +--------------------------------------------------------------------------------------------------------+

    With this RAG pipeline, you’re now equipped to unlock richer insights and build smarter applications, time to put your data into action!

    How useful was this post?

    Try Healthcare NLP

    See in action
    Our additional expert:
    I'm a Data Scientist at John Snow Labs, working on improving and contributing to Spark NLP. From building scalable data pipelines to training deep learning models, I thrive at the intersection of data engineering and machine learning. With 2 years of experience in big data and an ongoing MS in Computer Science at University of California, San Diego, I’ve tackled challenges in NLP, ML, and retrieval-augmented generation projects. Beyond the code, I love mentoring others, simplifying complex ideas, and constantly pushing my learning boundaries.

    Reliable and verified information compiled by our editorial and professional team. John Snow Labs' Editorial Policy.

    Consistent Linking, Tokenization, and Obfuscation for Regulatory-Grade De-Identification

    What is regulatory-grade de-identification in healthcare? Regulatory-grade de-identification refers to the process of systematically transforming or removing Protected Health Information (PHI) to...
    preloader