Skip to Content

Long Sequence Modeling with XGen: A 7B LLM Trained on 8K Input Sequence Length

TLDR

We trained a series of 7B LLMs named XGen-7B with standard dense attention on up to 8K sequence length for up to 1.5T tokens. We also fine tune the models on public-domain instructional data. The main take-aways are:

  • On standard NLP benchmarks, XGen achieves comparable or better results when compared with state-of-the-art open-source LLMs (e.g. MPT, Falcon, LLaMA, Redpajama, OpenLLaMA) of similar model size.
  • Our targeted evaluation on long sequence modeling benchmarks show benefits of our 8K-seq models over 2K- and 4K-seq models.
  • XGen-7B archives equally strong results both in text (e.g., MMLU, QA) and code (HumanEval) tasks.
  • Training cost of $150K on 1T tokens under Google Cloud pricing for TPU-v4.

Paper: https://arxiv.org/abs/2309.03450
Codebase: https://github.com/salesforce/xGen
Model Checkpoint: https://huggingface.co/Salesforce/xgen-7b-8k-base


Why XGen-7B with 8K Sequence Length

As LLMs become ubiquitous, their applications to long sequences have been a key focus, especially for applications like summarizing text (potentially interleaved with other data sources like tables and images), writing code, and predicting protein sequences, which require the model to effectively consider long distance structural dependencies. A large context allows a pre-trained LLM to look at customer data (e.g., documents the LLM did not use in training) and responds to useful information seeking queries.

Yet, most open-source LLMs (e.g., LLaMA, MPT, Falcon) have been trained with a maximum of 2K token sequence length, which is a key limitation in modeling long sequences. Inference time solutions such as ALiBi have yet to be evaluated for larger models (e.g. MPT-7b-StoryWriter-65k+). Recent work on model scaling has shown that for a given compute budget, the best performances are not necessarily achieved by the largest models, but by smaller models trained on more data (measured by number of tokens). A smaller model is also generally preferred for inference efficiency during serving including on-device serving. In light of this, we train a series of 7B LLMs named XGen with standard dense attention on up to 8K sequence length for up to 1.5T tokens. We also fine tune the XGen models on public-domain instructional data, creating their instruction-tuned counterparts (XGen-7B-inst).

Model

Description 

XGen-7B-4K-base

We train for 800B tokens with a sequence length of 2k tokens first, then for another 400B tokens (total 1.2T tokens) with 4k. Released under Apache-2.0.

XGen-7B-8K-base

Initialized with XGen-7B-4K-base and further trained for 300B more tokens (total 1.5T tokens) with 8K sequence length. Released under Apache-2.0.

XGen-7B-{4K,8K}-inst

Supervised fine tuned on public domain instructional data including databricks-dolly-15k, oasst1, Baize and GPT-related datasets. Released for research purpose only.


Pre-training Data

We employ a two-stage training strategy, where each stage uses a different data mixture.

First stage (1.37T tokens)

Dataset name

Effective number of tokens (B)

Sampling prop. (%)

Natural language data

1309.99

95.31

Code data

64.53

4.69

Total

1374.52

100

Natural language data is a mixture of publicly available data. Code data is a mixture of the GitHub subset from the RedPajama dataset and the Apex code data we collected.

Second stage (110B tokens)

To better support code-generation tasks, in the second stage we mix more code data from Starcoder with the data from Stage 1.

Dataset name

Number of tokens used (B)

Sampling prop. (%)

Data from stage 1

55

50%

BigCode Starcoder

55

50%

We use OpenAI’s tiktoken to tokenize our data. We add additional tokens for consecutive whitespaces and tabs, as well as the special tokens described in the Starcoder paper.


Training Details

The XGen-7b models are trained with our in-house library JaxFormer, which facilitates efficient training of LLMs under both data and model parallelism optimized for TPU-v4 hardware. The training recipe and model architecture follow LLaMA, while we conduct two additional explorations. First, we investigate the occurrence of so-called “loss spikes” [PaLM, loss spikes] during training, that is, the loss suddenly explodes temporarily while the root cause for these spikes is unknown. Second, the XGen models support sequence lengths up to 8,192 tokens (rather than the common 2,048) for which we introduce stage-wise training.

Loss Spikes

As models are scaled to larger sizes, the training itself is increasingly sensitive to instabilities, which cause poor model performance, if not addressed carefully. In our exploration, we have gathered evidence for several factors, which individually contribute to unstable training. These preliminary findings include “sequential over parallel circuits”, “swish-GLU over GeLU”, “RMS-Norm over Layer-norm”. Specifically, widely used parallel circuits, which parallelize the computation of self-attention and feed-forward as adopted in [GPT-J, PaLM, CodeGen] may affect the stability of training.

The figure above displays the loss in terms of cross-entropy over time following the well-known scaling laws. Remarkably, the training does not suffer from any instabilities or loss spikes. The two loss spikes depicted in the figure are expected when extending the sequence length, say from 2k to 4k tokens, since the model needs to adapt to such longer sequences.

Sequence Length

Training with longer sequences is computationally unproportionally costly as the complexity of self-attention is quadratic, that is, the training process is slow. To mitigate slow training, we introduce training in stages with increasing sequence length. First, 800B tokens with sequence length of 2k tokens are observed, then 400B tokens with 4k, finally, 300B tokens with 8k length.

We verify the adaptation to longer sequences by computing the average perplexity at each token position on a held-out validation set containing documents of 8k sequence length or above. If the model successfully learns to utilize the full sequence, we would expect the perplexity to decrease over sequence length, as previous tokens carry information for the next to-be-predicted token. That is, for a long sentence, the more context in the form of previous words is provided, the easier it becomes to guess the next word. The figure above indeed shows that XGen at each stage successfully learns to utilize longer contexts, up to 8k sequence length.


Results on Standard Benchmarks

(i) MMLU

We first consider the Measuring Massive Multitask Language Understanding benchmark (see examples here), which is more recent than others due to which it is arguably less susceptible to data contamination as reported in recent studies (see page 32 of GPT-4 paper and a related discussion here), and has been used consistently as a held-out evaluation benchmark. Recently, however, inconsistencies in reporting MMLU scores have been reported, which resulted in wrong rankings in Hugginface’s Open LLM leaderboard; In fact, Huggingface later had to write a blog to clarify this. In our work, we follow the original MMLU standard, which is consistent with the published results (i.e., in LLaMA).

MMLU 5-shot In-context Learning Results: We first show results on the original (and recommended) 5-shot evaluation setting, where the LLM is provided with 5 demonstrations. XGen achieves the best results in most categories, also in weighted average.

Models

Humanities

STEM

Social Sciences

Other

Weighted average

XGen-7b

33.8

30.7

40.0

41.5

36.3

LLaMA-7b

33.9

30.6

38.2

38.2

35.1

OpenLLaMA-7b

28.1

28.5

31.2

32.8

29.9

Falcon-7b

26.5

25.4

29.2

26.8

26.9

MPT-7b

25.9

26.2

26.9

28.1

26.7

Redpajama-7b

26.1

25.2

27.4

26.7

26.3

Cerebras-GPT-13b

26.1

26.5

25.8

26.6

26.2

Dolly-v2-12b

26.9

25.7

25.3

26.5

26.2

OPT-13b

26.2

24.3

23.4

26

25.1

GPT-J-6b

25.9

24.0

24.0

25.8

25.1

MMLU 0-shot Results: On zero-shot MMLU, similarly we see good results although the difference with LLaMA is generally less here.

Models

Humanities

STEM

Social Sciences

Other

Weighted average

XGen-7b

31.4

27.8

32.1

37.2

32.1

LLaMA-7b

32.3

27.1

31.3

36.8

32.0

OpenLLaMA-7b

28.0

27.6

28.9

30.1

28.6

MPT-7b

27.4

25.2

26.0

30.7

27.4

Redpajama-7b

27.5

25.5

24.2

25.0

25.8

GPT-J-6b

25.3

24.5

25.5

27.6

25.7

Dolly-v2-12b

26.2

26.0

24.0

24.9

25.4

Cerebras-GPT-13b

24.3

25.0

23.0

26.0

24.6

OPT-13b

26.3

23.3

23.6

23.6

24.4

Falcon-7b

24.8

21.7

24.0

24.4

23.9

(ii) General Zero-shot Results

Next, we report general zero-shot results on general NLP tasks that involve common sense reasoning and QA.

Models

MMLU

-wavg

ARC_ch

Hella Swag

Winogrande

TruthfulQA

BoolQ

PiQA

OpenBookQA

XGen-7b

32.1

41.2

74.2

64.9

39.1

74.3

75.5

40.2

LLaMA-7b

32.0

44.8

76.2

69.6

34

74.9

78.7

44.2

Falcon-7b

23.9

43.4

76.4

67.2

34.3

73.8

79.4

44.0

MPT-7b

27.4

41.7

76.1

68.6

33.4

74.1

79.1

41.8

OpenLLaMA-7b

28.6

38.7

71.8

67.0

35.2

70.6

76.0

39.0

Redpajama-7b

25.8

39.1

70.3

63.8

33.3

69.3

76.9

40.0

GPT-neox-20b

24.5

41.1

70.5

66.1

31.4

64.9

76.7

38.8

OPT-13b

24.4

35.8

69.9

64.7

33.9

65.0

75.7

39.8

GPT-J-6b

25.7

36.3

66.2

64.5

36.0

65.4

75.4

38.2

Dolly-v2-12b

25.4

39.6

70.8

61.8

34.4

56.3

75.4

39.2

Cerebras-GPT-13b

24.6

32.4

59.4

60.8

39.2

61.1

73.5

35.8

StableLM-alpha-7b

24.4

27.0

40.7

51.5

41.7

59.0

65.8

32.4

(iii) Results on Code Generation

To evaluate XGen’s code generation capability from natural language instructions (docstrings), we evaluate it on the well-known HumanEval benchmark. We set the sampling temperature to 0.2, p to 0.95 (for top-p sampling), and num_samples_per_task (n) to 200. We report the standard zero-shot results with pass@1 metric.

Models

pass@1

XGen-7b

14.20

LLaMA-7b

10.38

OpenLLaMA-7b

0 (Consecutive whitespaces are treated as one, breaking Python syntax)

Falcon-7b

0 (didn’t generate meaningful code)

MPT-7b

15.90

Redpajama-7b

5.24


Results on Long Sequence Generation Tasks

To further evaluate our XGen-7b 8k model in comparison to baselines which are limited to 2k inputs, we turn to long-form dialogue generation, text summarization and QA. All these tasks benefit from using processing and understanding a long context to generate a correct response. Note that for these tasks most of the base pre-trained models failed to generate a plausible response because of the task difficulty. We thus use instruction-tuned models.

Dialogue

To assess the long dialogue understanding and summarization capabilities, we report results on three dialogue summarization tasks: AMI meeting summarization, ForeverDreaming (FD), and TVMegaSite (TMS) screenplay summarization. The average source lengths of these datasets are approximately 5570, 6466, and 7653, respectively. We specifically evaluate samples that are less than 8K in length using various instruction-tuned models. Notably, when input truncation was not applied, both MPT-7b-inst and Alpaca-inst failed to perform well in this setting. Our model (XGen-7B-inst) achieved the highest ROUGE scores across all metrics.

Model

AMI

FD

TMS

R-1

R-2

R-L

R-1

R-2

R-L

R-1

R-2

R-L

XGen-7b-inst

31.34

8.25

17.00

29.34

5.39

16.43

26.39

3.94

13.71

Falcon-7b-inst

14.89

1.97

9.28

18.90

1.80

9.37

18.90

1.80

9.37

MPT-7b-inst

11.95

1.88

8.10

14.27

1.40

8.89

19.80

2.39

10.23

Alpaca-7b-inst

9.69

1.77

6.43

16.26

1.56

10.66

12.26

1.15

7.30

Long-form QA

Next, we evaluate our XGen-7b-inst on a long-form QA task that we have designed in-house. We ask ChatGPT to generate questions from (a) long Wikipedia documents spanning four domains: Physics, Engineering, History, and Entertainment, and (b) summaries of these documents. Then we query the LLMs to generate answers for these questions. The answers are typically up to 256 tokens long. We use GPT-4 to evaluate the answer quality in terms of coherence (structure and organization) and relevance (relevance of generated answer to the question and the context document) on a scale of 0-3. From the results below, we see our model has higher scores in different aspects compared to the baselines considered.

Model

Metrics

Coherence

Relevance

Avg. Ratings

XGen-7b-inst

2.55

2.52

2.54

MPT-7b-inst

2.5

2.45

2.48

Alpaca-7b-inst

1.65

1.91

1.78

Falcon-7b-inst

2.26

2.13

2.19

Summarization

Here, we evaluate our model on two text summarization datasets included in the SCROLLS Benchmark, namely QMSum and GovReport. They cover two different domains — meeting conversations and government reports. Additionally, QMSum data includes specific natural language queries which instruct the model about the key aspects of the source document that should be included in the summary. We see that our model XGen-7b outperforms other baselines on these tasks.

Model

QMSum

GovReports

R-1

R-2

R-L

R-1

R-2

R-L

XGen-7b-inst

27.96

5.66

24.26

21.28

8.19

20.08

Falcon-7b-inst

15.68

2.81

14.01

17.8

6.13

16.66

MPT-7b-inst

21.75

4.38

19.29

18.11

6.96

17.11

Redpajama-7b-inst

19.81

2.66

17.58

19.63

6.93

18.48

As we see encouraging results of our XGen-7b models on these long sequence tasks, we would like to note that since these models are not trained on the same instructional data, they are not strictly comparable.


Note on Potential Risks

Finally, despite our effort in addressing the risks of bias, toxicity and hallucinations both in pre-training and fine-tuning stages, like other LLMs, XGen-7b models are not free from such limitations. We hope our open-sourced codebase will help other researchers better understand these challenges and improve on these key limitations for making AI beneficial for everyone.

Get the latest articles in your inbox.